Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
"""
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors


@dataclasses.dataclass
class HFAloraCacheInfo:
Expand Down Expand Up @@ -209,11 +211,11 @@ def generate_from_context(
reroute_to_alora = True
if reroute_to_alora:
mot = self._generate_from_context_alora(
action, ctx, format=format, model_options=model_opts
action, ctx, _format=format, model_options=model_opts
)
return mot, ctx.add(mot)
mot = self._generate_from_context_standard(
action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
)
return mot, ctx.add(action).add(mot)

Expand All @@ -222,7 +224,7 @@ def _generate_from_context_alora(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
_format: type[BaseModelSubclass] | None = None,
model_options: dict[str, Any],
) -> ModelOutputThunk:
match action:
Expand All @@ -245,7 +247,7 @@ def _generate_from_context_alora(
assert alora_for_this_request is not None
assert type(user_message) is str
assert type(assistant_message) is str
assert format is None, "Structured outputs are not supported by ALoRAs."
assert _format is None, "Structured outputs are not supported by ALoRAs."

alora_output = alora_for_this_request.generate_using_strings(
input=user_message,
Expand All @@ -269,7 +271,7 @@ def _generate_from_context_standard(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
_format: type[BaseModelSubclass] | None = None,
model_options: dict[str, Any],
tool_calls: bool = False,
) -> ModelOutputThunk:
Expand Down Expand Up @@ -310,7 +312,7 @@ def _generate_from_context_standard(
# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
if tool_calls:
if format:
if _format:
FancyLogger.get_logger().warning(
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
)
Expand Down Expand Up @@ -338,10 +340,10 @@ def _generate_from_context_standard(
).to(self._device) # type: ignore

format_kwargs = {}
if format:
if _format:
# outlines.generate.json always parses the resulting json into a python dict.
# We however want to keep it as a json string for later storing it in ModelOutputThunk
schema: dict[str, Any] = format.model_json_schema()
schema: dict[str, Any] = _format.model_json_schema()
schema_json: str = json.dumps(schema)
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
schema_json
Expand Down Expand Up @@ -406,7 +408,7 @@ def _generate_from_context_standard(
self.post_processing,
conversation=ctx_as_conversation,
input_ids=input_ids,
format=format,
_format=_format,
tool_calls=tool_calls,
tools=tools,
seed=seed,
Expand Down Expand Up @@ -463,7 +465,7 @@ async def post_processing(
self,
mot: ModelOutputThunk,
conversation: list[dict],
format: type[BaseModelSubclass] | None,
_format: type[BaseModelSubclass] | None,
tool_calls: bool,
tools: dict[str, Callable],
seed,
Expand Down Expand Up @@ -494,7 +496,7 @@ async def post_processing(
self.cache_put(mot.value, cache_info)

# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
if format is None and tool_calls:
if _format is None and tool_calls:
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)

assert mot._action is not None, (
Expand All @@ -514,7 +516,7 @@ async def post_processing(
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot.value
generate_log.extra = {
"format": format,
"format": _format,
"tools_available": tools,
"tools_called": mot.tool_calls,
"seed": seed,
Expand Down
24 changes: 13 additions & 11 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import ALoraRequirement

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors


class LiteLLMBackend(FormatterBackend):
"""A generic LiteLLM compatible backend."""
Expand Down Expand Up @@ -123,7 +125,7 @@ def generate_from_context(
mot = self._generate_from_chat_context_standard(
action,
ctx,
format=format,
_format=format,
model_options=model_options,
tool_calls=tool_calls,
)
Expand Down Expand Up @@ -215,7 +217,7 @@ def _generate_from_chat_context_standard(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass]
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
Expand Down Expand Up @@ -249,12 +251,12 @@ def _generate_from_chat_context_standard(
[OpenAIBackend.message_to_openai_message(m) for m in messages]
)

if format is not None:
if _format is not None:
response_format = {
"type": "json_schema",
"json_schema": {
"name": format.__name__,
"schema": format.model_json_schema(),
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
Expand All @@ -267,7 +269,7 @@ def _generate_from_chat_context_standard(
thinking = "medium"

# Append tool call information if applicable.
tools = self._extract_tools(action, format, model_opts, tool_calls, ctx)
tools = self._extract_tools(action, _format, model_opts, tool_calls, ctx)
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None

model_specific_options = self._make_backend_specific_and_remove(model_opts)
Expand Down Expand Up @@ -302,7 +304,7 @@ def _generate_from_chat_context_standard(
conversation=conversation,
tools=tools,
thinking=thinking,
format=format,
_format=_format,
)

try:
Expand Down Expand Up @@ -380,7 +382,7 @@ async def post_processing(
conversation: list[dict],
tools: dict[str, Callable],
thinking,
format,
_format,
):
"""Called when generation is done."""
# Reconstruct the chat_response from chunks if streamed.
Expand Down Expand Up @@ -425,7 +427,7 @@ async def post_processing(
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["litellm_chat_response"]
generate_log.extra = {
"format": format,
"format": _format,
"tools_available": tools,
"tools_called": mot.tool_calls,
"seed": thinking,
Expand All @@ -436,11 +438,11 @@ async def post_processing(

@staticmethod
def _extract_tools(
action, format, model_opts, tool_calls, ctx
action, _format, model_opts, tool_calls, ctx
) -> dict[str, Callable]:
tools: dict[str, Callable] = dict()
if tool_calls:
if format:
if _format:
FancyLogger.get_logger().warning(
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
)
Expand Down
19 changes: 12 additions & 7 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import ALoraRequirement

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors


class OllamaModelBackend(FormatterBackend):
"""A model that uses the Ollama Python SDK for local inference."""
Expand Down Expand Up @@ -265,7 +267,7 @@ def generate_from_context(
mot = self.generate_from_chat_context(
action,
ctx,
format=format,
_format=format,
model_options=model_options,
tool_calls=tool_calls,
)
Expand All @@ -277,7 +279,7 @@ def generate_from_chat_context(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
_format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
tool_calls: bool = False,
) -> ModelOutputThunk:
Expand Down Expand Up @@ -325,7 +327,7 @@ def generate_from_chat_context(
# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
if tool_calls:
if format:
if _format:
FancyLogger.get_logger().warning(
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
)
Expand All @@ -348,7 +350,7 @@ def generate_from_chat_context(
think=model_opts.get(ModelOption.THINKING, None),
stream=model_opts.get(ModelOption.STREAM, False),
options=self._make_backend_specific_and_remove(model_opts),
format=format.model_json_schema() if format is not None else None,
format=_format.model_json_schema() if _format is not None else None,
) # type: ignore

output = ModelOutputThunk(None)
Expand All @@ -360,7 +362,10 @@ def generate_from_chat_context(
# each processing step.
output._process = functools.partial(self.processing, tools=tools)
output._post_process = functools.partial(
self.post_processing, conversation=conversation, tools=tools, format=format
self.post_processing,
conversation=conversation,
tools=tools,
_format=_format,
)

try:
Expand Down Expand Up @@ -523,7 +528,7 @@ async def post_processing(
mot: ModelOutputThunk,
conversation: list[dict],
tools: dict[str, Callable],
format,
_format,
):
"""Called when generation is done."""
assert mot._action is not None, (
Expand All @@ -542,7 +547,7 @@ async def post_processing(
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["chat_response"]
generate_log.extra = {
"format": format,
"format": _format,
"thinking": mot._model_options.get(ModelOption.THINKING, None),
"tools_available": tools,
"tools_called": mot.tool_calls,
Expand Down
30 changes: 16 additions & 14 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@

openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors


class _ServerType(Enum):
LOCALHOST = 1
Expand Down Expand Up @@ -303,7 +305,7 @@ def generate_from_context(
mot = self.generate_from_chat_context(
action,
ctx,
format=format,
_format=format,
model_options=model_options,
tool_calls=tool_calls,
)
Expand All @@ -314,7 +316,7 @@ def generate_from_chat_context(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass]
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
Expand All @@ -332,13 +334,13 @@ def generate_from_chat_context(
reroute_to_alora = True
if reroute_to_alora:
return self._generate_from_chat_context_alora(
action, ctx, format=format, model_options=model_options
action, ctx, _format=_format, model_options=model_options
)

return self._generate_from_chat_context_standard(
action,
ctx,
format=format,
_format=_format,
model_options=model_options,
tool_calls=tool_calls,
)
Expand All @@ -348,7 +350,7 @@ def _generate_from_chat_context_alora(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass]
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
) -> ModelOutputThunk:
Expand All @@ -373,7 +375,7 @@ def _generate_from_chat_context_alora(
assert alora_for_this_request is not None
assert type(user_message) is str
assert type(assistant_message) is str
assert format is None, "Structured outputs are not supported by ALoRAs."
assert _format is None, "Structured outputs are not supported by ALoRAs."

model_opts = self._simplify_and_merge(model_options, is_chat_context=True)

Expand Down Expand Up @@ -434,7 +436,7 @@ def _generate_from_chat_context_standard(
action: Component | CBlock,
ctx: Context,
*,
format: type[BaseModelSubclass]
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
Expand Down Expand Up @@ -463,12 +465,12 @@ def _generate_from_chat_context_standard(
conversation.append({"role": "system", "content": system_prompt})
conversation.extend([self.message_to_openai_message(m) for m in messages])

if format is not None:
if _format is not None:
response_format = {
"type": "json_schema",
"json_schema": {
"name": format.__name__,
"schema": format.model_json_schema(),
"name": _format.__name__,
"schema": _format.model_json_schema(),
"strict": True,
},
}
Expand All @@ -478,7 +480,7 @@ def _generate_from_chat_context_standard(
# Append tool call information if applicable.
tools: dict[str, Callable] = dict()
if tool_calls:
if format:
if _format:
FancyLogger.get_logger().warning(
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
)
Expand Down Expand Up @@ -527,7 +529,7 @@ def _generate_from_chat_context_standard(
conversation=conversation,
thinking=thinking,
seed=model_opts.get(ModelOption.SEED, None),
format=format,
_format=_format,
)

try:
Expand Down Expand Up @@ -596,7 +598,7 @@ async def post_processing(
conversation: list[dict],
thinking,
seed,
format,
_format,
):
"""Called when generation is done."""
# Reconstruct the chat_response from chunks if streamed.
Expand Down Expand Up @@ -634,7 +636,7 @@ async def post_processing(
generate_log.date = datetime.datetime.now()
generate_log.model_output = mot._meta["oai_chat_response"]
generate_log.extra = {
"format": format,
"format": _format,
"thinking": thinking,
"tools_available": tools,
"tools_called": mot.tool_calls,
Expand Down
Loading