diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index f09b4a04..c38c988d 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -67,6 +67,8 @@ """ TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device] +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + @dataclasses.dataclass class HFAloraCacheInfo: @@ -209,11 +211,11 @@ def generate_from_context( reroute_to_alora = True if reroute_to_alora: mot = self._generate_from_context_alora( - action, ctx, format=format, model_options=model_opts + action, ctx, _format=format, model_options=model_opts ) return mot, ctx.add(mot) mot = self._generate_from_context_standard( - action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls + action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls ) return mot, ctx.add(action).add(mot) @@ -222,7 +224,7 @@ def _generate_from_context_alora( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] | None = None, + _format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], ) -> ModelOutputThunk: match action: @@ -245,7 +247,7 @@ def _generate_from_context_alora( assert alora_for_this_request is not None assert type(user_message) is str assert type(assistant_message) is str - assert format is None, "Structured outputs are not supported by ALoRAs." + assert _format is None, "Structured outputs are not supported by ALoRAs." alora_output = alora_for_this_request.generate_using_strings( input=user_message, @@ -269,7 +271,7 @@ def _generate_from_context_standard( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] | None = None, + _format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], tool_calls: bool = False, ) -> ModelOutputThunk: @@ -310,7 +312,7 @@ def _generate_from_context_standard( # Append tool call information if applicable. tools: dict[str, Callable] = dict() if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) @@ -338,10 +340,10 @@ def _generate_from_context_standard( ).to(self._device) # type: ignore format_kwargs = {} - if format: + if _format: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk - schema: dict[str, Any] = format.model_json_schema() + schema: dict[str, Any] = _format.model_json_schema() schema_json: str = json.dumps(schema) regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json @@ -406,7 +408,7 @@ def _generate_from_context_standard( self.post_processing, conversation=ctx_as_conversation, input_ids=input_ids, - format=format, + _format=_format, tool_calls=tool_calls, tools=tools, seed=seed, @@ -463,7 +465,7 @@ async def post_processing( self, mot: ModelOutputThunk, conversation: list[dict], - format: type[BaseModelSubclass] | None, + _format: type[BaseModelSubclass] | None, tool_calls: bool, tools: dict[str, Callable], seed, @@ -494,7 +496,7 @@ async def post_processing( self.cache_put(mot.value, cache_info) # Only scan for tools if we are not doing structured output and tool calls were provided to the model. - if format is None and tool_calls: + if _format is None and tool_calls: mot.tool_calls = self._extract_model_tool_requests(tools, mot.value) assert mot._action is not None, ( @@ -514,7 +516,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot.value generate_log.extra = { - "format": format, + "format": _format, "tools_available": tools, "tools_called": mot.tool_calls, "seed": seed, diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 451b9f35..7f9b284a 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -40,6 +40,8 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + class LiteLLMBackend(FormatterBackend): """A generic LiteLLM compatible backend.""" @@ -123,7 +125,7 @@ def generate_from_context( mot = self._generate_from_chat_context_standard( action, ctx, - format=format, + _format=format, model_options=model_options, tool_calls=tool_calls, ) @@ -215,7 +217,7 @@ def _generate_from_chat_context_standard( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] + _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, @@ -249,12 +251,12 @@ def _generate_from_chat_context_standard( [OpenAIBackend.message_to_openai_message(m) for m in messages] ) - if format is not None: + if _format is not None: response_format = { "type": "json_schema", "json_schema": { - "name": format.__name__, - "schema": format.model_json_schema(), + "name": _format.__name__, + "schema": _format.model_json_schema(), "strict": True, }, } @@ -267,7 +269,7 @@ def _generate_from_chat_context_standard( thinking = "medium" # Append tool call information if applicable. - tools = self._extract_tools(action, format, model_opts, tool_calls, ctx) + tools = self._extract_tools(action, _format, model_opts, tool_calls, ctx) formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None model_specific_options = self._make_backend_specific_and_remove(model_opts) @@ -302,7 +304,7 @@ def _generate_from_chat_context_standard( conversation=conversation, tools=tools, thinking=thinking, - format=format, + _format=_format, ) try: @@ -380,7 +382,7 @@ async def post_processing( conversation: list[dict], tools: dict[str, Callable], thinking, - format, + _format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. @@ -425,7 +427,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot._meta["litellm_chat_response"] generate_log.extra = { - "format": format, + "format": _format, "tools_available": tools, "tools_called": mot.tool_calls, "seed": thinking, @@ -436,11 +438,11 @@ async def post_processing( @staticmethod def _extract_tools( - action, format, model_opts, tool_calls, ctx + action, _format, model_opts, tool_calls, ctx ) -> dict[str, Callable]: tools: dict[str, Callable] = dict() if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 0db3eb92..f58d6513 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -36,6 +36,8 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + class OllamaModelBackend(FormatterBackend): """A model that uses the Ollama Python SDK for local inference.""" @@ -265,7 +267,7 @@ def generate_from_context( mot = self.generate_from_chat_context( action, ctx, - format=format, + _format=format, model_options=model_options, tool_calls=tool_calls, ) @@ -277,7 +279,7 @@ def generate_from_chat_context( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] | None = None, + _format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: @@ -325,7 +327,7 @@ def generate_from_chat_context( # Append tool call information if applicable. tools: dict[str, Callable] = dict() if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) @@ -348,7 +350,7 @@ def generate_from_chat_context( think=model_opts.get(ModelOption.THINKING, None), stream=model_opts.get(ModelOption.STREAM, False), options=self._make_backend_specific_and_remove(model_opts), - format=format.model_json_schema() if format is not None else None, + format=_format.model_json_schema() if _format is not None else None, ) # type: ignore output = ModelOutputThunk(None) @@ -360,7 +362,10 @@ def generate_from_chat_context( # each processing step. output._process = functools.partial(self.processing, tools=tools) output._post_process = functools.partial( - self.post_processing, conversation=conversation, tools=tools, format=format + self.post_processing, + conversation=conversation, + tools=tools, + _format=_format, ) try: @@ -523,7 +528,7 @@ async def post_processing( mot: ModelOutputThunk, conversation: list[dict], tools: dict[str, Callable], - format, + _format, ): """Called when generation is done.""" assert mot._action is not None, ( @@ -542,7 +547,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot._meta["chat_response"] generate_log.extra = { - "format": format, + "format": _format, "thinking": mot._model_options.get(ModelOption.THINKING, None), "tools_available": tools, "tools_called": mot.tool_calls, diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index a60d1fc9..39b026a8 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -55,6 +55,8 @@ openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string" +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + class _ServerType(Enum): LOCALHOST = 1 @@ -303,7 +305,7 @@ def generate_from_context( mot = self.generate_from_chat_context( action, ctx, - format=format, + _format=format, model_options=model_options, tool_calls=tool_calls, ) @@ -314,7 +316,7 @@ def generate_from_chat_context( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] + _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, @@ -332,13 +334,13 @@ def generate_from_chat_context( reroute_to_alora = True if reroute_to_alora: return self._generate_from_chat_context_alora( - action, ctx, format=format, model_options=model_options + action, ctx, _format=_format, model_options=model_options ) return self._generate_from_chat_context_standard( action, ctx, - format=format, + _format=_format, model_options=model_options, tool_calls=tool_calls, ) @@ -348,7 +350,7 @@ def _generate_from_chat_context_alora( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] + _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, ) -> ModelOutputThunk: @@ -373,7 +375,7 @@ def _generate_from_chat_context_alora( assert alora_for_this_request is not None assert type(user_message) is str assert type(assistant_message) is str - assert format is None, "Structured outputs are not supported by ALoRAs." + assert _format is None, "Structured outputs are not supported by ALoRAs." model_opts = self._simplify_and_merge(model_options, is_chat_context=True) @@ -434,7 +436,7 @@ def _generate_from_chat_context_standard( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] + _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, @@ -463,12 +465,12 @@ def _generate_from_chat_context_standard( conversation.append({"role": "system", "content": system_prompt}) conversation.extend([self.message_to_openai_message(m) for m in messages]) - if format is not None: + if _format is not None: response_format = { "type": "json_schema", "json_schema": { - "name": format.__name__, - "schema": format.model_json_schema(), + "name": _format.__name__, + "schema": _format.model_json_schema(), "strict": True, }, } @@ -478,7 +480,7 @@ def _generate_from_chat_context_standard( # Append tool call information if applicable. tools: dict[str, Callable] = dict() if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) @@ -527,7 +529,7 @@ def _generate_from_chat_context_standard( conversation=conversation, thinking=thinking, seed=model_opts.get(ModelOption.SEED, None), - format=format, + _format=_format, ) try: @@ -596,7 +598,7 @@ async def post_processing( conversation: list[dict], thinking, seed, - format, + _format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. @@ -634,7 +636,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot._meta["oai_chat_response"] generate_log.extra = { - "format": format, + "format": _format, "thinking": thinking, "tools_available": tools, "tools_called": mot.tool_calls, diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 5fbae724..13961339 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -44,6 +44,8 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement # type: ignore +format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors + class WatsonxAIBackend(FormatterBackend): """A generic backend class for watsonx SDK.""" @@ -243,7 +245,7 @@ def generate_from_context( mot = self.generate_from_chat_context( action, ctx, - format=format, + _format=format, model_options=model_options, tool_calls=tool_calls, ) @@ -254,7 +256,7 @@ def generate_from_chat_context( action: Component | CBlock, ctx: Context, *, - format: type[BaseModelSubclass] + _format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, tool_calls: bool = False, @@ -285,12 +287,12 @@ def generate_from_chat_context( conversation.append({"role": "system", "content": system_prompt}) conversation.extend([{"role": m.role, "content": m.content} for m in messages]) - if format is not None: + if _format is not None: model_opts["response_format"] = { "type": "json_schema", "json_schema": { - "name": format.__name__, - "schema": format.model_json_schema(), + "name": _format.__name__, + "schema": _format.model_json_schema(), "strict": True, }, } @@ -300,7 +302,7 @@ def generate_from_chat_context( # Append tool call information if applicable. tools: dict[str, Callable] = {} if tool_calls: - if format: + if _format: FancyLogger.get_logger().warning( f"tool calling is superseded by format; will not call tools for request: {action}" ) @@ -356,7 +358,7 @@ def generate_from_chat_context( conversation=conversation, tools=tools, seed=model_opts.get(ModelOption.SEED, None), - format=format, + _format=_format, ) try: @@ -424,7 +426,7 @@ async def post_processing( conversation: list[dict], tools: dict[str, Callable], seed, - format, + _format, ): """Called when generation is done.""" # Reconstruct the chat_response from chunks if streamed. @@ -462,7 +464,7 @@ async def post_processing( generate_log.date = datetime.datetime.now() generate_log.model_output = mot._meta["oai_chat_response"] generate_log.extra = { - "format": format, + "format": _format, "tools_available": tools, "tools_called": mot.tool_calls, "seed": seed,