diff --git a/docs/dev/intrinsics_and_adapters.md b/docs/dev/intrinsics_and_adapters.md new file mode 100644 index 00000000..a41c3e5b --- /dev/null +++ b/docs/dev/intrinsics_and_adapters.md @@ -0,0 +1,38 @@ +# Intrinsics and Adapters +Note: Mellea currently only supports GraniteCommonAdapters and Intrinsics. + +## Basics +In Mellea, intrinsics are a type of Component that signals one or more of the following to a backend: +- a special adapter must be used for generation +- the input/output for generation must be transformed in a particular way +- the model options must be modified in a particular way + +These changes only happen when the intrinsic is the "action" of the request. Intrinsics should usually not be used as an item in the context of generation (in fact, by default, Intrinsics have no string representation). + +These changes are specified by the Adapter that corresponds to a given Intrinsic. Matching happens based on the adapter name and type. + +## Parts of an Intrinsic +Intrinsics specify: +- an adapter name (ie requirement_check) +- types of adapters suitable to be used (ie alora) +- any kwargs necessary (ie a requirement like "make sure the last user message is...") + +## Parts of an Adapter +Adapters specify: +- compatible backends +- adapter type +- functions for getting a path to load them + +## Using Intrinsics +Mellea Intrinsics currently utilize the granite-common package for loading adapters and formatting input/outputs (https://github.com/ibm-granite/granite-common). This means Mellea only allows intrinsics/adapters that follow this pattern. + +## Needed Future Work +### Custom Adapters / Intrinsics +Mellea should support custom intrinsic / adapter implementations. To do this: +- make backend `_generate_from_intrinsic` functions generic and utilize only common adapter functions +- adapters must specify a transformation function that encapsulates the input/output modifications necessary for their generation requests + +### Concurrency Checks +Some backends (currently only LocalHFBackend) that allow adapters to be loaded, cannot independently utilize these adapters without impacting other generation requests. + +These backends should support a generation lock that ensures requests are only performed when the correct set of adapters (or no adapters) are active. diff --git a/docs/dev/requirement_aLoRA_rerouting.md b/docs/dev/requirement_aLoRA_rerouting.md index f7001df0..011073bd 100644 --- a/docs/dev/requirement_aLoRA_rerouting.md +++ b/docs/dev/requirement_aLoRA_rerouting.md @@ -14,14 +14,14 @@ The actual rule is slightly more complicated. ## The Actual Rule -If a `Requirement` is validated using a backend that could either use a `constraint` aLoRA or perform an LLMaJ prompt on the underlying model, then the aLoRA is used for validation, even if the `backend.generate_from_context` method is called instead of the `alora.generate_from_strings` method. +If a `Requirement` is validated using a backend that could either use a `requirement_check` aLoRA or perform an LLMaJ prompt on the underlying model, then the aLoRA is used for validation, even if the `backend.generate_from_context` method is called instead of the `backend._generate_from_intrinsic` method. There are three exceptions to this rule: 1. `Backend.default_to_constraint_checking_alora` is set to `False` (this parameter defaults to `True`). 2. The `Requirement` has a more specific subtype that indicates a more specific intent (`LLMaJRequirement`). 3. The `ALoRA` requirement checker throws an exception. -There is an exception (or disambiguation) to the first exception: If the user provides an `ALoRARequirement`, then the `backend.generate_from_context` call is rerouted to the constraint checking LoRA, regardless of the value of `deault_to_constraint_checking_alora`. +There is an exception (or disambiguation) to the first exception: If the user provides an `ALoRARequirement`, then the `backend.generate_from_context` call is rerouted to the constraint checking LoRA, regardless of the value of `default_to_constraint_checking_alora`. ## Decision Rationale @@ -33,12 +33,13 @@ Suppose that the user creates a backend and then adds a generic constraint check ```python from mellea import start_session -from mellea.backends.aloras.granite_aloras import add_granite_aloras from mellea.stdlib.requirement import Requirement m = start_session( "huggingface.LocalHFBackend:ibm-granite/granite-3.2-8b-instruct") -add_granite_aloras(m) # This will load the Constraint checint aLoRA. + +# By default, the AloraRequirement uses a GraniteCommonAdapter with "requirement_check". +m.backend.add_adapter(GraniteCommonAdapter("ibm-granite/rag-intrinsics-lib", "requirement_check", base_model_name="granite-3.2-8b-instruct")) m.instruct( "Corporate wants you to find the difference between these two strings:\n\naaa\naba") diff --git a/docs/examples/intrinsics/intrinsics.py b/docs/examples/intrinsics/intrinsics.py new file mode 100644 index 00000000..1392c551 --- /dev/null +++ b/docs/examples/intrinsics/intrinsics.py @@ -0,0 +1,54 @@ +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.openai import OpenAIBackend, _ServerType +from mellea.backends.adapters.adapter import AdapterType, GraniteCommonAdapter +from mellea.stdlib.base import ChatContext, ModelOutputThunk +from mellea.stdlib.chat import Message +import mellea.stdlib.functional as mfuncs +from mellea.stdlib.intrinsics.intrinsic import Intrinsic + +# This is an example for how you would directly use intrinsics. See `mellea/stdlib/intrinsics/rag.py` +# for helper functions. + +# Create the backend. Example for a VLLM Server. Commented out in favor of the hugging face code for now. +# # Assumes a locally running VLLM server. +# backend = OpenAIBackend( +# model_id="ibm-granite/granite-3.3-8b-instruct", +# base_url="http://0.0.0.0:8000/v1", +# api_key="EMPTY", +# ) + +# # If using a remote VLLM server, utilize the `test/backends/test_openai_vllm/serve.sh` +# # script with `export VLLM_DOWNLOAD_RAG_INTRINSICS=True`. This will download the granite_common +# # adapters on the server. +# backend._server_type = _ServerType.REMOTE_VLLM + +backend = LocalHFBackend(model_id="ibm-granite/granite-3.3-8b-instruct") + +# Create the Adapter. GraniteCommonAdapter's default to ALORAs. +req_adapter = GraniteCommonAdapter( + "requirement_check", base_model_name=backend.base_model_name +) + +# Add the adapter to the backend. +backend.add_adapter(req_adapter) + +ctx = ChatContext() +ctx = ctx.add(Message("user", "Hi, can you help me?")) +ctx = ctx.add(Message("assistant", "Hello; yes! What can I help with?")) + +# Generate from an intrinsic with the same name as the adapter. By default, it will look for +# ALORA and then LORA adapters. +out, new_ctx = mfuncs.act( + Intrinsic( + "requirement_check", + intrinsic_kwargs={"requirement": "The assistant is helpful."}, + ), + ctx, + backend, +) + +# Print the output. The requirement_check adapter has a specific output format: +print(out) # {"requirement_likelihood": 1.0} + +# The AloraRequirement uses this adapter. It automatically parses that output +# when validating the output. diff --git a/mellea/backends/_utils.py b/mellea/backends/_utils.py index c6e90ba8..08720bc0 100644 --- a/mellea/backends/_utils.py +++ b/mellea/backends/_utils.py @@ -4,7 +4,6 @@ from collections.abc import Callable from typing import Any, Literal -from mellea.backends.aloras import Alora from mellea.backends.formatter import Formatter from mellea.backends.tools import parse_tools from mellea.helpers.fancy_logger import FancyLogger @@ -57,30 +56,6 @@ def to_chat( return ctx_as_conversation -def use_alora( - action: Component | CBlock, - alora: Alora | None, - default_to_constraint_checking_alora: bool, -) -> bool: - """Returns True when the condition for using alora is met. - - See `docs/dev/requirement_aLoRA_rerouting.md` for an explanation of the following code block. - """ - if issubclass(type(action), Requirement): - # The general rule is that we reroute to the alora if it exists. - reroute_to_alora = alora is not None - # However, there are some exceptions: - if not default_to_constraint_checking_alora: - reroute_to_alora = False - if issubclass(type(action), LLMaJRequirement): - reroute_to_alora = False - if issubclass(type(action), ALoraRequirement): - reroute_to_alora = True - return reroute_to_alora - else: - return False - - def to_tool_calls( tools: dict[str, Callable], decoded_result: str ) -> dict[str, ModelToolCall] | None: diff --git a/mellea/backends/adapters/adapter.py b/mellea/backends/adapters/adapter.py new file mode 100644 index 00000000..8f1577a8 --- /dev/null +++ b/mellea/backends/adapters/adapter.py @@ -0,0 +1,274 @@ +"""Module for adapters to backends.""" + +import abc +import pathlib +from typing import Any, TypeVar + +import granite_common.intrinsics +import yaml +from litellm import cast + +from mellea.backends import Backend +from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata +from mellea.backends.types import _ServerType + + +class Adapter(abc.ABC): + """An adapter that can be added to a single backend.""" + + def __init__(self, name: str, adapter_type: AdapterType): + """An adapter that can be added to a backend. + + Note: An adapter can only be added to a single backend. + + Args: + name: name of the adapter; when referencing this adapter, use + adapter.qualified_name + adapter_type: enum describing what type of adapter it is (ie LORA / ALORA) + """ + self.name = name + self.adapter_type = adapter_type + self.qualified_name = name + "_" + adapter_type.value + """the name of the adapter to use when loading / looking it up""" + + self.backend: Backend | None = None + """set when the adapter is added to a backend""" + + self.path: str | None = None + """set when the adapter is added to a backend""" + + +class OpenAIAdapter(Adapter): + """Adapter for OpenAIBackends.""" + + @abc.abstractmethod + def get_open_ai_path( + self, + base_model_name: str, + server_type: _ServerType = _ServerType.LOCALHOST, + remote_path: str | None = None, + ) -> str: + """Returns the path needed to load the adapter. + + Args: + base_model_name: the base model; typically the last part of the huggingface model id like "granite-3.3-8b-instruct" + server_type: the server type (ie LOCALHOST / OPENAI); usually the backend has information on this + remote_path: optional; used only if the server_type is REMOTE_VLLM; base path at which to find the adapter + """ + ... + + +class LocalHFAdapter(Adapter): + """Adapter for LocalHFBackends.""" + + @abc.abstractmethod + def get_local_hf_path(self, base_model_name: str) -> str: + """Returns the path needed to load the adapter. + + Args: + base_model_name: the base model; typically the last part of the huggingface model id like "granite-3.3-8b-instruct" + """ + ... + + +class GraniteCommonAdapter(OpenAIAdapter, LocalHFAdapter): + """Adapter for intrinsics that utilize the ``granite-common`` library.""" + + def __init__( + self, + intrinsic_name: str, + adapter_type: AdapterType = AdapterType.ALORA, + config_file: str | pathlib.Path | None = None, + config_dict: dict | None = None, + base_model_name: str | None = None, + ): + """Entry point for creating GraniteCommonAdapter objects. + + An adapter that can be added to either an `OpenAIBackend` or a `LocalHFBackend`. + Most intrinsics support LoRA or aLoRA adapter types. + + Args: + intrinsic_name: name of the intrinsic; the local name of the loaded adapter + that implements this intrinsic will be adapter.qualified_name + adapter_type: enum describing what type of adapter it is (ie LORA / ALORA) + config_file: optional; file for defining the intrinsic / transformations + config_dict: optional; dict for defining the intrinsic / transformations + base_model_name: optional; if provided with no config_file/config_dict, + will be used to look up the granite_common config for this adapter + """ + super().__init__(intrinsic_name, adapter_type) + + self.intrinsic_name = intrinsic_name + self.intrinsic_metadata = fetch_intrinsic_metadata(intrinsic_name) + self.base_model_name = base_model_name + + if adapter_type not in self.intrinsic_metadata.adapter_types: + raise ValueError( + f"Intrinsic '{intrinsic_name}' not available as an adapter of type " + f"'{adapter_type}. Available types are " + f"{self.intrinsic_metadata.adapter_types}." + ) + self.adapter_type = adapter_type + + # If any of the optional params are specified, attempt to set up the + # config for the intrinsic here. + if config_file and config_dict: + raise ValueError( + f"Conflicting values for config_file and config_dict " + f"parameters provided. Values were {config_file=} " + f"and {config_dict=}" + ) + if config_file is None and config_dict is None and self.base_model_name is None: + raise ValueError( + "At least one of [config_file, config_dict, base_model_name] " + "must be provided." + ) + if config_file is None and config_dict is None: + assert self.base_model_name is not None, ( + "must provide `base_model_name` if not providing a `config_file` or `config_dict`" + ) + # We're converting the adapter type to a boolean flag here. + assert adapter_type in (AdapterType.ALORA, AdapterType.LORA), ( + f"{adapter_type} not supported" + ) + is_alora = self.adapter_type == AdapterType.ALORA + config_file = granite_common.intrinsics.obtain_io_yaml( + self.intrinsic_name, + self.base_model_name, + alora=is_alora, + repo_id=self.intrinsic_metadata.repo_id, + ) + if config_file: + with open(config_file, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + if not isinstance(config_dict, dict): + raise ValueError( + f"YAML file {config_file} does not evaluate to a " + f"dictionary when parsed." + ) + assert config_dict is not None # Code above should initialize this variable + self.config: dict = config_dict + + def get_open_ai_path( + self, + base_model_name: str, + server_type: _ServerType = _ServerType.LOCALHOST, + remote_path: str | None = None, + ) -> str: + """Returns the path needed to load the adapter. + + Args: + base_model_name: the base model; typically the last part of the huggingface + model id like "granite-3.3-8b-instruct" + server_type: the server type (ie LOCALHOST / OPENAI); usually the backend + has information on this + remote_path: optional; used only if the server_type is REMOTE_VLLM; base + path at which to find the adapter + """ + if server_type == _ServerType.LOCALHOST: + path = self.download_and_get_path(base_model_name) + elif server_type == _ServerType.REMOTE_VLLM: + if remote_path is None: + remote_path = "rag-intrinsics-lib" + path = self.get_path_on_remote(base_model_name, remote_path) + else: + raise ValueError( + f"{self} not supported for OpenAIBackend with server_type: {server_type}" + ) + + return path + + def get_local_hf_path(self, base_model_name: str) -> str: + """Returns the path needed to load the adapter. + + Args: + base_model_name: the base model; typically the last part of the huggingface + model id like "granite-3.3-8b-instruct" + """ + return self.download_and_get_path(base_model_name) + + def download_and_get_path(self, base_model_name: str) -> str: + """Downloads the required rag intrinsics files if necessary and returns the path to them. + + Args: + base_model_name: the base model; typically the last part of the huggingface + model id like "granite-3.3-8b-instruct" + + Returns: + a path to the files + """ + is_alora = self.adapter_type == AdapterType.ALORA + return str( + granite_common.intrinsics.obtain_lora( + self.intrinsic_name, + base_model_name, + alora=is_alora, + repo_id=self.intrinsic_metadata.repo_id, + ) + ) + + def get_path_on_remote(self, base_model_name: str, base_path: str) -> str: + """Assumes the files have already been downloaded on the remote server.""" + # TODO: This will break when we switch to the new repo!!! + return f"./{base_path}/{self.name}/{self.adapter_type.value}/{base_model_name}" + + +T = TypeVar("T") + + +def get_adapter_for_intrinsic( + intrinsic_name: str, + intrinsic_adapter_types: list[AdapterType] | tuple[AdapterType, ...], + available_adapters: dict[str, T], +) -> T | None: + """Finds an adapter from a dict of available adapters based on the intrinsic name and its allowed adapter types. + + Args: + repo_id: Name of Hugging Face Hub repository containing the adapters that + implement the intrinsic + intrinsic_name: the name of the intrinsic, like "answerability" + intrinsic_adapter_types: the adapter types allowed for this intrinsic, like ALORA / LORA + available_adapters: the available adapters to choose from; maps adapter.qualified_name to the Adapter + + Returns: + an Adapter if found; else None + """ + adapter = None + for adapter_type in intrinsic_adapter_types: + qualified_name = f"{intrinsic_name}_{adapter_type.value}" + adapter = available_adapters.get(qualified_name) + if adapter is not None: + break + + return adapter + + +class AdapterMixin(Backend, abc.ABC): + """Mixin class for backends capable of utilizing adapters.""" + + @property + @abc.abstractmethod + def base_model_name(self) -> str: + """Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`.""" + + @abc.abstractmethod + def add_adapter(self, *args, **kwargs): + """Adds the given adapter to the backend. Must not have been added to a different backend.""" + + @abc.abstractmethod + def load_adapter(self, adapter_qualified_name: str): + """Loads the given adapter for the backend. Must have previously been added.""" + + @abc.abstractmethod + def unload_adapter(self, adapter_qualified_name: str): + """Unloads the given adapter from the backend.""" + + @abc.abstractmethod + def list_adapters(self) -> list[str]: + """Lists the adapters added via add_adapter(). + + :returns: list of adapter names that are currently registered with this backend + """ + raise NotImplementedError( + f"Backend type {type(self)} does not implement list_adapters() API call." + ) diff --git a/mellea/backends/adapters/catalog.py b/mellea/backends/adapters/catalog.py new file mode 100644 index 00000000..0ef73219 --- /dev/null +++ b/mellea/backends/adapters/catalog.py @@ -0,0 +1,92 @@ +"""Catalog of available intrinsics. + +Catalog of intrinsics currently known to Mellea,including metadata about where to find +LoRA and aLoRA adapters that implement said intrinsics. +""" + +import enum + +import pydantic + + +class AdapterType(enum.Enum): + """Possible types of adapters for a backend.""" + + LORA = "lora" + ALORA = "alora" + + +class IntriniscsCatalogEntry(pydantic.BaseModel): + """A single row in the main intrinsics catalog table. + + We use Pydantic for this dataclass because the rest of Mellea also uses Pydantic. + """ + + name: str = pydantic.Field(description="User-visible name of the intrinsic.") + internal_name: str | None = pydantic.Field( + default=None, + description="Internal name used for adapter loading, or None if the name used " + "for that purpose is the same as self.name", + ) + repo_id: str = pydantic.Field( + description="Hugging Face repository (aka 'model') where adapters for the " + "intrinsic are located." + ) + adapter_types: tuple[AdapterType, ...] = pydantic.Field( + default=(AdapterType.LORA, AdapterType.ALORA), + description="Adapter types that are known to be available for this intrinsic.", + ) + + +_RAG_REPO = "ibm-granite/rag-intrinsics-lib" + + +_INTRINSICS_CATALOG_ENTRIES = [ + ############################################ + # Core Intrinsics + ############################################ + IntriniscsCatalogEntry(name="requirement_check", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="uncertainty", repo_id=_RAG_REPO), + ############################################ + # RAG Intrinsics + ############################################ + IntriniscsCatalogEntry( + name="answer_relevance_classifier", + repo_id=_RAG_REPO, + adapter_types=(AdapterType.LORA,), + ), + IntriniscsCatalogEntry(name="answer_relevance_rewriter", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="answerability", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="citations", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="context_relevance", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="hallucination_detection", repo_id=_RAG_REPO), + IntriniscsCatalogEntry(name="query_rewrite", repo_id=_RAG_REPO), +] + +_INTRINSICS_CATALOG = {e.name: e for e in _INTRINSICS_CATALOG_ENTRIES} +"""Catalog of intrinsics that Mellea knows about. + +Mellea code should access this catalog via :func:`fetch_intrinsic_metadata()`""" + + +def known_intrinsic_names() -> list[str]: + """:returns: List of all known user-visible names for intrinsics.""" + return list(_INTRINSICS_CATALOG.keys()) + + +def fetch_intrinsic_metadata(intrinsic_name: str) -> IntriniscsCatalogEntry: + """Retrieve information about the adapter that backs an intrinsic. + + :param intrinsic_name: User-visible name of the intrinsic + + :returns: Metadata about the adapter(s) that implement the intrinsic. + """ + if intrinsic_name not in _INTRINSICS_CATALOG: + raise ValueError( + f"Unknown intrinsic name '{intrinsic_name}'. Valid names are " + f"{known_intrinsic_names()}" + ) + + # Make a copy in case some naughty downstream code decides to modify the returned + # value. + return _INTRINSICS_CATALOG[intrinsic_name].model_copy() diff --git a/mellea/backends/aloras/__init__.py b/mellea/backends/aloras/__init__.py deleted file mode 100644 index ae7b37b2..00000000 --- a/mellea/backends/aloras/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Abstract interfaces for Backends that implement Activated LoRAs.""" - -import abc - -from mellea.stdlib.base import CBlock, ModelOutputThunk - - -class Alora(abc.ABC): - """Activated LoRAs (Aloras)](https://arxiv.org/pdf/2504.12397) are are [low-rank adapters](https://arxiv.org/abs/2106.09685) that can reuse KV cache from their underlying model. - - This class should not be directly subclassed by a specific ALora. Each backend that supports ALora should provide a backend-specific abstract class that subclasses `ALora`. Individual ALoras should then be defined by subclassing the model-specific backend. - - ALoras are always attached to an underlying model and use the following calling convention: - 1. The underlying model is prompted (without the Alora active). We call this the `input`. - 2. The underlying model generates some tokens from the `input` context (again, without the ALora active). We call this the `response`. - 3. Then the adapter is activated and generates some tokens. We call then the `alora_response`. - - Args: - name (str): An arbitrary name/label in the model serving engine (e.g. vllm, or local huggingface) to assign to an ALora. This is irrelevant from the alora's (huggingface) model id. - """ - - def __init__(self, name: str): - """Each aLoRA is identified by a name.""" - self.name: str = name - - @abc.abstractmethod - def generate_using_strings(self, *args, **kwargs) -> ModelOutputThunk: - """Generates from the ALora using raw strings as the interface for inputs. In most cases, must be run from a running event loop. - - This has a generic signature because each aLoRA has different parameters depending on its functionality and how it gets called. - """ - - def generate_using_stdlib(self, *args, **kwargs) -> CBlock: - """Generates from the Alora using Span-based backends.""" - # This is NOT marked as an `abc.abstractmethod` for now because we are not releasing span-based backends. When we release a span-based backend, we should mark this method as `abc.abstractmethod`""" - raise NotImplementedError( - "There are not currently ant ALoras trained to use spans." - ) - - -class AloraBackendMixin(abc.ABC): - """Mixin class for backends capable of aLoRA functionality.""" - - @abc.abstractmethod - def add_alora(self, *args, **kwargs): - """Loads an ALora.""" - ... - - @abc.abstractmethod - def get_alora(self, alora_name: str) -> Alora | None: - """Returns the ALora by name, or None of that ALora isn't loaded.""" - ... - - @abc.abstractmethod - def get_aloras(self) -> list[Alora]: - """Returns a list of all loaded aLoRA adapters.""" - ... diff --git a/mellea/backends/aloras/huggingface/__init__.py b/mellea/backends/aloras/huggingface/__init__.py deleted file mode 100644 index 6746ef50..00000000 --- a/mellea/backends/aloras/huggingface/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""ALora implementations for `mellea.backends.huggingface` backends.""" diff --git a/mellea/backends/aloras/huggingface/granite_aloras.py b/mellea/backends/aloras/huggingface/granite_aloras.py deleted file mode 100644 index b5e29a47..00000000 --- a/mellea/backends/aloras/huggingface/granite_aloras.py +++ /dev/null @@ -1,285 +0,0 @@ -"""Huggingface implementations for IBM's "starter pack" of Activated LoRAs.""" - -import asyncio -import functools -from copy import deepcopy - -import torch -from transformers.generation.utils import GenerateDecoderOnlyOutput - -from mellea.backends.huggingface import HFAlora, HFAloraCacheInfo, LocalHFBackend -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import send_to_queue -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import GenerateType, ModelOutputThunk - - -class HFConstraintAlora(HFAlora): - """The Requirement Checking ALora for Granite checks if the specified requirement was satisfied by the most recent model generation. Only one requirement is checked at a time. - - Currently supports [Granite 3.2 8B](https://huggingface.co/ibm-granite/granite-3.2-8b-alora-requirement-check) and [Granite 3.3 8B](https://huggingface.co/ibm-granite/granite-3.3-8b-alora-requirement-check) by default. - """ - - def __init__( - self, - name: str, - path_or_model_id: str, - generation_prompt: str, - backend: LocalHFBackend, - *, - constraint_prompt: str | None = None, - include_constraint_in_alora_offset: bool = False, - ): - """Initialize after checking that the backend is correct. - - Args: - name: name of the alora. - path_or_model_id: huggingface path or model id. - generation_prompt: the prompt required to activate the aLoRa. - backend: a LocalHFBackend that this alora is attached to. - constraint_prompt: a template that the constraint can be interpolated into; can only have a single `{}` slot. - include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset. - """ - super().__init__(name, path_or_model_id, generation_prompt, backend) - - # Maintain default behavior. - if constraint_prompt is None: - constraint_prompt = "\nRequirement: {}<|end_of_text|>\n" - - self._constraint_prompt = constraint_prompt - self._include_constraint_in_alora_offset = include_constraint_in_alora_offset - - # We do a lot of logging for ALoras because this is an experimental feature. Maybe we should tag these log messages? - self._logger = FancyLogger.get_logger() - - def generate_using_strings( - self, - input: str, - response: str, - constraint: str, - force_yn: bool = True, - stream: bool = False, - ) -> ModelOutputThunk: - """Generates a constraint response from the ALora. Must be run in a running event loop.""" - assert self._backend.alora_model is not None - # Go ahead and do runtime type-checking because passing CBlocks into this function is a common error. - assert type(input) is str - assert type(response) is str - assert type(constraint) is str - self._backend.alora_model.set_adapter(self.name) - cache_hit = self._backend.cache_get(response) - - if stream: - self._logger.warning( - "`HFConstraintAlora` cannot stream output; defaulting to non-streaming approach." - ) - - generate_kwargs = {} - if cache_hit: - self._logger.debug( - f"using cache for alora {self.__class__} and response '{response}'" - ) - generate_kwargs["past_key_values"] = deepcopy(cache_hit.kv_cache) - input_combined = self._generate_using_cache(cache_hit, constraint, force_yn) - - else: - self._logger.debug( - f"not using cache for alora {self.__class__} and response '{response}'" - ) - input_combined = self._generate_not_using_cache( - input, response, constraint, force_yn - ) - - if not self._include_constraint_in_alora_offset: - alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1] - else: - # Get the constraint tokens separately so that we can calculate the alora offsets. - constraint_tokens = self._backend._tokenizer( - self._constraint_prompt.format(constraint), return_tensors="pt" - ).to(self._backend._device) - - alora_offsets = [ - constraint_tokens["input_ids"].shape[1] - + self._generation_prompt_tokens["input_ids"].shape[1] - - 2 - ] - - chat_response = asyncio.to_thread( - self._backend.alora_model.generate, - input_combined["input_ids"].to(self._backend._device), - attention_mask=input_combined["attention_mask"].to(self._backend._device), - max_new_tokens=1, - return_dict_in_generate=True, - alora_offsets=alora_offsets, - output_scores=True, - **generate_kwargs, - ) - - output = ModelOutputThunk(None) - output._meta["alora_name"] = self.name - - output._process = functools.partial( - processing, - backend=self._backend, - force_yn=force_yn, - gen_prompt=self._generation_prompt, - ) - output._post_process = functools.partial(post_processing, backend=self._backend) - - try: - # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. - # We can also support synchronous calls by adding a flag and changing this ._generate function. - - # This function should always be called from a running event loop so we don't have to worry about - # scheduling the task to a specific event loop here. - output._generate = asyncio.create_task( - send_to_queue(chat_response, output._async_queue) # type: ignore - ) - output._generate_type = GenerateType.ASYNC - except RuntimeError as e: - # Most likely cause is running this function without an event loop present. - raise e - - return output - - def _generate_using_cache( - self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool - ) -> dict: - """Returns the input object used for generation.""" - # Must tokenize the constraint here since the requirement isn't known at initialization. - constraint_tokens = self._backend._tokenizer( - self._constraint_prompt.format(constraint), return_tensors="pt" - ).to(self._backend._device) - - input_combined = { - "input_ids": torch.cat( - [ - cache_hit.merged_token_ids.unsqueeze(0), - constraint_tokens["input_ids"], - self._generation_prompt_tokens["input_ids"], - ], - dim=1, - ), - "attention_mask": torch.cat( - [ - cache_hit.merged_attention.unsqueeze(0), - constraint_tokens["attention_mask"], - self._generation_prompt_tokens["attention_mask"], - ], - dim=1, - ), - } - - self._logger.debug( - f"Prompt for cached aLoRA({self.name}):\n {self._backend._tokenizer.decode(input_combined['input_ids'][0])}" - ) - - return input_combined - - def _generate_not_using_cache( - self, input: str, response: str, constraint: str, force_yn: bool - ) -> dict: - """Returns the input object used for generation.""" - # Params aren't needed when just getting the backend args. - backend_model_opts = self._backend._simplify_and_merge(None) - sys_prompt = backend_model_opts.get(ModelOption.SYSTEM_PROMPT, None) - - chat = [ - *([{"role": "system", "content": sys_prompt}] if sys_prompt else []), - {"role": "user", "content": input}, - {"role": "assistant", "content": response}, - ] - - templatized = self._backend._tokenizer.apply_chat_template(chat, tokenize=False) - assert type(templatized) is str - - # Must tokenize the constraint here since the requirement isn't known at initialization. - templatized = templatized + self._constraint_prompt.format(constraint) - - tokenized = self._backend._tokenizer(templatized, return_tensors="pt").to( - self._backend._device - ) - - input_combined = { - "input_ids": torch.cat( - [tokenized["input_ids"], self._generation_prompt_tokens["input_ids"]], - dim=1, - ), - "attention_mask": torch.cat( - [ - tokenized["attention_mask"], - self._generation_prompt_tokens["attention_mask"], - ], - dim=1, - ), - } - - self._logger.debug( - f"Prompt for non-cached aLoRA({self.name}):\n{self._backend._tokenizer.decode(input_combined['input_ids'][0])}" - ) - - return input_combined - - -async def processing( - mot: ModelOutputThunk, - chunk: GenerateDecoderOnlyOutput, - backend: LocalHFBackend, - force_yn: bool, - gen_prompt: str, -): - """Called to process the incoming chunks.""" - if mot._underlying_value is None: - mot._underlying_value = "" - - # Don't support async for HFConstraintAlora. Means we can process the output here. - assert isinstance(chunk, GenerateDecoderOnlyOutput) - - if force_yn: - last_logits = chunk.scores[-1].squeeze(0) # type: ignore - token_Y = backend._tokenizer("Y", add_special_tokens=False)["input_ids"][0] # type: ignore - token_N = backend._tokenizer("N", add_special_tokens=False)["input_ids"][0] # type: ignore - logit_Y = last_logits[token_Y].item() - logit_N = last_logits[token_N].item() - mot._underlying_value = "Y" if logit_Y > logit_N else "N" - else: - output_text = backend._tokenizer.decode(chunk.sequences[0]) - constraint_satisfied = output_text.split(gen_prompt)[-1] - mot._underlying_value = constraint_satisfied[ - 0 - ] # Grab the first char of the str. - - -async def post_processing(mot: ModelOutputThunk, backend: LocalHFBackend): - """Called after all data has been received.""" - backend.formatter.parse(mot._action, mot) # type: ignore - - -def add_granite_aloras(backend: LocalHFBackend): - """Adds the IBM Granite "starter pack" ALoras to a backend.""" - if backend._hf_model_id == "ibm-granite/granite-3.2-8b-instruct": - backend.add_alora( - HFConstraintAlora( - name="constraint", - path_or_model_id="ibm-granite/granite-3.2-8b-alora-requirement-check", - generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>", - backend=backend, - constraint_prompt="\nRequirement: {}<|end_of_text|>\n", - include_constraint_in_alora_offset=False, - ) - ) - elif backend._hf_model_id == "ibm-granite/granite-3.3-8b-instruct": - backend.add_alora( - HFConstraintAlora( - name="constraint", - path_or_model_id="ibm-granite/granite-3.3-8b-alora-requirement-check", - generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>", - backend=backend, - constraint_prompt="\n<|start_of_role|>requirement<|end_of_role|>{}<|end_of_text|>\n", - include_constraint_in_alora_offset=True, - ) - ) - else: - raise ValueError( - f"cannot add_granite_aloras to unknown huggingface model_id / backend: {backend._hf_model_id}" - ) diff --git a/mellea/backends/aloras/openai/__init__.py b/mellea/backends/aloras/openai/__init__.py deleted file mode 100644 index f07c7f75..00000000 --- a/mellea/backends/aloras/openai/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""ALora implementations for `mellea.backends.openai` backends.""" diff --git a/mellea/backends/aloras/openai/granite_aloras.py b/mellea/backends/aloras/openai/granite_aloras.py deleted file mode 100644 index 6d1b2c6b..00000000 --- a/mellea/backends/aloras/openai/granite_aloras.py +++ /dev/null @@ -1,128 +0,0 @@ -"""OpenAI implementations for IBM's "starter pack" of Activated LoRAs.""" - -import asyncio -import functools -from collections.abc import Coroutine -from typing import Any - -import openai -from openai.types.completion import Completion - -from mellea.backends.aloras import Alora -from mellea.backends.openai import OpenAIAlora, OpenAIBackend -from mellea.backends.types import ModelOption -from mellea.helpers.async_helpers import send_to_queue -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import GenerateType, ModelOutputThunk - - -class OpenAIConstraintAlora(OpenAIAlora): - """The [Requirement Checking ALora for Granite 3.2 8B](https://huggingface.co/ibm-granite/granite-3.2-8b-alora-requirement-check) checks if the specified requirement was satisfied by the most recent model generation. Only one requirement is checked at a time.""" - - def __init__( - self, name: str, path: str, generation_prompt: str, backend: OpenAIBackend - ): - """Initialize after checking that the backend is correct.""" - assert backend._hf_model_id == "ibm-granite/granite-3.2-8b-instruct" - super().__init__(name, path, generation_prompt, backend) - # We do a lot of logging for ALoras because this is an experimental feature. Maybe we should tag these log messages? - self._logger = FancyLogger.get_logger() - - def generate_using_strings( - self, - input: str, - response: str, - constraint: str, - force_yn: bool = True, - stream: bool = False, - ) -> ModelOutputThunk: - """Generates a constraint response from the ALora. Must be run in a running event loop.""" - # Go ahead and do runtime type-checking because passing CBlocks into this function is a common error. - assert type(input) is str - assert type(response) is str - assert type(constraint) is str - - # Params aren't needed when just getting the backend args. - backend_model_opts = self._backend._simplify_and_merge(None, False) - sys_prompt = backend_model_opts.get(ModelOption.SYSTEM_PROMPT, None) - - chat = [ - *([{"role": "system", "content": sys_prompt}] if sys_prompt else []), - {"role": "user", "content": input}, - {"role": "assistant", "content": response}, - ] - - prompt = self._backend.apply_chat_template(chat) - prompt += f"\nRequirement: {constraint}<|end_of_text|>\n" # type: ignore - prompt += self._generation_prompt - - self._logger.debug(f"Prompt for non-cached aLoRA({self.name}):\n{prompt}") - - force_yn_args = {} - if force_yn: - assert hasattr(self._backend, "_tokenizer") - token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[ - "input_ids" - ][0] # type: ignore - token_N = self._backend._tokenizer("N", add_special_tokens=False)[ - "input_ids" - ][0] # type: ignore - - force_yn_args["logit_bias"] = {str(token_Y): 100, str(token_N): 100} - - chat_response: Coroutine[ - Any, Any, openai.AsyncStream[Completion] | Completion - ] = self._backend._async_client.completions.create( - model=self.name, - prompt=prompt, - max_tokens=1, - n=1, - stream=stream, - **force_yn_args, - ) # type: ignore - - output = ModelOutputThunk(None) - output._meta["alora_name"] = self.name - - output._process = processing - output._post_process = functools.partial(post_processing, self._backend) - - try: - # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. - # We can also support synchronous calls by adding a flag and changing this ._generate function. - - # This function should always be called from a running event loop so we don't have to worry about - # scheduling the task to a specific event loop here. - output._generate = asyncio.create_task( - send_to_queue(chat_response, output._async_queue) - ) - output._generate_type = GenerateType.ASYNC - except RuntimeError as e: - # Most likely cause is running this function without an event loop present - raise e - - return output - - -async def processing(mot: ModelOutputThunk, chunk: Completion): - """Called to process the incoming chunks.""" - if mot._underlying_value is None: - mot._underlying_value = "" - mot._underlying_value += chunk.choices[0].text - - -async def post_processing(backend: OpenAIBackend, mot: ModelOutputThunk): - """Called after all data has been received.""" - backend.formatter.parse(mot._action, mot) # type: ignore - - -def add_granite_aloras(backend: OpenAIBackend): - """Adds the IBM Granite "starter pack" ALoras to a backend.""" - backend.add_alora( - OpenAIConstraintAlora( - name="constraint", - path="ibm-granite/granite-3.2-8b-alora-requirement-check", - generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>", - backend=backend, - ) - ) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 99058e7d..1d82aeb0 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -13,10 +13,13 @@ import inspect import json from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any +from copy import deepcopy +from typing import TYPE_CHECKING, Any, cast +import granite_common import outlines import outlines_core +import peft import torch from transformers import ( AsyncTextIteratorStreamer, @@ -30,11 +33,19 @@ from transformers.generation.utils import GenerateDecoderOnlyOutput from mellea.backends import BaseModelSubclass -from mellea.backends._utils import to_chat, to_tool_calls, use_alora -from mellea.backends.aloras import Alora, AloraBackendMixin +from mellea.backends._utils import to_chat, to_tool_calls +from mellea.backends.adapters.adapter import ( + AdapterMixin, + AdapterType, + GraniteCommonAdapter, + LocalHFAdapter, + get_adapter_for_intrinsic, +) +from mellea.backends.adapters.catalog import fetch_intrinsic_metadata from mellea.backends.cache import Cache, SimpleLRUCache from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier +from mellea.backends.openai import OpenAIBackend from mellea.backends.process_reward_models import PRM from mellea.backends.tools import ( add_tools_from_context_actions, @@ -54,11 +65,9 @@ ModelToolCall, ) from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics.intrinsic import Intrinsic from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement -if TYPE_CHECKING: - from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore - assert outlines, "outlines needs to be present to make outlines_core work" """A configuration type for the unhappy path: Tokenizer * Model * torch device string @@ -80,7 +89,7 @@ class HFAloraCacheInfo: q_end: int = -1 -class LocalHFBackend(FormatterBackend, AloraBackendMixin): +class LocalHFBackend(FormatterBackend, AdapterMixin): """The LocalHFBackend uses Huggingface's transformers library for inference, and uses a Formatter to convert `Component`s into prompts. This backend also supports Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397). This backend is designed for running an HF model for small-scale inference locally on your machine. @@ -169,21 +178,9 @@ def __init__( self._use_caches = use_caches self._cache = cache if cache is not None else SimpleLRUCache(3) - # Used when running aLoRAs with this backend. - self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037 - # ALoras that have been loaded for this model. - self._aloras: dict[str, HFAlora] = {} - - @property - def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037 - """The ALora model.""" - return self._alora_model - - @alora_model.setter - def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037 - """Sets the ALora model. This should only happen once in a backend's lifetime.""" - assert self._alora_model is None - self._alora_model = model + # Adapters can be made known to the backend (added) and loaded. + self._added_adapters: dict[str, LocalHFAdapter] = {} + self._loaded_adapters: dict[str, LocalHFAdapter] = {} def generate_from_context( self, @@ -198,71 +195,204 @@ def generate_from_context( # Upsert model options. model_opts = self._simplify_and_merge(model_options) - if use_alora( - action, - self.get_alora("constraint"), - self.default_to_constraint_checking_alora, - ): - mot = self._generate_from_context_alora( - action, ctx, _format=format, model_options=model_opts - ) - return mot, ctx.add(action).add(mot) - else: - mot = self._generate_from_context_standard( - action, - ctx, - _format=format, - model_options=model_opts, - tool_calls=tool_calls, + # Requirements can be automatically rerouted to a requirement adapter. + if isinstance(action, Requirement): + # See docs/dev/requirement_aLoRA_rerouting.md + reroute_to_alora = self.default_to_constraint_checking_alora + adapter_name = "requirement_check" + + if isinstance(action, ALoraRequirement): + reroute_to_alora = True + adapter_name = action.intrinsic_name + alora_action = action + else: + assert action.description is not None, ( + "must have a description when generating from a requirement" + ) + alora_action = ALoraRequirement(action.description, adapter_name) + + # Check if a requirement_check (or AloraRequirement specified) adapter + # exists. + alora_req_adapter = get_adapter_for_intrinsic( + adapter_name, [AdapterType.ALORA], self._added_adapters ) + if alora_req_adapter is None: + # Log a warning if using an AloraRequirement but no adapter fit. + if reroute_to_alora and isinstance(action, ALoraRequirement): + FancyLogger.get_logger().warning( + f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation" + ) + reroute_to_alora = False + + if issubclass(type(action), LLMaJRequirement): + reroute_to_alora = False + + if reroute_to_alora: + # Keep the alora requirement handling separate for now. + mot = self._generate_from_intrinsic( + alora_action, ctx, model_options=model_opts + ) + return mot, ctx.add(alora_action).add(mot) + + elif isinstance(action, Intrinsic): + mot = self._generate_from_intrinsic(action, ctx, model_options=model_opts) return mot, ctx.add(action).add(mot) - def _generate_from_context_alora( - self, - action: Component | CBlock, - ctx: Context, - *, - _format: type[BaseModelSubclass] | None = None, - model_options: dict[str, Any], + mot = self._generate_from_context_standard( + action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls + ) + return mot, ctx.add(action).add(mot) + + def _generate_from_intrinsic( + self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any] ) -> ModelOutputThunk: - match action: - case ALoraRequirement(): - alora_for_this_request = ( - self.get_alora("constraint") - if action.alora is None - else action.alora - ) - case _: - alora_for_this_request = self.get_alora("constraint") - assert alora_for_this_request is not None, ( - "This code block should not execute unless there is a 'constraint' alora loaded." - ) - # Construct the linearized context. This is very similar to normal generation. + if not ctx.is_chat_context: + raise Exception("Does not yet support non-chat contexts.") + linearized_ctx = ctx.view_for_generation() - assert linearized_ctx is not None and len(linearized_ctx) > 1 - msgs = self.formatter.to_chat_messages(linearized_ctx) - user_message, assistant_message = msgs[-2].content, msgs[-1].content - assert alora_for_this_request is not None - assert type(user_message) is str - assert type(assistant_message) is str - assert _format is None, "Structured outputs are not supported by ALoRAs." - - alora_output = alora_for_this_request.generate_using_strings( - input=user_message, - response=assistant_message, - constraint=action.description, # type: ignore - stream=model_options.get(ModelOption.STREAM, False), + assert linearized_ctx is not None, ( + "If ctx.is_chat_context, then the context should be linearizable." + ) + ctx_as_message_list: list[Message] = self.formatter.to_chat_messages( + linearized_ctx + ) + + conversation: list[dict] = [] + system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, "") + if system_prompt != "": + conversation.append({"role": "system", "content": system_prompt}) + + conversation.extend( + [OpenAIBackend.message_to_openai_message(m) for m in ctx_as_message_list] + ) + + docs = OpenAIBackend.messages_to_docs(ctx_as_message_list) + + seed = model_options.get(ModelOption.SEED, None) + if seed is not None: + set_seed(seed) + + if model_options.get(ModelOption.STREAM, None) is not None: + # Intrinsics don't support streaming because of their post-processing step. + FancyLogger.get_logger().warning( + "intrinsics cannot use streaming; removing model option" + ) + del model_options[ModelOption.STREAM] + + adapter = get_adapter_for_intrinsic( + action.intrinsic_name, action.adapter_types, self._added_adapters + ) + if adapter is None: + raise ValueError( + f"backend ({self}) has no adapter for processing intrinsic: {action.intrinsic_name}" + ) + + # TODO: Code below this point is mostly specific to RagIntrinsics (and granite_common). + # It should be refactored into a specific adapter.transform() function. + assert isinstance(adapter, GraniteCommonAdapter), ( + "currently Mellea only supports GraniteCommonAdapters and Intrinsics" + ) + + intrinsic_config = adapter.config + assert intrinsic_config is not None + + rewriter = granite_common.IntrinsicsRewriter( + config_dict=intrinsic_config, model_name=adapter.name + ) + result_processor = granite_common.IntrinsicsResultProcessor( + config_dict=intrinsic_config ) - # The alora function doesn't set up all the fields. - alora_output._context = linearized_ctx - alora_output._action = action - alora_output._model_options = model_options + # Convert our conversation into a proper chat completions dict. + # [{role: user, content: Hello}, {...}] -> {messages: [{role:user,...}, ...], model:..., ...} + request_json: dict = { + "messages": conversation, + "extra_body": {"documents": docs}, + } + rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs) + + # TODO: Handle caching here. granite_common doesn't tell us what changed, + # so we will have to invalidate the cache on our side. This requires + # us having specific caching for each Component/Message. + + self.load_adapter(adapter.qualified_name) - # TODO: Figure out what info we want to populate for aloras here. - alora_output._generate_log = GenerateLog() + # TODO: This modifies the underlying model. We should set a non-exclusive lock here. + # It should allow generate requests with the same adapter to proceed. This logic also + # needs to be added to the other generate functions. + self._model.set_adapter(adapter.qualified_name) + + generate_input, other_input = ( + granite_common.util.chat_completion_request_to_transformers_inputs( + rewritten, self._tokenizer, self._model + ) + ) + + chat_response: Coroutine[Any, Any, granite_common.ChatCompletionResponse] = ( + asyncio.to_thread( + granite_common.util.generate_with_transformers, + self._tokenizer, + self._model, + generate_input, + other_input, + ) + ) + + output = ModelOutputThunk(None) + output._context = ctx.view_for_generation() + output._action = action + output._model_options = model_options + + # Add another step to the processing function. + async def granite_common_processing( + mot: ModelOutputThunk, + chunk: granite_common.ChatCompletionResponse, + rewritten: granite_common.ChatCompletion, + result_processor: granite_common.IntrinsicsResultProcessor, + input_ids, + ): + res = result_processor.transform(chunk, rewritten) # type: ignore + + # TODO: If we want to support caches, we need to get the GenerateDecoderOnlyOutput. This means we + # probably need to break out the pieces from `generate_with_transformers`. + # processing expects a str or a GenerateDecoderOnlyOutput. Extract the str. + return await self.processing( + mot, res.choices[0].message.content, input_ids=input_ids + ) + + output._process = functools.partial( + granite_common_processing, + rewritten=rewritten, + result_processor=result_processor, + input_ids=generate_input["input_tokens"], + ) + + # TODO: Post-processing should release the lock for this generation. + output._post_process = functools.partial( + self.post_processing, + conversation=conversation, + input_ids=generate_input["input_tokens"], + _format=None, + tool_calls=False, + tools={}, + seed=seed, + ) - return alora_output + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) # type: ignore + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present. + raise e + + return output def _generate_from_context_standard( self, @@ -453,7 +583,7 @@ async def post_processing( assert mot.value is not None # Add an entry to the cache for ALora reuse. - if self._use_caches: + if self._use_caches and mot._meta.get("hf_output", None) is not None: output_complete = mot._meta["hf_output"].sequences[0] cache: DynamicCache = mot._meta["hf_output"].past_key_values # type: ignore @@ -689,72 +819,87 @@ def _filter_chat_template_only_options( } return {k: v for k, v in model_options.items() if k not in chat_template_only} - # region ALora loading, unloading, and utility functions. - def add_alora(self, alora: HFAlora): - """Loads an ALora for this backend. - - Args: - alora (str): identifier for the ALora adapter - """ - from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore - - assert issubclass(alora.__class__, HFAlora), ( - f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}" - ) - assert alora._backend == self, "Cannot load an ALora into the wrong backend." + # region Adapter loading, unloading, and utility functions. + @property + def base_model_name(self): + """Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`.""" + return self._hf_model_id.split("/")[1] + + def add_adapter(self, adapter: LocalHFAdapter): + """Adds the given adapter to the backend. Must not have been added to a different backend.""" + if adapter.backend is not None: + if adapter.backend is self: + FancyLogger.get_logger().warning( + f"attempted to add adapter {adapter.name} with type {adapter.adapter_type} to the same backend {adapter.backend}" + ) + return + else: + raise Exception( + f"adapter {adapter.name} with type {adapter.adapter_type} has already been added to backend {adapter.backend}" + ) - if self.get_alora(alora.name) is not None: + if self._added_adapters.get(adapter.qualified_name) is not None: FancyLogger.get_logger().warning( - f"Client code attempted to add {alora.name} but {alora.name} was already added to {self.__class__}. The backend is refusing to do this, because ALora loading is not idempotent." + f"Client code attempted to add {adapter.name} with type {adapter.adapter_type} but {adapter.name} was already added to {self.__class__}. The backend is refusing to do this, because adapter loading is not idempotent." ) return None - if self.alora_model is None: - base_model = self._model - self.alora_model = aLoRAPeftModelForCausalLM.from_pretrained( - base_model, alora.path_or_model_id, alora.name - ) - else: - self.alora_model.load_adapter(alora.path_or_model_id, alora.name) + adapter.path = adapter.get_local_hf_path(self.base_model_name) + adapter.backend = self + self._added_adapters[adapter.qualified_name] = adapter - self._aloras[alora.name] = alora + def load_adapter(self, adapter_qualified_name: str): + """Loads the given adapter for the backend. Must have previously been added.""" + adapter = self._added_adapters.get(adapter_qualified_name, None) + if adapter is None: + raise ValueError( + f"could not load adapter {adapter_qualified_name} for backend {self}: adapter was not previously added" + ) - def get_alora(self, alora_name: str) -> Alora | None: - """Returns the ALora by name, or None if that ALora isn't loaded.""" - return self._aloras.get(alora_name) + try: + adapter_kwargs = {} - def get_aloras(self) -> list[Alora]: - """Returns a list of all loaded ALora adapters.""" - return list(self._aloras.values()) + # Peft tries to stringify the device. If it's mps, it gets stringified as "mps:0" which causes + # an error when loading with safetensors.torch.load_file. Force the device as a string "mps" to fix. + if self._device == torch.device("mps"): + adapter_kwargs["device"] = "mps" + self._model.load_adapter( + adapter.path, adapter.qualified_name, adapter_kwargs=adapter_kwargs + ) + except ValueError as e: + # If it's just that it's already loaded, ignore it. + if f"Adapter with name {adapter_qualified_name} already exists." not in str( + e + ): + raise e - # endregion + # Loading an adapter activates it. We disable adapters immediately after. + # Prefer this over `.disable_adapters()`; the disable function doesn't always + # seem to work. + self._model.set_adapter([]) + self._loaded_adapters[adapter.qualified_name] = adapter + + def unload_adapter(self, adapter_qualified_name: str): + """Unloads the given adapter from the backend.""" + # Check if the backend knows about this adapter. + adapter = self._loaded_adapters.get(adapter_qualified_name, None) + if adapter is None: + FancyLogger.get_logger().info( + f"could not unload adapter {adapter_qualified_name} for backend {self}: adapter is not loaded" + ) + return + self._model.delete_adapter(adapter.qualified_name) -class HFAlora(Alora, abc.ABC): - """ALoras that work with the local huggingface backend.""" + # Remove the adapter from the list of loaded adapters. + del self._loaded_adapters[adapter.qualified_name] - def __init__( - self, - name: str, - path_or_model_id: str, - generation_prompt: str, - backend: LocalHFBackend, - ): - """Initialize an ALora that should work with huggingface backends that support ALoras. + def list_adapters(self) -> list[str]: + """Lists the adapters added via add_adapter(). - Args: - name (str): An arbitrary name/label to assign to an ALora. This is irrelevant from the alora's (huggingface) model id. - path_or_model_id (str): A local path to ALora's weights or a Huggingface model_id to an ALora. - generation_prompt (str): A prompt used to "activate" the Lora. This string goes between the pre-activation context and the aLora generate call. This needs to be provided by the entity that trained the ALora. - backend (LocalHFBackend): Mained as a pointer to the backend to which this this ALora is attached. + :returns: list of adapter names that are currently registered with this backend """ - super().__init__(name) - self.path_or_model_id = path_or_model_id - self._backend = backend - self._generation_prompt = generation_prompt - self._generation_prompt_tokens = self._backend._tokenizer( - self._generation_prompt, return_tensors="pt" - ).to(self._backend._device) + return list(self._loaded_adapters.keys()) class HFProcessRewardModel(PRM, abc.ABC): @@ -788,7 +933,6 @@ def __init__( self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( self.model_name_or_path, torch_dtype=torch.bfloat16 ) - self.model.to(self._device) # type: ignore self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 566fa7a2..d28766e9 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -7,20 +7,27 @@ import inspect import json from collections.abc import Callable, Coroutine +from copy import deepcopy from enum import Enum -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse +import granite_common import openai import requests -from huggingface_hub import snapshot_download from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion import Completion import mellea.backends.model_ids as model_ids from mellea.backends import BaseModelSubclass -from mellea.backends.aloras import Alora, AloraBackendMixin +from mellea.backends.adapters.adapter import ( + AdapterMixin, + AdapterType, + GraniteCommonAdapter, + OpenAIAdapter, + get_adapter_for_intrinsic, +) from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter from mellea.backends.model_ids import ModelIdentifier from mellea.backends.tools import ( @@ -28,7 +35,7 @@ add_tools_from_model_options, convert_tools_to_json, ) -from mellea.backends.types import ModelOption +from mellea.backends.types import ModelOption, _server_type, _ServerType from mellea.helpers.async_helpers import ( ClientCache, get_current_event_loop, @@ -43,11 +50,13 @@ CBlock, Component, Context, + Document, GenerateLog, GenerateType, ModelOutputThunk, ) from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics.intrinsic import Intrinsic from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement if TYPE_CHECKING: @@ -58,25 +67,7 @@ format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors -class _ServerType(Enum): - LOCALHOST = 1 - OPENAI = 2 - - -def _server_type(url: str) -> _ServerType | None: - try: - parsed = urlparse(url) - hostname = parsed.hostname - if hostname in ("localhost", "127.0.0.1", "::1"): - return _ServerType.LOCALHOST - elif hostname == "api.openai.com": - return _ServerType.OPENAI - except Exception as e: - print(f"Error parsing URL: {e}") - return None - - -class OpenAIBackend(FormatterBackend, AloraBackendMixin): +class OpenAIBackend(FormatterBackend, AdapterMixin): """A generic OpenAI compatible backend.""" def __init__( @@ -170,6 +161,8 @@ def __init__( else: self._api_key = api_key + self._server_type = _server_type(self._base_url) + self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs) self._client = openai.OpenAI( # type: ignore @@ -181,8 +174,10 @@ def __init__( # Call once to create an async_client and populate the cache. _ = self._async_client - # ALoras that have been loaded for this model. - self._aloras: dict[str, OpenAIAlora] = {} + # Adapters can be made know to the backend (added) and + # loaded / active. + self._added_adapters: dict[str, OpenAIAdapter] = {} + self._loaded_adapters: dict[str, OpenAIAdapter] = {} @property def _async_client(self) -> openai.AsyncOpenAI: @@ -302,14 +297,13 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." ) - mot = self.generate_from_chat_context( + return self.generate_from_chat_context( action, ctx, _format=format, model_options=model_options, tool_calls=tool_calls, ) - return mot, ctx.add(action).add(mot) def generate_from_chat_context( self, @@ -320,81 +314,194 @@ def generate_from_chat_context( | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> tuple[ModelOutputThunk, Context]: """Generates a new completion from the provided Context using this backend's `Formatter`.""" - if issubclass(type(action), Requirement): - # The general rule is that we reroute to the alora if it exists. - reroute_to_alora = self.get_alora("constraint") is not None - # However, there are some exceptions: - if not self.default_to_constraint_checking_alora: + # Requirements can be automatically rerouted to a requirement adapter. + if isinstance(action, Requirement): + # See docs/dev/requirement_aLoRA_rerouting.md + reroute_to_alora = self.default_to_constraint_checking_alora + adapter_name = "requirement_check" + + if isinstance(action, ALoraRequirement): + reroute_to_alora = True + adapter_name = action.intrinsic_name + alora_action = action + else: + assert action.description is not None, ( + "must have a description when generating from a requirement" + ) + alora_action = ALoraRequirement(action.description, adapter_name) + + # Check if a requirement_check (or AloraRequirement specified) adapter exists. + alora_req_adapter = get_adapter_for_intrinsic( + adapter_name, [AdapterType.ALORA], self._added_adapters + ) + if alora_req_adapter is None: + # Log a warning if using an AloraRequirement but no adapter fit. + if reroute_to_alora and isinstance(action, ALoraRequirement): + FancyLogger.get_logger().warning( + f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation" + ) reroute_to_alora = False + if issubclass(type(action), LLMaJRequirement): reroute_to_alora = False - if issubclass(type(action), ALoraRequirement): - reroute_to_alora = True + if reroute_to_alora: - return self._generate_from_chat_context_alora( - action, ctx, _format=_format, model_options=model_options + # Keep the alora requirement handling separate for now. + mot = self._generate_from_intrinsic( + alora_action, ctx, model_options=model_options ) + return mot, ctx.add(alora_action).add(mot) - return self._generate_from_chat_context_standard( + elif isinstance(action, Intrinsic): + mot = self._generate_from_intrinsic( + action, ctx, model_options=model_options + ) + return mot, ctx.add(action).add(mot) + + mot = self._generate_from_chat_context_standard( action, ctx, _format=_format, model_options=model_options, tool_calls=tool_calls, ) + return mot, ctx.add(action).add(mot) - def _generate_from_chat_context_alora( - self, - action: Component | CBlock, - ctx: Context, - *, - _format: type[BaseModelSubclass] - | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel - model_options: dict | None = None, + def _generate_from_intrinsic( + self, action: Intrinsic, ctx: Context, *, model_options: dict | None = None ) -> ModelOutputThunk: - match action: - case ALoraRequirement(): - alora_for_this_request = ( - self.get_alora("constraint") - if action.alora is None - else action.alora - ) - case _: - alora_for_this_request = self.get_alora("constraint") - assert alora_for_this_request is not None, ( - "This code block should not execute unless there is a 'constraint' alora loaded." - ) + model_opts = self._simplify_and_merge( + model_options, is_chat_context=ctx.is_chat_context + ) + if len(model_opts.items()) > 0: + FancyLogger.get_logger().info( + "passing in model options when generating with an adapter; some model options may be overwritten / ignored" + ) + + linearized_context = ctx.view_for_generation() + assert linearized_context is not None, ( + "Cannot generate from a non-linear context in a FormatterBackend." + ) + if len(linearized_context) == 0: + FancyLogger.get_logger().warning( + f"generating with an intrinsic when the context is empty; this is typically incorrect: {action}" + ) + + # Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this. + messages: list[Message] = self.formatter.to_chat_messages(linearized_context) + + conversation: list[dict] = [] + + system_prompt = model_opts.get(ModelOption.SYSTEM_PROMPT, "") + if system_prompt != "": + conversation.append({"role": "system", "content": system_prompt}) + conversation.extend([self.message_to_openai_message(m) for m in messages]) + docs = self.messages_to_docs(messages) + + if model_opts.get(ModelOption.STREAM, None) is not None: + # Intrinsics don't support streaming because of their post-processing step. + FancyLogger.get_logger().warning( + "intrinsics cannot use streaming; removing model option" + ) + del model_opts[ModelOption.STREAM] + + adapter = get_adapter_for_intrinsic( + action.intrinsic_name, action.adapter_types, self._added_adapters + ) + if adapter is None: + raise ValueError( + f"backend ({self}) has no adapter for processing intrinsic: {action.intrinsic_name}" + ) + + # TODO: Code below this point is mostly specific to RagIntrinsics (and granite_common). + # It should be refactored into a specific adapter.transform() function. + assert isinstance(adapter, GraniteCommonAdapter), ( + "currently Mellea only supports GraniteCommonAdapters and Intrinsics" + ) + assert adapter.config is not None + rewriter = granite_common.IntrinsicsRewriter( + config_dict=adapter.config, model_name=adapter.qualified_name + ) + result_processor = granite_common.IntrinsicsResultProcessor( + config_dict=adapter.config + ) + + # Convert our conversation into a proper chat completions dict. + # [{role: user, content: Hello}, {...}] -> {messages: [{role:user,...}, ...], model:..., ...} + request_json: dict = { + "messages": conversation, + "extra_body": {"documents": docs}, + } + + rewritten = rewriter.transform(request_json, **action.intrinsic_kwargs) + + self.load_adapter(adapter.qualified_name) + chat_response: Coroutine[Any, Any, ChatCompletion] = ( + self._async_client.chat.completions.create(**rewritten.model_dump()) + ) + + output = ModelOutputThunk(None) + output._context = linearized_context + output._action = action + output._model_options = model_opts + output._meta["granite_common_chat_response"] = rewritten + + # Add another step to the processing function. + async def granite_common_processing( + mot: ModelOutputThunk, + chunk: ChatCompletion, + rewritten: ChatCompletion, + result_processor: granite_common.IntrinsicsResultProcessor, + ): + res = result_processor.transform(chunk, rewritten) # type: ignore + + # processing expects a ChatCompletion object. Granite common differs slightly from this. Re-create the necessary object. + full_res = ChatCompletion( + id=chunk.id, + choices=[], + created=chunk.created, + model=chunk.model, + usage=chunk.usage, + object="chat.completion", + ) + + # Set the choices here so that pydantic validation doesn't error out. + full_res.choices = res.choices # type: ignore + + return await self.processing(mot, full_res) + + output._process = functools.partial( + granite_common_processing, + rewritten=rewritten, # type: ignore + result_processor=result_processor, + ) - # Construct the linearized context. This is very similar to normal generation. - linearized_ctx = ctx.view_for_generation() - assert linearized_ctx is not None and len(linearized_ctx) > 1 - msgs = self.formatter.to_chat_messages(linearized_ctx) - user_message, assistant_message = msgs[-2].content, msgs[-1].content - assert alora_for_this_request is not None - assert type(user_message) is str - assert type(assistant_message) is str - assert _format is None, "Structured outputs are not supported by ALoRAs." - - model_opts = self._simplify_and_merge(model_options, is_chat_context=True) - - alora_output = alora_for_this_request.generate_using_strings( - input=user_message, - response=assistant_message, - constraint=action.description, # type: ignore - stream=model_opts.get(ModelOption.STREAM, False), + output._post_process = functools.partial( + self.post_processing, + tools={}, + conversation=conversation, + thinking=None, + seed=model_opts.get(ModelOption.SEED, None), + _format=None, ) - # The alora function doesn't set up all the fields. - alora_output._context = linearized_ctx - alora_output._action = action - alora_output._model_options = model_options + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. - # TODO: Figure out what info we want to populate for aloras here. - alora_output._generate_log = GenerateLog() + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e - return alora_output + return output @staticmethod def message_to_openai_message(msg: Message): @@ -431,6 +538,24 @@ def message_to_openai_message(msg: Message): # ] # } + @staticmethod + def messages_to_docs(msgs: list[Message]) -> list[dict[str, str]]: + """Extracts the docs from a list of messages.""" + docs: list[Document] = [] + for message in msgs: + if message._docs is not None: + docs.extend(message._docs) + + json_docs: list[dict[str, str]] = [] + for doc in docs: + json_doc = {"text": doc.text} + if doc.title is not None: + json_doc["title"] = doc.title + if doc.doc_id is not None: + json_doc["doc_id"] = doc.doc_id + json_docs.append(json_doc) + return json_docs + def _generate_from_chat_context_standard( self, action: Component | CBlock, @@ -725,41 +850,87 @@ def generate_from_raw( return results - def add_alora(self, alora: "OpenAIAlora"): - """Loads an ALora for this backend. - - Args: - alora (str): identifier for the ALora adapter - """ - assert issubclass(alora.__class__, OpenAIAlora), ( - f"cannot add an ALora of type {alora.__class__} to model; must inherit from {OpenAIAlora.__class__}" - ) - assert alora._backend == self, "Cannot load an ALora into the wrong backend." + @property + def base_model_name(self): + """Returns the base_model_id of the model used by the backend. For example, `granite-3.3-8b-instruct` for `ibm-granite/granite-3.3-8b-instruct`.""" + return self._hf_model_id.split("/")[1] + + def add_adapter(self, adapter: OpenAIAdapter): + """Adds the given adapter to the backend. Must not have been added to a different backend.""" + if adapter.backend is not None: + if adapter.backend is self: + FancyLogger.get_logger().warning( + f"attempted to add adapter {adapter.name} with type {adapter.adapter_type} to the same backend {adapter.backend}" + ) + return + else: + raise Exception( + f"adapter {adapter.name} with type {adapter.adapter_type} has already been added to backend {adapter.backend}" + ) - if self.get_alora(alora.name) is not None: + if self._added_adapters.get(adapter.qualified_name, None) is not None: FancyLogger.get_logger().warning( - f"Client code attempted to add {alora.name} but {alora.name} was already added to {self.__class__}. The backend is refusing to do this, because ALora loading is not idempotent." + f"Client code attempted to add {adapter.name} with type {adapter.adapter_type} but it was already added to {self.__class__}. This attempt to add the adapter will be ignored." ) return None - assert _server_type(self._base_url) == _ServerType.LOCALHOST, ( - "alora is supported only for locally running vllm instances" + adapter.path = adapter.get_open_ai_path( + self.base_model_name, server_type=self._server_type + ) + adapter.backend = self + self._added_adapters[adapter.qualified_name] = adapter + + def load_adapter(self, adapter_qualified_name: str): + """Loads the given adapter for the backend. Must have previously been added.""" + adapter = self._added_adapters.get(adapter_qualified_name, None) + if adapter is None: + raise ValueError( + f"could not load adapter {adapter_qualified_name} for backend {self}: adapter was not previously added" + ) + + url = f"{self._base_url}/load_lora_adapter" + response = requests.post( + url, + json={"lora_name": adapter_qualified_name, "lora_path": adapter.path}, + headers={"Content-Type": "application/json"}, ) - snapshot_path = snapshot_download(alora.path) + err: str | None = None + match response.status_code: + case 200: + FancyLogger.get_logger().info( + f"{url}: status {response.status_code} {response.text}" + ) + case 400: + if "has already been loaded." in str(response.content): + FancyLogger.get_logger().warning( + f"{url}: status {response.status_code} {response.text}" + ) + else: + err = f"{url}: status {response.status_code} {response.text}" + case _: + err = f"{url}: status {response.status_code} {response.text}" - # https://docs.vllm.ai/en/stable/features/lora.html#using-api-endpoints - # curl -X POST http://localhost:8000/v1/load_lora_adapter \ - # -H "Content-Type: application/json" \ - # -d '{ - # "lora_name": "sql_adapter", - # "lora_path": "/path/to/sql-lora-adapter" - # }' + if err is not None: + FancyLogger.get_logger().error(err) + raise Exception(f"error loading adapter {adapter_qualified_name}: {err}") - url = f"{self._base_url}/load_lora_adapter" + self._loaded_adapters[adapter.qualified_name] = adapter + + def unload_adapter(self, adapter_qualified_name: str): + """Unloads the given adapter from the backend.""" + # Check if the backend knows about this adapter. + adapter = self._loaded_adapters.get(adapter_qualified_name, None) + if adapter is None: + FancyLogger.get_logger().info( + f"could not unload adapter {adapter_qualified_name} for backend {self}: adapter is not loaded" + ) + return + + url = f"{self._base_url}/unload_lora_adapter" response = requests.post( url, - json={"lora_name": alora.name, "lora_path": snapshot_path}, + json={"lora_name": adapter_qualified_name}, headers={"Content-Type": "application/json"}, ) @@ -768,23 +939,30 @@ def add_alora(self, alora: "OpenAIAlora"): FancyLogger.get_logger().info( f"{url}: status {response.status_code} {response.text}" ) - self._aloras[alora.name] = alora + case 404: + # This response code indicates that the adapter isn't currently loaded; + # which is the goal of this function. Log it but proceed as if successful. + FancyLogger.get_logger().info( + f"{url}: status {response.status_code} {response.text}" + ) case _: + # Unknown err. FancyLogger.get_logger().error( f"{url}: status {response.status_code} {response.text}" ) + raise Exception( + f"error unloading adapter {adapter_qualified_name}: {url}: status {response.status_code} {response.text}" + ) - self._aloras[alora.name] = alora - - return None + # Remove the adapter from the list of loaded adapters. + del self._loaded_adapters[adapter.qualified_name] - def get_alora(self, alora_name: str) -> Alora | None: - """Returns the ALora by name, or None if that ALora isn't loaded.""" - return self._aloras.get(alora_name) + def list_adapters(self) -> list[str]: + """Lists the adapters added via add_adapter(). - def get_aloras(self) -> list[Alora]: - """Returns a list of all loaded ALora adapters.""" - return list(self._aloras.values()) + :returns: list of adapter names that are currently registered with this backend + """ + return list(self._loaded_adapters.keys()) def apply_chat_template(self, chat: list[dict[str, str]]): """Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id).""" @@ -804,23 +982,3 @@ def apply_chat_template(self, chat: list[dict[str, str]]): ) return self._tokenizer.apply_chat_template(chat, tokenize=False) - - -class OpenAIAlora(Alora, abc.ABC): - """ALoras that work with OpenAI backend.""" - - def __init__( - self, name: str, path: str, generation_prompt: str, backend: OpenAIBackend - ): - """Initialize an ALora that should work with OpenAI backends that support ALoras. - - Args: - name (str): An arbitrary name/label to assign to an ALora. This is irrelevant from the alora's (huggingface) model id. - path (str): A local path to ALora's weights or a Huggingface model_id to an ALora. - generation_prompt (str): A prompt used to "activate" the Lora. This string goes between the pre-activation context and the aLora generate call. This needs to be provided by the entity that trained the ALora. - backend (OpenAIBackend): Mained as a pointer to the backend to which this this ALora is attached. - """ - super().__init__(name) - self.path = path - self._backend = backend - self._generation_prompt = generation_prompt diff --git a/mellea/backends/types.py b/mellea/backends/types.py index d7f0db12..89f03851 100644 --- a/mellea/backends/types.py +++ b/mellea/backends/types.py @@ -1,6 +1,8 @@ """Useful type definitions for models, formatters, and backends.""" +from enum import Enum from typing import Any +from urllib.parse import urlparse from mellea.helpers.fancy_logger import FancyLogger @@ -109,3 +111,27 @@ def merge_model_options( for k, v in overwrite_opts.items(): new_options[k] = v return new_options + + +class _ServerType(Enum): + """Different types of servers that might be relevant for a backend.""" + + UNKNOWN = 0 + LOCALHOST = 1 + OPENAI = 2 + REMOTE_VLLM = 3 + """Must be set manually for now.""" + + +def _server_type(url: str) -> _ServerType: + """Find a server type based on the url.""" + try: + parsed = urlparse(url) + hostname = parsed.hostname + if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): + return _ServerType.LOCALHOST + elif hostname == "api.openai.com": + return _ServerType.OPENAI + except Exception as e: + print(f"Error parsing URL: {e}") + return _ServerType.UNKNOWN diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index bf0c1954..111d44f6 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -149,6 +149,35 @@ def get_images_from_component(c: Component) -> None | list[ImageBlock]: return None +# TODO: Add support for passing in docs as model options. +class Document(Component): + """Documents should typically be used in a Message object.""" + + def __init__(self, text: str, title: str | None = None, doc_id: str | None = None): + """Create a document object. Should typically be used as a list in the `_docs` field of Message.""" + self.text = text + self.title = title + self.doc_id = doc_id + + def parts(self) -> list[Component | CBlock]: + """The set of all the constituent parts of the `Component`.""" + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> str: + """Formats the `Document` into a string. + + Returns: a string + """ + doc = "" + if self.doc_id is not None: + doc += f"document ID '{self.doc_id}': " + if self.title is not None: + doc += f"'{self.title}': " + doc += f"{self.text}" + + return doc + + class GenerateType(enum.Enum): """Used to track what functions can be used to extract a value from a ModelOutputThunk.""" diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/chat.py index 7f5bbb4a..574e6fa6 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/chat.py @@ -8,6 +8,7 @@ CBlock, Component, Context, + Document, ImageBlock, ModelOutputThunk, ModelToolCall, @@ -26,6 +27,7 @@ def __init__( content: str, *, images: None | list[ImageBlock] = None, + documents: None | list[Document] = None, ): """Initializer for Chat messages. @@ -33,10 +35,12 @@ def __init__( role (str): The role that this message came from (e.g., user, assistant). content (str): The content of the message. images (list[ImageBlock]): The images associated with the message if any. + documents (list[Document]): documents associated with the message if any. """ self.role = role self.content = content self._images = images + self._docs = documents @property def images(self) -> None | list[str]: @@ -59,7 +63,12 @@ def format_for_llm(self) -> TemplateRepresentation: """ return TemplateRepresentation( obj=self, - args={"role": self.role, "content": self.content, "images": self.images}, + args={ + "role": self.role, + "content": self.content, + "images": self.images, + "documents": self._docs, + }, template_order=["*", "Message"], ) @@ -68,7 +77,11 @@ def __str__(self): images = [] if self.images is not None: images = [f"{i[:20]}..." for i in self.images] - return f'mellea.Message(role="{self.role}", content="{self.content}", images="{images}")' + + docs = [] + if self._docs is not None: + docs = [f"{doc.format_for_llm()[:10]}..." for doc in self._docs] + return f'mellea.Message(role="{self.role}", content="{self.content}", images="{images}", documents="{docs}")' class ToolMessage(Message): diff --git a/mellea/stdlib/intrinsics/intrinsic.py b/mellea/stdlib/intrinsics/intrinsic.py new file mode 100644 index 00000000..4c54a55d --- /dev/null +++ b/mellea/stdlib/intrinsics/intrinsic.py @@ -0,0 +1,66 @@ +"""Module for Intrinsics.""" + +import pathlib +from copy import copy +from typing import cast + +from mellea.backends.adapters.catalog import AdapterType, fetch_intrinsic_metadata +from mellea.stdlib.base import CBlock, Component, TemplateRepresentation + + +class Intrinsic(Component): + """A component representing an intrinsic.""" + + def __init__( + self, intrinsic_name: str, intrinsic_kwargs: dict | None = None + ) -> None: + """A component for rewriting messages using intrinsics. + + Intrinsics are special components that transform a chat completion request. + These transformations typically take the form of: + - parameter changes (typically structured outputs) + - adding new messages to the chat + - editing existing messages + + An intrinsic component should correspond to a loaded adapter. + + Args: + intrinsic_name: the user-visible name of the intrinsic; must match a known + name in Mellea's intrinsics catalog. + intrinsic_kwargs: some intrinsics require kwargs when utilizing them; + provide them here + """ + self.metadata = fetch_intrinsic_metadata(intrinsic_name) + if intrinsic_kwargs is None: + intrinsic_kwargs = {} + self.intrinsic_kwargs = intrinsic_kwargs + + @property + def intrinsic_name(self): + """User-visible name of this intrinsic.""" + return self.metadata.name + + @property + def adapter_types(self) -> tuple[AdapterType, ...]: + """Tuple of available adapter types that implement this intrinsic.""" + return self.metadata.adapter_types + + def parts(self) -> list[Component | CBlock]: + """The set of all the constituent parts of the `Intrinsic`. + + Will need to be implemented by subclasses since not all intrinsics are output + as text / messages. + """ + raise NotImplementedError("parts isn't implemented by default") + + def format_for_llm(self) -> TemplateRepresentation | str: + """`Intrinsic` doesn't implement `format_for_default`. + + Formats the `Intrinsic` into a `TemplateRepresentation` or string. + + Returns: a `TemplateRepresentation` or string + """ + raise NotImplementedError( + "`Intrinsic` doesn't implement format_for_llm by default. You should only " + "use an `Intrinsic` as the action and not as a part of the context." + ) diff --git a/mellea/stdlib/intrinsics/rag.py b/mellea/stdlib/intrinsics/rag.py new file mode 100644 index 00000000..5885c900 --- /dev/null +++ b/mellea/stdlib/intrinsics/rag.py @@ -0,0 +1,285 @@ +"""Intrinsic functions related to retrieval-augmented generation.""" + +import collections.abc +import json + +import mellea.stdlib.functional as mfuncs +from mellea.backends.adapters.adapter import ( + AdapterMixin, + AdapterType, + GraniteCommonAdapter, +) +from mellea.stdlib.base import ChatContext, Document +from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics.intrinsic import Intrinsic + +_ANSWER_RELEVANCE_CORRECTION_METHODS = { + "Excessive unnecessary information": "removing the excessive information from the " + "draft response", + "Unduly restrictive": "providing answer without the unwarranted restriction, or " + "indicating that the desired answer is not available", + "Too vague or generic": "providing more crisp and to-the-point answer, or " + "indicating that the desired answer is not available", + "Contextual misalignment": "providing a response that answers the last user " + "inquiry, taking into account the context of the conversation", + "Misinterpreted inquiry": "providing answer only to the correct interpretation of " + "the inquiry, or attempting clarification if the inquiry is ambiguous or otherwise " + "confusing, or indicating that the desired answer is not available", + "No attempt": "providing a relevant response if an inquiry should be answered, or " + "providing a short response if the last user utterance contains no inquiry", +} +"""Prompting strings for the answer relevance rewriter. This model is a (a)LoRA adapter, +so it's important to stick to in-domain prompts.""" + + +def _call_intrinsic( + intrinsic_name: str, + context: ChatContext, + backend: AdapterMixin, + /, + kwargs: dict | None = None, +): + """Shared code for invoking intrinsics. + + :returns: Result of the call in JSON format. + """ + # Adapter needs to be present in the backend before it can be invoked. + # We must create the Adapter object in order to determine whether we need to create + # the Adapter object. + base_model_name = backend.base_model_name + if base_model_name is None: + raise ValueError("Backend has no model ID") + adapter = GraniteCommonAdapter( + intrinsic_name, adapter_type=AdapterType.LORA, base_model_name=base_model_name + ) + if adapter.qualified_name not in backend.list_adapters(): + backend.add_adapter(adapter) + + # Create the AST node for the action we wish to perform. + intrinsic = Intrinsic(intrinsic_name, intrinsic_kwargs=kwargs) + + # Execute the AST node. + model_output_thunk, _ = mfuncs.act( + intrinsic, + context, + backend, + # No rejection sampling, please + strategy=None, + ) + + # act() can return a future. Don't know how to handle one from non-async code. + assert model_output_thunk.is_computed() + + # Output of an Intrinsic action is the string representation of the output of the + # intrinsic. Parse the string. + result_str = model_output_thunk.value + if result_str is None: + raise ValueError("Model output is None.") + result_json = json.loads(result_str) + return result_json + + +def check_answerability( + question: str, + documents: collections.abc.Iterable[Document], + context: ChatContext, + backend: AdapterMixin, +) -> float: + """Test a user's question for answerability. + + Intrinsic function that checks whether the question in the last user turn of a + chat can be answered by a provided set of RAG documents. + + :param context: Chat context containing the conversation thus far + :param question: Question that the user has posed in response to the last turn in + ``context``. + :param documents: Document snippets retrieved that may or may not answer the + indicated question. + :param backend: Backend instance that supports adding the LoRA or aLoRA adapters + for answerability checks + + :return: Answerability score as a floating-point value from 0 to 1. + """ + result_json = _call_intrinsic( + "answerability", + context.add(Message("user", question, documents=list(documents))), + backend, + ) + return result_json["answerability_likelihood"] + + +def rewrite_question( + question: str, context: ChatContext, backend: AdapterMixin +) -> float: + """Rewrite a user's question for retrieval. + + Intrinsic function that rewrites the question in the next user turn into a + self-contained query that can be passed to the retriever. + + :param context: Chat context containing the conversation thus far + :param question: Question that the user has posed in response to the last turn in + ``context``. + :param backend: Backend instance that supports adding the LoRA or aLoRA adapters + + :return: Rewritten version of ``question``. + """ + result_json = _call_intrinsic( + "query_rewrite", context.add(Message("user", question)), backend + ) + return result_json["rewritten_question"] + + +def find_citations( + response: str, + documents: collections.abc.Iterable[Document], + context: ChatContext, + backend: AdapterMixin, +) -> list[dict]: + """Find information in documents that supports an assistant response. + + Intrinsic function that finds sentences in RAG documents that support sentences + in a potential assistant response to a user question. + + :param context: Context of the dialog between user and assistant at the point where + the user has just asked a question that will be answered with RAG documents + :param response: Potential assistant response + :param documents: Documents at were used to generate ``response``. These documents + should set the ``doc_id`` field; otherwise the intrinsic will be unable to + specify which document was the source of a given citation. + :param backend: Backend that supports one of the adapters that implements this + intrinsic. + :return: List of records with the following fields: + * ``response_begin`` + * ``response_end`` + * ``response_text`` + * ``citation_doc_id`` + * ``citation_begin`` + * ``citation_end`` + * ``citation_text`` + Begin and end offsets are character offsets into their respective UTF-8 strings. + """ + result_json = _call_intrinsic( + "citations", + context.add(Message("assistant", response, documents=list(documents))), + backend, + ) + return result_json + + +def check_context_relevance( + question: str, document: Document, context: ChatContext, backend: AdapterMixin +) -> float: + """Test whether a document is relevant to a user's question. + + Intrinsic function that checks whether a single document contains part or all of + the answer to a user's question. Does not consider the context in which the + question was asked. + + :param context: The chat up to the point where the user asked a question. + :param question: Question that the user has posed. + :param document: A retrieved document snippet + :param backend: Backend instance that supports the adapters that implement this + intrinsic + + :return: Context relevance score as a floating-point value from 0 to 1. + """ + result_json = _call_intrinsic( + "context_relevance", + context.add(Message("user", question)), + backend, + # Target document is passed as an argument + kwargs={"document_content": document.text}, + ) + return result_json["context_relevance"] + + +def flag_hallucinated_content( + response: str, + documents: collections.abc.Iterable[Document], + context: ChatContext, + backend: AdapterMixin, +) -> float: + """Flag potentially-hallucinated sentences in an agent's response. + + Intrinsic function that checks whether the sentences in an agent's response to a + user question are faithful to the retrieved document snippets. Sentences that do not + align with the retrieved snippets are flagged as potential hallucinations. + + :param context: A chat log that ends with a user asking a question + :param response: The assistant's response to the user's question in the last turn + of ``context`` + :param documents: Document snippets that were used to generate ``response`` + :param backend: Backend instance that supports the adapters that implement this + intrinsic + + :return: List of records with the following fields: + * response_begin + * response_end + * response_text + * faithfulness_likelihood + * explanation + """ + result_json = _call_intrinsic( + "hallucination_detection", + context.add(Message("assistant", response, documents=list(documents))), + backend, + ) + return result_json + + +def rewrite_answer_for_relevance( + response: str, + documents: collections.abc.Iterable[Document], + context: ChatContext, + backend: AdapterMixin, + /, + rewrite_threshold: float = 0.5, +) -> str: + """Rewrite an assistant answer to improve relevance to the user's question. + + :param context: A chat log that ends with a user asking a question + :param response: The assistant's response to the user's question in the last turn + of ``context`` + :param documents: Document snippets that were used to generate ``response`` + :param backend: Backend instance that supports the adapters that implement this + intrinsic + :param rewrite_threshold: Number between 0.0 and 1.0 that determines how eagerly + to skip rewriting the assistant's answer for relevance. 0.0 means never rewrite + and 1.0 means always rewrite. + + :returns: Either the original response, or a rewritten version of the original + response. + """ + # First run the classifier to determine the likelihood of a relevant answer + # Output will have three fields: + # * answer_relevance_analysis + # * answer_relevance_category + # * answer_relevance_likelihood + result_json = _call_intrinsic( + "answer_relevance_classifier", + context.add(Message("assistant", response, documents=list(documents))), + backend, + ) + if result_json["answer_relevance_likelihood"] >= rewrite_threshold: + return response + + # If we get here, the classifier indicated a likely irrelevant response. Trigger + # rewrite. + # Rewrite needs a prompt string that is an expanded version of the classifier's + # short output. + correction_method = _ANSWER_RELEVANCE_CORRECTION_METHODS[ + result_json["answer_relevance_category"] + ] + + result_json = _call_intrinsic( + "answer_relevance_rewriter", + context.add(Message("assistant", response, documents=list(documents))), + backend, + kwargs={ + "answer_relevance_category": result_json["answer_relevance_category"], + "answer_relevance_analysis": result_json["answer_relevance_category"], + "correction_method": correction_method, + }, + ) + # Unpack boxed string + return result_json["answer_relevance_rewrite"] diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 97ee2bec..4cef5e35 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -1,22 +1,23 @@ """Requirements are a special type of Component used as input to the "validate" step in Instruct/Validate/Repair design patterns.""" import inspect +import json import re from collections.abc import Callable from copy import copy from typing import Any, overload from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.aloras import Alora +from mellea.backends.adapters.adapter import AdapterType from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( CBlock, Component, Context, - GenerateLog, ModelOutputThunk, TemplateRepresentation, ) +from mellea.stdlib.intrinsics.intrinsic import Intrinsic def default_output_to_bool(x: CBlock | str) -> bool: @@ -185,19 +186,53 @@ class LLMaJRequirement(Requirement): use_aloras: bool = False -class ALoraRequirement(Requirement): +def requirement_check_to_bool(x: CBlock | str) -> bool: + """Checks if a given output should be marked converted to `True`. + + By default, the requirement check alora outputs: `{"requirement_likelihood": 0.0}`. + True if >.5 + """ + output = str(x) + req_dict: dict[str, Any] = json.loads(output) + + likelihood = req_dict.get("requirement_likelihood", None) + if likelihood is None: + FancyLogger.get_logger().warning( + f"could not get value from alora requirement output; looking for `requirement_likelihood` in {req_dict}" + ) + return False + + if likelihood > 0.5: + return True + + return False + + +class ALoraRequirement(Requirement, Intrinsic): """A requirement that always uses an (possibly specified) ALora. If an exception is thrown during the ALora execution path, `mellea` will fall back to LLMaJ. But that is the only case where LLMaJ will be used.""" - def __init__(self, description: str, alora: Alora | None = None): + def __init__(self, description: str, intrinsic_name: str | None = None): """A requirement that is validated by an ALora. Args: description: See `Requirement.__init__` - alora: if None, the ALora with name "constraint" will be used. + intrinsic_name: the name of the intrinsic; must match the adapter """ - super().__init__(description, validation_fn=None) + # TODO: We may want to actually do the validation_fn here so that we can set the score. + super().__init__( + description, validation_fn=None, output_to_bool=requirement_check_to_bool + ) self.use_aloras: bool = True - self.alora = alora + + if intrinsic_name is None: + intrinsic_name = "requirement_check" + + # Initialize the other side of the inheritance tree + Intrinsic.__init__( + self, + intrinsic_name=intrinsic_name, + intrinsic_kwargs={"requirement": f"{self.description}"}, + ) class ScorerRequirement(Requirement): diff --git a/pyproject.toml b/pyproject.toml index 3105f990..05250525 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,9 @@ dependencies = [ "mistletoe>=1.4.0", "huggingface-hub>=0.33.4", "pillow", - "math_verify", # Needed for Majority Voting Sampling Strategies. # Needed for Majority Voting Sampling Strategies. - "rouge_score", + "granite-common>=0.3.5", # Needed for Intrinsics. + "math_verify", # Needed for Majority Voting Sampling Strategies. + "rouge_score", # Needed for Majority Voting Sampling Strategies. "llm-sandbox[docker]>=0.3.23", ] @@ -72,7 +73,7 @@ hf = [ "datasets>=4.0.0", "outlines-core==0.1.26", "outlines", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it - "peft>=0.16.0", + "peft>=0.18.0", # aLoRA support was added in Peft 0.18.0 "transformers>=4.53.2", "trl==0.19.1", ] diff --git a/test/backends/test_adapters/intrinsics-data/answerability.yaml b/test/backends/test_adapters/intrinsics-data/answerability.yaml new file mode 100644 index 00000000..2b72cd18 --- /dev/null +++ b/test/backends/test_adapters/intrinsics-data/answerability.yaml @@ -0,0 +1,25 @@ +# Model name string, or null to use whatever is provided in the chat completion request. +model: ~ +# JSON schema of the model's output +response_format: | + { + "type": "string", + "enum": ["answerable", "unanswerable"] + } +transformations: + # Convert categorical answer to continuous value by decoding logprobs + - type: likelihood + categories_to_values: + "answerable": 1.0 + "unanswerable": 0.0 + input_path: [] + # Convert scalar value to a record for consistency with other intrinsics + - type: nest + input_path: [] + field_name: "answerability_likelihood" +instruction: ~ +parameters: + # "unanswerable" can be 6 tokens at high temperatures + max_completion_tokens: 6 +# No sentence boundary detection +sentence_boundaries: ~ diff --git a/test/backends/test_adapters/test_adapter.py b/test/backends/test_adapters/test_adapter.py new file mode 100644 index 00000000..f5abbb84 --- /dev/null +++ b/test/backends/test_adapters/test_adapter.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from mellea.backends.adapters.adapter import GraniteCommonAdapter + + +# The backend tests handle most of the adapter testing. Do a basic test here +# to make sure init and config loading work. +def test_adapter_init(): + dir_file = pathlib.Path(__file__).parent.joinpath("intrinsics-data") + answerability_file = f"{dir_file}/answerability.yaml" + + adapter = GraniteCommonAdapter("answerability", config_file=answerability_file) + + assert adapter.config is not None + assert adapter.config["parameters"]["max_completion_tokens"] == 6 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 75779985..79434097 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -1,39 +1,35 @@ import asyncio + import pydantic import pytest from typing_extensions import Annotated from mellea import MelleaSession -from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras +from mellea.backends.adapters.adapter import GraniteCommonAdapter from mellea.backends.cache import SimpleLRUCache from mellea.backends.formatter import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import ( - CBlock, - ChatContext, - Context, - ModelOutputThunk, - SimpleContext, -) -from mellea.stdlib.requirement import ( - ALoraRequirement, - LLMaJRequirement, - Requirement, - ValidationResult, - default_output_to_bool, -) +from mellea.stdlib.base import (CBlock, ChatContext, Context, ModelOutputThunk, + SimpleContext) +from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement, + Requirement, ValidationResult, + default_output_to_bool) @pytest.fixture(scope="module") def backend(): """Shared HuggingFace backend for all tests in this module.""" backend = LocalHFBackend( - model_id="ibm-granite/granite-3.2-8b-instruct", + model_id="ibm-granite/granite-3.3-8b-instruct", formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"), cache=SimpleLRUCache(5), ) - add_granite_aloras(backend) + backend.add_adapter( + GraniteCommonAdapter( + "requirement_check", base_model_name=backend.base_model_name + ) + ) return backend @@ -44,6 +40,23 @@ def session(backend): yield session session.reset() +@pytest.mark.qualitative +def test_adapters(backend): + assert len(backend._added_adapters.items()) > 0 + + expected_qualified_name = "requirement_check_alora" + adapter = backend._added_adapters[expected_qualified_name] + backend.load_adapter(adapter.qualified_name) + assert adapter.qualified_name in backend._loaded_adapters + + # Ensure you can load the same adapter twice. + backend.load_adapter(adapter.qualified_name) + + # Ensure you can unload an adapter. + backend.unload_adapter(adapter.qualified_name) + backend.unload_adapter(adapter.qualified_name) + assert adapter.qualified_name not in backend._loaded_adapters + @pytest.mark.qualitative def test_system_prompt(session): @@ -54,27 +67,6 @@ def test_system_prompt(session): print(result) -@pytest.mark.qualitative -async def test_constraint_alora(session, backend): - answer = session.instruct( - "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.", - model_options={ - ModelOption.MAX_NEW_TOKENS: 300 - }, # Until aloras get a bit better, try not to abruptly end generation. - ) - - alora_output = backend.get_aloras()[ - 0 - ].generate_using_strings( - input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", - response=str(answer), - constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", - force_yn=False, # make sure that the alora naturally output Y and N without constrained generation - ) - await alora_output.avalue() - assert alora_output.value in ["Y", "N"], alora_output - - @pytest.mark.qualitative def test_constraint_lora_with_requirement(session, backend): answer = session.instruct( @@ -89,7 +81,7 @@ def test_constraint_lora_with_requirement(session, backend): assert len(validation_outputs) == 1 val_result = validation_outputs[0] assert isinstance(val_result, ValidationResult) - assert str(val_result.reason) in ["Y", "N"] + assert "requirement_likelihood" in str(val_result.reason) @pytest.mark.qualitative @@ -122,7 +114,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): assert len(validation_outputs) == 1 val_result = validation_outputs[0] assert isinstance(val_result, ValidationResult) - assert str(val_result.reason) in ["Y", "N"] + assert "requirement_likelihood" in str(val_result.reason) # Ensure the ValidationResult has its thunk and context set. Ensure the context has # the correct actions / results in it. @@ -149,6 +141,7 @@ def test_llmaj_req_does_not_use_alora(session, backend): val_result = validation_outputs[0] assert isinstance(val_result, ValidationResult) assert str(val_result.reason) not in ["Y", "N"] + assert "requirement_likelihood" not in str(val_result.reason) @pytest.mark.qualitative @@ -190,7 +183,10 @@ class Email(pydantic.BaseModel): body: str output = session.instruct( - "Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ", + "Write a short email to Olivia, thanking her for organizing a sailing " + "activity. " + "Her email is olivia@example.com. " + "No more than two sentences. ", format=Email, model_options={ModelOption.MAX_NEW_TOKENS: 2**8}, ) @@ -201,7 +197,7 @@ class Email(pydantic.BaseModel): print(email) print("address:", email.to.email_address) - assert "@" in email.to.email_address, "The @ sign should be in the meail address." + assert "@" in email.to.email_address, "The @ sign should be in the email address." assert email.to.email_address.endswith("example.com"), ( "The email address should be at example.com" ) diff --git a/test/backends/test_huggingface_tools.py b/test/backends/test_huggingface_tools.py index a9c833b0..0df5f3dc 100644 --- a/test/backends/test_huggingface_tools.py +++ b/test/backends/test_huggingface_tools.py @@ -4,7 +4,6 @@ import mellea.backends.model_ids as model_ids from mellea import MelleaSession -from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras from mellea.backends.cache import SimpleLRUCache from mellea.backends.formatter import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index d5acc79c..5c1ddf86 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -8,6 +8,10 @@ from mellea.stdlib.base import CBlock, SimpleContext from mellea.stdlib.chat import Message from mellea.stdlib.sampling import RejectionSamplingStrategy +from mellea.backends import model_ids + + +_MODEL_ID = f"ollama_chat/{model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name}" @pytest.fixture(scope="function") @@ -22,12 +26,12 @@ def backend(gh_run: int): url = url.replace("127.0.0.1", "http://localhost") return LiteLLMBackend( - model_id="ollama_chat/llama3.2:1b", + model_id=_MODEL_ID, base_url=url, model_options={"api_base": url}, ) else: - return LiteLLMBackend() + return LiteLLMBackend(model_id=_MODEL_ID) @pytest.fixture(scope="function") @@ -106,9 +110,13 @@ def test_litellm_ollama_instruct_options(session): model_options = { ModelOption.SEED: 123, ModelOption.TEMPERATURE: 0.5, - ModelOption.THINKING: True, ModelOption.MAX_NEW_TOKENS: 100, - "reasoning_effort": True, + + # Ollama thinking controls currently broken on Granite; see + # https://github.com/ollama/ollama/issues/10983 + # TODO: Re-enable when this upstream bug gets fixed. + #ModelOption.THINKING: True, + #"reasoning_effort": True, "homer_simpson": "option should be kicked out", } diff --git a/test/backends/test_openai_vllm/environment.yml b/test/backends/test_openai_vllm/environment.yml index 2d0b9e8e..5ca4116b 100644 --- a/test/backends/test_openai_vllm/environment.yml +++ b/test/backends/test_openai_vllm/environment.yml @@ -9,3 +9,4 @@ dependencies: variables: VLLM_USE_PRECOMPILED: 1 # need this flag for alora fork, installation fails otherwise VLLM_ALLOW_RUNTIME_LORA_UPDATING: True # allow loading (a)lora through POST http://localhost:8000/v1/load_lora_adapter + VLLM_DOWNLOAD_RAG_INTRINSICS: False # if True, download the rag-intrinsics-lib (https://huggingface.co/ibm-granite/rag-intrinsics-lib/tree/main); only required for remote vllm servers diff --git a/test/backends/test_openai_vllm/serve.sh b/test/backends/test_openai_vllm/serve.sh index 7746eed1..9b061d0e 100755 --- a/test/backends/test_openai_vllm/serve.sh +++ b/test/backends/test_openai_vllm/serve.sh @@ -26,8 +26,18 @@ # see environment.yml. export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True +# Mellea makes assumptions about the location of the adapter files on the server. By default, it assumes +# referenced adapters are at `./rag-intrinsics-lib/$adapter_name/$adapter_type/$base_model_name`. You can +# change this behavior by defining custom adapter classes that override the path. +# You will also need to set the OpenAIBackend's server_type to REMOTE_VLLM. +if [[ "$VLLM_ALLOW_RUNTIME_LORA_UPDATING" == "True" ]] && [[ "$VLLM_DOWNLOAD_RAG_INTRINSICS" == "True" ]] +then + echo "downloading rag-intrinsics-lib from huggingface" + hf download ibm-granite/rag-intrinsics-lib --local-dir ./rag-intrinsics-lib +fi + echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log" -vllm serve ibm-granite/granite-3.2-8b-instruct \ +vllm serve ibm-granite/granite-3.3-8b-instruct \ --enable-activated-lora \ --enable-lora \ --dtype bfloat16 \ diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index ff3ffa20..8fc05a1b 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -1,23 +1,18 @@ # test/rits_backend_tests/test_openai_integration.py -from contextvars import Context -from mellea import MelleaSession -from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext -from mellea.backends.openai import OpenAIBackend -from mellea.backends.aloras.openai.granite_aloras import add_granite_aloras -from mellea.stdlib.requirement import ( - Requirement, - ALoraRequirement, - LLMaJRequirement, - req, -) -from mellea.backends.formatter import TemplateFormatter -from mellea.backends.types import ModelOption +import os import pydantic -from typing_extensions import Annotated import pytest -import os +from typing_extensions import Annotated +from mellea import MelleaSession +from mellea.backends.adapters.adapter import GraniteCommonAdapter +from mellea.backends.formatter import TemplateFormatter +from mellea.backends.openai import OpenAIBackend +from mellea.backends.types import ModelOption, _ServerType +from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk +from mellea.stdlib.requirement import (ALoraRequirement, LLMaJRequirement, + Requirement, req) # The vllm tests are disabled by default, because we need a test environment with the vLLM server running. # We use an env var VLLM_TESTS_ENABLED to enable these tests. @@ -35,8 +30,8 @@ class TestOpenAIBackend: backend = OpenAIBackend( - model_id="ibm-granite/granite-3.2-8b-instruct", - formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"), + model_id="ibm-granite/granite-3.3-8b-instruct", + formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"), base_url="http://0.0.0.0:8000/v1", api_key="EMPTY", ) @@ -137,13 +132,30 @@ class Answer(pydantic.BaseModel): class TestOpenAIALoraStuff: backend = OpenAIBackend( - model_id="ibm-granite/granite-3.2-8b-instruct", + model_id="ibm-granite/granite-3.3-8b-instruct", formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"), base_url="http://localhost:8000/v1", api_key="EMPTY", ) + backend.add_adapter(GraniteCommonAdapter("requirement_check", + base_model_name=backend.base_model_name)) + m = MelleaSession(backend, ctx=ChatContext()) - add_granite_aloras(backend) + + def test_adapters(self): + assert len(self.backend._added_adapters.items()) > 0 + + adapter = self.backend._added_adapters["requirement_check_alora"] + self.backend.load_adapter(adapter.qualified_name) + assert adapter.qualified_name in self.backend._loaded_adapters + + # Ensure you can load the same adapter twice. + self.backend.load_adapter(adapter.qualified_name) + + # Ensure you can unload an adapter. + self.backend.unload_adapter(adapter.qualified_name) + self.backend.unload_adapter(adapter.qualified_name) + assert adapter.qualified_name not in self.backend._loaded_adapters def test_system_prompt(self): self.m.reset() @@ -153,23 +165,6 @@ def test_system_prompt(self): ) print(result) - @pytest.mark.xfail - def test_constraint_alora(self): - self.m.reset() - answer = self.m.instruct( - "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" - ) - alora_output = self.backend.get_aloras()[ - 0 - ].generate_using_strings( - input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", - response=str(answer), - constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", - force_yn=False, # make sure that the alora naturally output Y and N without constrained generation - ) - assert alora_output in ["Y", "N"], alora_output - self.m.reset() - def test_constraint_lora_with_requirement(self): self.m.reset() answer = self.m.instruct( @@ -182,7 +177,7 @@ def test_constraint_lora_with_requirement(self): ) assert len(validation_outputs) == 1 val_result = validation_outputs[0] - assert str(val_result.reason) in ["Y", "N"] + assert "requirement_likelihood" in str(val_result.reason) self.m.reset() def test_constraint_lora_override(self): @@ -215,7 +210,7 @@ def test_constraint_lora_override_does_not_override_alora(self): ) assert len(validation_outputs) == 1 non_alora_output = validation_outputs[0] - assert str(non_alora_output.reason) in ["Y", "N"] + assert "requirement_likelihood" in str(non_alora_output.reason) # Ensure the ValidationResult has its thunk and context set. Ensure the context has # the correct actions / results in it. diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index e19c6adc..917619e0 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -26,6 +26,5 @@ def format_for_llm(self) -> str: c = _ClosuredComponent() assert len(c.parts()) == 0 - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_chat.py b/test/stdlib_basics/test_chat.py new file mode 100644 index 00000000..3ae911b9 --- /dev/null +++ b/test/stdlib_basics/test_chat.py @@ -0,0 +1,25 @@ +import pytest +from mellea.backends.openai import OpenAIBackend +from mellea.stdlib.base import Document +from mellea.stdlib.chat import Message + +def test_message_with_docs(): + doc = Document("I'm text!", "Im a title!") + msg = Message("user", "hello", documents=[doc]) + + assert msg._docs is not None + assert doc in msg._docs + + docs = OpenAIBackend.messages_to_docs([msg]) + assert len(docs) == 1 + assert docs[0]["text"] == doc.text + assert docs[0]["title"] == doc.title + + assert "Im a titl..." in str(msg) + + tr = msg.format_for_llm() + assert tr.args["documents"] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/stdlib_intrinsics/test_rag/test_rag.py b/test/stdlib_intrinsics/test_rag/test_rag.py new file mode 100644 index 00000000..18985f02 --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/test_rag.py @@ -0,0 +1,185 @@ +"""Tests of the code in ``mellea.stdlib.intrinsics.rag``""" + +import gc +import json +import os +import pathlib + +import pytest +import torch + +from mellea.backends.huggingface import LocalHFBackend +from mellea.stdlib.base import ChatContext, Document +from mellea.stdlib.chat import Message +from mellea.stdlib.intrinsics import rag + +DATA_ROOT = pathlib.Path(os.path.dirname(__file__)) / "testdata" +"""Location of data files for the tests in this file.""" + + +BASE_MODEL = "ibm-granite/granite-3.3-2b-instruct" + + +@pytest.fixture(name="backend") +def _backend(): + """Backend used by the tests in this file.""" + + # Prevent thrashing if the default device is CPU + torch.set_num_threads(4) + + backend_ = LocalHFBackend(model_id=BASE_MODEL) + yield backend_ + + # Code after yield is cleanup code. + # Free GPU memory with extreme prejudice. + del backend_ + gc.collect() # Force a collection of the newest generation + gc.collect() + gc.collect() # Hopefully force a collection of the oldest generation + torch.cuda.empty_cache() + + +def _read_input_json(file_name: str): + """Shared code for reading data stored in JSON files and converting to Mellea + types.""" + with open(DATA_ROOT / "input_json" / file_name, encoding="utf-8") as f: + json_data = json.load(f) + + # Data is assumed to be an OpenAI chat completion request. Convert to Mellea format. + context = ChatContext() + for m in json_data["messages"][:-1]: + context = context.add(Message(m["role"], m["content"])) + + # Store the user turn at the end of the messages list separately so that tests can + # play it back. + next_user_turn = json_data["messages"][-1]["content"] + + documents = [] + if "extra_body" in json_data and "documents" in json_data["extra_body"]: + for d in json_data["extra_body"]["documents"]: + documents.append( + Document( + text=d["text"], + doc_id=d["doc_id"], + ) + ) + return context, next_user_turn, documents + + +def _read_output_json(file_name: str): + """Shared code for reading canned outputs stored in JSON files and converting + to Mellea types.""" + with open(DATA_ROOT / "output_json" / file_name, encoding="utf-8") as f: + json_data = json.load(f) + + # Output is in OpenAI chat completion response format. Assume only one choice. + result_str = json_data["choices"][0]["message"]["content"] + + # Intrinsic outputs are always JSON, serialized to a string for OpenAI + # compatibility. + return json.loads(result_str) + + +@pytest.mark.qualitative +def test_answerability(backend): + """Verify that the answerability intrinsic functions properly.""" + context, next_user_turn, documents = _read_input_json("answerability.json") + + # First call triggers adapter loading + result = rag.check_answerability(next_user_turn, documents, context, backend) + assert pytest.approx(result) == 1.0 + + # Second call hits a different code path from the first one + result = rag.check_answerability(next_user_turn, documents, context, backend) + assert pytest.approx(result) == 1.0 + + +@pytest.mark.qualitative +def test_query_rewrite(backend): + """Verify that the answerability intrinsic functions properly.""" + context, next_user_turn, _ = _read_input_json("query_rewrite.json") + expected = ( + "Is Rex, the dog, more likely to get fleas because he spends a lot of " + "time outdoors?" + ) + + # First call triggers adapter loading + result = rag.rewrite_question(next_user_turn, context, backend) + assert result == expected + + # Second call hits a different code path from the first one + result = rag.rewrite_question(next_user_turn, context, backend) + assert result == expected + + +@pytest.mark.qualitative +def test_citations(backend): + """Verify that the citations intrinsic functions properly.""" + context, assistant_response, docs = _read_input_json("citations.json") + expected = _read_output_json("citations.json") + + # First call triggers adapter loading + result = rag.find_citations(assistant_response, docs, context, backend) + assert result == expected + + # Second call hits a different code path from the first one + result = rag.find_citations(assistant_response, docs, context, backend) + assert result == expected + + +@pytest.mark.qualitative +def test_context_relevance(backend): + """Verify that the context relevance intrinsic functions properly.""" + context, question, docs = _read_input_json("context_relevance.json") + + # Context relevance can only check against a single document at a time. + document = docs[0] + + # First call triggers adapter loading + result = rag.check_context_relevance(question, document, context, backend) + assert pytest.approx(result, abs=2e-2) == 0.45 + + # Second call hits a different code path from the first one + result = rag.check_context_relevance(question, document, context, backend) + assert pytest.approx(result, abs=2e-2) == 0.45 + + +@pytest.mark.qualitative +def test_hallucination_detection(backend): + """Verify that the hallucination detection intrinsic functions properly.""" + context, assistant_response, docs = _read_input_json("hallucination_detection.json") + expected = _read_output_json("hallucination_detection.json") + + # First call triggers adapter loading + result = rag.flag_hallucinated_content(assistant_response, docs, context, backend) + # pytest.approx() chokes on lists of records, so we do this complicated dance. + for r, e in zip(result, expected, strict=True): + assert pytest.approx(r, abs=2e-2) == e + + # Second call hits a different code path from the first one + result = rag.flag_hallucinated_content(assistant_response, docs, context, backend) + for r, e in zip(result, expected, strict=True): + assert pytest.approx(r, abs=2e-2) == e + + +@pytest.mark.qualitative +def test_answer_relevance(backend): + """Verify that the answer relevance composite intrinsic functions properly.""" + context, answer, docs = _read_input_json("answer_relevance.json") + expected_rewrite = "Alice, Bob, and Carol attended the meeting." + + # First call triggers adapter loading + result = rag.rewrite_answer_for_relevance(answer, docs, context, backend) + assert result == expected_rewrite + + # Second call hits a different code path from the first one + result = rag.rewrite_answer_for_relevance(answer, docs, context, backend) + assert result == expected_rewrite + + # Canned input always gets rewritten. Set threshold to disable the rewrite. + result = rag.rewrite_answer_for_relevance(answer, docs, context, backend, + rewrite_threshold=0.0) + assert result == answer + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/answer_relevance.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/answer_relevance.json new file mode 100644 index 00000000..30779c0c --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/answer_relevance.json @@ -0,0 +1,13 @@ +{ + "messages": [ + {"role": "user", "content": "Who attended the meeting?"}, + {"role": "assistant", "content": "Many people attended the meeting."} + ], + "extra_body": { + "documents": [ + {"doc_id": "1", "text": "Meeting attendees: Alice, Bob, Carol."}, + {"doc_id": "2", "text": "Meeting time: 9:00 am to 11:00 am."} + ] + }, + "temperature": 0.0 +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/answerability.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/answerability.json new file mode 100644 index 00000000..8b4df2c7 --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/answerability.json @@ -0,0 +1,20 @@ +{ + "messages": [ + { + "role": "assistant", + "content": "Hello there, how can I help you?" + }, + { + "content": "What is the square root of 4?", + "role": "user" + } + ], + "extra_body": { + "documents": [ + { + "doc_id": "1", + "text": "The square root of 4 is 2." + } + ] + } +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/citations.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/citations.json new file mode 100644 index 00000000..f22cfc9a --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/citations.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "role": "user", + "content": "How does Murdoch's expansion in Australia compare to his expansion in New Zealand?" + }, + { + "role": "assistant", + "content": "Murdoch expanded in Australia and New Zealand by acquiring and expanding local newspapers. I do not have information about his expansion in New Zealand after purchasing The Dominion." + } + ], + "extra_body": { + "documents": [ + { + "doc_id": "0", + "text": "Keith Rupert Murdoch was born on 11 March 1931 in Melbourne, Australia, the son of Sir Keith Murdoch (1885-1952) and Dame Elisabeth Murdoch (nee Greene; 1909-2012). He is of English, Irish, and Scottish ancestry. Murdoch's parents were also born in Melbourne. Keith Murdoch was a war correspondent and later a regional newspaper magnate owning two newspapers in Adelaide, South Australia, and a radio station in a faraway mining town. Following his father's death, when he was 21, Murdoch returned from Oxford to take charge of the family business News Limited, which had been established in 1923. Rupert Murdoch turned its Adelaide newspaper, The News, its main asset, into a major success. He began to direct his attention to acquisition and expansion, buying the troubled Sunday Times in Perth, Western Australia (1956) and over the next few years acquiring suburban and provincial newspapers in New South Wales, Queensland, Victoria and the Northern Territory, including the Sydney afternoon tabloid, The Daily Mirror (1960). The Economist describes Murdoch as \"inventing the modern tabloid\", as he developed a pattern for his newspapers, increasing sports and scandal coverage and adopting eye-catching headlines. Murdoch's first foray outside Australia involved the purchase of a controlling interest in the New Zealand daily The Dominion. In January 1964, while touring New Zealand with friends in a rented Morris Minor after sailing across the Tasman, Murdoch read of a takeover bid for the Wellington paper by the British-based Canadian newspaper magnate, Lord Thomson of Fleet. On the spur of the moment, he launched a counter-bid. A four-way battle for control ensued in which the 32-year-old Murdoch was ultimately successful. Later in 1964, Murdoch launched The Australian, Australia's first national daily newspaper, which was based first in Canberra and later in Sydney. In 1972, Murdoch acquired the Sydney morning tabloid The Daily Telegraph from Australian media mogul Sir Frank Packer, who later regretted selling it to him. In 1984, Murdoch was appointed Companion of the Order of Australia (AC) for services to publishing. In 1999, Murdoch significantly expanded his music holdings in Australia by acquiring the controlling share in a leading Australian independent label, Michael Gudinski's Mushroom Records; he merged that with Festival Records, and the result was Festival Mushroom Records (FMR). Both Festival and FMR were managed by Murdoch's son James Murdoch for several years." + }, + { + "doc_id": "1", + "text": "This document has nothing to do with Rupert Murdoch. This document is two sentences long." + } + ] + } +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/context_relevance.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/context_relevance.json new file mode 100644 index 00000000..84cdf4e5 --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/context_relevance.json @@ -0,0 +1,16 @@ +{ + "messages": [ + { + "content": "Who is the CEO of Microsoft?", + "role": "user" + } + ], + "extra_body": { + "documents": [ + { + "doc_id": "1", + "text": "Microsoft Corporation is an American multinational corporation and technology conglomerate headquartered in Redmond, Washington.[2] Founded in 1975, the company became influential in the rise of personal computers through software like Windows, and the company has since expanded to Internet services, cloud computing, video gaming and other fields. Microsoft is the largest software maker, one of the most valuable public U.S. companies,[a] and one of the most valuable brands globally." + } + ] + } +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/hallucination_detection.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/hallucination_detection.json new file mode 100644 index 00000000..f224ed20 --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/hallucination_detection.json @@ -0,0 +1,24 @@ +{ + "messages": [ + { + "role": "assistant", + "content": "Hello there, how can I help you?" + }, + { + "content": "Tell me about some yellow fish.", + "role": "user" + }, + { + "role": "assistant", + "content": "Purple bumble fish are yellow. Green bumble fish are also yellow." + } + ], + "extra_body": { + "documents": [ + { + "doc_id": "1", + "text": "The only type of fish that is yellow is the purple bumble fish." + } + ] + } +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/input_json/query_rewrite.json b/test/stdlib_intrinsics/test_rag/testdata/input_json/query_rewrite.json new file mode 100644 index 00000000..0c36933d --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/input_json/query_rewrite.json @@ -0,0 +1,29 @@ +{ + "messages": [ + { + "role": "assistant", + "content": "Welcome to pet questions!" + }, + { + "role": "user", + "content": "I have two pets, a dog named Rex and a cat named Lucy." + }, + { + "role": "assistant", + "content": "Great, what would you like to share about them?" + }, + { + "role": "user", + "content": "Rex spends a lot of time in the backyard and outdoors, and Luna is always inside." + }, + { + "role": "assistant", + "content": "Sounds good! Rex must love exploring outside, while Lucy probably enjoys her cozy indoor life." + }, + { + "role": "user", + "content": "But is he more likely to get fleas because of that?" + } + ], + "temperature": 0.0 +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/output_json/citations.json b/test/stdlib_intrinsics/test_rag/testdata/output_json/citations.json new file mode 100644 index 00000000..804f64f4 --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/output_json/citations.json @@ -0,0 +1,11 @@ +{ + "choices": [ + { + "index": 0, + "message": { + "content": "[{\"response_begin\": 0, \"response_end\": 96, \"response_text\": \"Murdoch expanded in Australia and New Zealand by acquiring and expanding local newspapers. \", \"citation_doc_id\": \"0\", \"citation_begin\": 2468, \"citation_end\": 3533, \"citation_text\": \"He began to direct his attention to acquisition and expansion, buying the troubled Sunday Times in Perth, Western Australia (1956) and over the next few years acquiring suburban and provincial newspapers in New South Wales, Queensland, Victoria and the Northern Territory, including the Sydney afternoon tabloid, The Daily Mirror (1960). \"}, {\"response_begin\": 0, \"response_end\": 96, \"response_text\": \"Murdoch expanded in Australia and New Zealand by acquiring and expanding local newspapers. \", \"citation_doc_id\": \"0\", \"citation_begin\": 4792, \"citation_end\": 6183, \"citation_text\": \"Murdoch's first foray outside Australia involved the purchase of a controlling interest in the New Zealand daily The Dominion. \"}]", + "role": "assistant" + } + } + ] +} \ No newline at end of file diff --git a/test/stdlib_intrinsics/test_rag/testdata/output_json/hallucination_detection.json b/test/stdlib_intrinsics/test_rag/testdata/output_json/hallucination_detection.json new file mode 100644 index 00000000..7546817e --- /dev/null +++ b/test/stdlib_intrinsics/test_rag/testdata/output_json/hallucination_detection.json @@ -0,0 +1,11 @@ +{ + "choices": [ + { + "index": 0, + "message": { + "content": "[{\"response_begin\": 0, \"response_end\": 36, \"response_text\": \"Purple bumble fish are yellow. \", \"faithfulness_likelihood\": 0.2062460112028628, \"explanation\": \"This sentence makes a factual claim about the color of fish. However, the provided document only mentions one type of fish that is yellow, which is the purple bumble fish. There is no information about green bumble fish in the document, so the claim about green bumble fish being yellow cannot be verified.\"}, {\"response_begin\": 36, \"response_end\": 70, \"response_text\": \"Green bumble fish are also yellow.\", \"faithfulness_likelihood\": 0.006380047389365753, \"explanation\": \"This sentence makes a factual claim about the color of fish. However, the provided document only mentions one type of fish that is yellow, which is the purple bumble fish. There is no information about green bumble fish in the document, so the claim about green bumble fish being yellow cannot be verified.\"}]", + "role": "assistant" + } + } + ] +} \ No newline at end of file diff --git a/uv.lock b/uv.lock index 0c58fde1..ee281b4a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and python_full_version < '4' and sys_platform == 'darwin'", @@ -850,7 +850,7 @@ name = "colorlog" version = "6.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "(python_full_version < '3.14' and sys_platform == 'win32') or (python_full_version >= '4' and sys_platform == 'win32')" }, + { name = "colorama", marker = "python_full_version < '3.14' and sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a2/61/f083b5ac52e505dfc1c624eafbf8c7589a0d7f32daa398d2e7590efa5fda/colorlog-6.10.1.tar.gz", hash = "sha256:eb4ae5cb65fe7fec7773c2306061a8e63e02efc2c72eba9d27b0fa23c94f1321", size = 17162, upload-time = "2025-10-16T16:14:11.978Z" } wheels = [ @@ -1738,6 +1738,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c4/ab/09169d5a4612a5f92490806649ac8d41e3ec9129c636754575b3553f4ea4/googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038", size = 297515, upload-time = "2025-11-06T18:29:13.14Z" }, ] +[[package]] +name = "granite-common" +version = "0.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/b8/cba7a2399079838f793cc138c5b341df965c82e7cdaaf4c37deeffa0a14c/granite_common-0.3.5.tar.gz", hash = "sha256:80d4251b9294b6ec234d5aa4e273801b66f7cc4c5bc77151e5c22e7d7f5a19cd", size = 273710, upload-time = "2025-11-15T01:32:38.761Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/f4/79d121e4192cf7871122995f060f891c6dbcb91b9b4753e4dd27852704c4/granite_common-0.3.5-py3-none-any.whl", hash = "sha256:fca8fdb7caff7f5714bfda9c81438f4a6df974e0b01631421cd6ca1a19bbb07e", size = 77551, upload-time = "2025-11-15T01:32:37.186Z" }, +] + [[package]] name = "grpcio" version = "1.76.0" @@ -3204,6 +3217,7 @@ dependencies = [ { name = "ansicolors" }, { name = "click" }, { name = "fastapi" }, + { name = "granite-common" }, { name = "huggingface-hub" }, { name = "jinja2" }, { name = "json5" }, @@ -3300,6 +3314,7 @@ requires-dist = [ { name = "datasets", marker = "extra == 'hf'", specifier = ">=4.0.0" }, { name = "docling", marker = "extra == 'docling'", specifier = ">=2.45.0" }, { name = "fastapi" }, + { name = "granite-common", specifier = ">=0.3.5" }, { name = "huggingface-hub", specifier = ">=0.33.4" }, { name = "ibm-watsonx-ai", marker = "extra == 'watsonx'", specifier = ">=1.3.31" }, { name = "jinja2" }, @@ -3315,7 +3330,7 @@ requires-dist = [ { name = "outlines", marker = "extra == 'hf'" }, { name = "outlines-core", marker = "extra == 'hf'", specifier = "==0.1.26" }, { name = "outlines-core", marker = "extra == 'vllm'", specifier = "==0.1.26" }, - { name = "peft", marker = "extra == 'hf'", specifier = ">=0.16.0" }, + { name = "peft", marker = "extra == 'hf'", specifier = ">=0.18.0" }, { name = "pillow" }, { name = "pydantic" }, { name = "requests", specifier = ">=2.32.3" }, @@ -4066,7 +4081,7 @@ name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386, upload-time = "2024-10-25T19:54:26.39Z" }, @@ -4077,7 +4092,7 @@ name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" }, @@ -4106,9 +4121,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" }, @@ -4120,7 +4135,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" }, @@ -4165,9 +4180,9 @@ name = "ocrmac" version = "1.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, - { name = "pillow" }, - { name = "pyobjc-framework-vision" }, + { name = "click", marker = "sys_platform == 'darwin'" }, + { name = "pillow", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-vision", marker = "sys_platform == 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/dd/dc/de3e9635774b97d9766f6815bbb3f5ec9bce347115f10d9abbf2733a9316/ocrmac-1.0.0.tar.gz", hash = "sha256:5b299e9030c973d1f60f82db000d6c2e5ff271601878c7db0885e850597d1d2e", size = 1463997, upload-time = "2024-11-07T12:00:00.197Z" } wheels = [ @@ -4192,8 +4207,8 @@ name = "omegaconf" version = "2.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "antlr4-python3-runtime", version = "4.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "pyyaml", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, + { name = "antlr4-python3-runtime", version = "4.9.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, + { name = "pyyaml", marker = "python_full_version < '3.14'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } wheels = [ @@ -4224,7 +4239,7 @@ name = "opencv-python" version = "4.11.0.86" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, + { name = "numpy", marker = "python_full_version < '3.14'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/17/06/68c27a523103dad5837dc5b87e71285280c4f098c60e4fe8a8db6486ab09/opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4", size = 95171956, upload-time = "2025-01-16T13:52:24.737Z" } wheels = [ @@ -4592,7 +4607,7 @@ wheels = [ [[package]] name = "peft" -version = "0.17.1" +version = "0.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "accelerate" }, @@ -4606,9 +4621,9 @@ dependencies = [ { name = "tqdm" }, { name = "transformers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/b8/2e79377efaa1e5f0d70a497db7914ffd355846e760ffa2f7883ab0f600fb/peft-0.17.1.tar.gz", hash = "sha256:e6002b42517976c290b3b8bbb9829a33dd5d470676b2dec7cb4df8501b77eb9f", size = 568192, upload-time = "2025-08-21T09:25:22.703Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4b/0c/f2938db546ac7fc961ab5917cd50fcf5d0d70b406de93e3faccaa504e152/peft-0.18.0.tar.gz", hash = "sha256:c81c80b2056ab40c23d58ef25f74daab417ac653970718589a11a8af28218588", size = 634141, upload-time = "2025-11-13T11:13:06.603Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/49/fe/a2da1627aa9cb6310b6034598363bd26ac301c4a99d21f415b1b2855891e/peft-0.17.1-py3-none-any.whl", hash = "sha256:3d129d64def3d74779c32a080d2567e5f7b674e77d546e3585138216d903f99e", size = 504896, upload-time = "2025-08-21T09:25:18.974Z" }, + { url = "https://files.pythonhosted.org/packages/0f/55/481bf25613d40ef53534f664deba7b138fe566356b6ca10304e2b3b2529c/peft-0.18.0-py3-none-any.whl", hash = "sha256:624f69ca6393b765ccc6734adda7ca57d80b238f0900a42c357d8b67a03d62ff", size = 556427, upload-time = "2025-11-13T11:13:03.664Z" }, ] [[package]] @@ -5328,7 +5343,7 @@ name = "pyobjc-framework-cocoa" version = "12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, + { name = "pyobjc-core", marker = "sys_platform == 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/37/6f/89837da349fe7de6476c426f118096b147de923139556d98af1832c64b97/pyobjc_framework_cocoa-12.0.tar.gz", hash = "sha256:02d69305b698015a20fcc8e1296e1528e413d8cf9fdcd590478d359386d76e8a", size = 2771906, upload-time = "2025-10-21T08:30:51.765Z" } wheels = [ @@ -5346,8 +5361,8 @@ name = "pyobjc-framework-coreml" version = "12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, + { name = "pyobjc-core", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0c/a0/875b5174794c984df60944be54df0282945f8bae4a606fbafa0c6b717ddd/pyobjc_framework_coreml-12.0.tar.gz", hash = "sha256:e1d7a9812886150881c86000fba885cb15201352c75fb286bd9e3a1819b5a4d5", size = 40814, upload-time = "2025-10-21T08:31:53.83Z" } wheels = [ @@ -5365,8 +5380,8 @@ name = "pyobjc-framework-quartz" version = "12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, + { name = "pyobjc-core", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/91/0b/3c34fc9de790daff5ca49d1f36cb8dcc353ac10e4e29b4759e397a3831f4/pyobjc_framework_quartz-12.0.tar.gz", hash = "sha256:5bcb9e78d671447e04d89e2e3c39f3135157892243facc5f8468aa333e40d67f", size = 3159509, upload-time = "2025-10-21T08:40:01.918Z" } wheels = [ @@ -5384,10 +5399,10 @@ name = "pyobjc-framework-vision" version = "12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core" }, - { name = "pyobjc-framework-cocoa" }, - { name = "pyobjc-framework-coreml" }, - { name = "pyobjc-framework-quartz" }, + { name = "pyobjc-core", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-coreml", marker = "sys_platform == 'darwin'" }, + { name = "pyobjc-framework-quartz", marker = "sys_platform == 'darwin'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0f/5a/07cdead5adb77d0742b014fa742d503706754e3ad10e39760e67bb58b497/pyobjc_framework_vision-12.0.tar.gz", hash = "sha256:942c9583f1d887ac9f704f3b0c21b3206b68e02852a87219db4309bb13a02f14", size = 59905, upload-time = "2025-10-21T08:41:53.741Z" } wheels = [ @@ -5845,17 +5860,17 @@ name = "rapidocr" version = "3.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorlog", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "numpy", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "omegaconf", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "opencv-python", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "pillow", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "pyclipper", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "pyyaml", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "requests", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "shapely", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "six", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, - { name = "tqdm", marker = "python_full_version < '3.14' or python_full_version >= '4'" }, + { name = "colorlog", marker = "python_full_version < '3.14'" }, + { name = "numpy", marker = "python_full_version < '3.14'" }, + { name = "omegaconf", marker = "python_full_version < '3.14'" }, + { name = "opencv-python", marker = "python_full_version < '3.14'" }, + { name = "pillow", marker = "python_full_version < '3.14'" }, + { name = "pyclipper", marker = "python_full_version < '3.14'" }, + { name = "pyyaml", marker = "python_full_version < '3.14'" }, + { name = "requests", marker = "python_full_version < '3.14'" }, + { name = "shapely", marker = "python_full_version < '3.14'" }, + { name = "six", marker = "python_full_version < '3.14'" }, + { name = "tqdm", marker = "python_full_version < '3.14'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3c/83/5b8c8075954c5b61d938b8954710d986134c4ca7c32a841ad7d8c844cf6c/rapidocr-3.4.2-py3-none-any.whl", hash = "sha256:17845fa8cc9a20a935111e59482f2214598bba1547000cfd960d8924dd4522a5", size = 15056674, upload-time = "2025-10-11T14:43:00.296Z" }, @@ -6795,7 +6810,7 @@ name = "shapely" version = "2.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "python_full_version < '4'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/4d/bc/0989043118a27cccb4e906a46b7565ce36ca7b57f5a18b78f4f1b0f72d9d/shapely-2.1.2.tar.gz", hash = "sha256:2ed4ecb28320a433db18a5bf029986aa8afcfd740745e78847e330d5d94922a9", size = 315489, upload-time = "2025-09-24T13:51:41.432Z" } wheels = [ @@ -7552,7 +7567,7 @@ name = "triton" version = "3.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools" }, + { name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/76/04/d54d3a6d077c646624dc9461b0059e23fd5d30e0dbe67471e3654aec81f9/triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7", size = 156441993, upload-time = "2025-04-09T20:27:25.107Z" }, @@ -8078,8 +8093,8 @@ name = "xformers" version = "0.0.30" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, - { name = "torch" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "torch", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bf/f7/dd2269cce89fd1221947dd7cc3a60707ffe721ef55c1803ac3b1a1f7ae5c/xformers-0.0.30.tar.gz", hash = "sha256:a12bf3eb39e294cdbe8a7253ac9b665f41bac61d6d98df174e34ef7bdb6f2fc4", size = 10214139, upload-time = "2025-04-28T20:51:02.045Z" } wheels = [