Skip to content
57 changes: 38 additions & 19 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class LiteLLMBackend(FormatterBackend):

def __init__(
self,
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name),
model_id: str = "ollama_chat/"
+ str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name),
formatter: Formatter | None = None,
base_url: str | None = "http://localhost:11434",
model_options: dict | None = None,
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
# Usually, values that are intentionally extracted while prepping for the backend generate call
# will be omitted here so that they will be removed when model_options are processed
# for the call to the model.
# for the call to the model. For LiteLLM, this dict might change slightly depending on the provider.
self.from_mellea_model_opts_map = {
ModelOption.SEED: "seed",
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
Expand Down Expand Up @@ -176,39 +177,57 @@ def _make_backend_specific_and_remove(
Returns:
a new dict
"""
backend_specific = ModelOption.replace_keys(
model_options, self.from_mellea_model_opts_map
)
backend_specific = ModelOption.remove_special_keys(backend_specific)

# We set `drop_params=True` which will drop non-supported openai params; check for non-openai
# params that might cause errors and log which openai params aren't supported here.
# See https://docs.litellm.ai/docs/completion/input.
# standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params(
self._model_id
)
supported_params = (
set(supported_params_list) if supported_params_list is not None else set()
)

# unknown_keys = [] # keys that are unknown to litellm
unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider
# LiteLLM specific remappings (typically based on provider). There's a few cases where the provider accepts
# different parameters than LiteLLM says it does. Here's a few rules that help in those scenarios.
model_opts_remapping = self.from_mellea_model_opts_map.copy()
if (
"max_completion_tokens" not in supported_params
and "max_tokens" in supported_params
):
# Scenario hit by Watsonx. LiteLLM believes Watsonx doesn't accept "max_completion_tokens" even though
# OpenAI compatible endpoints should accept both (and Watsonx does accept both).
model_opts_remapping[ModelOption.MAX_NEW_TOKENS] = "max_tokens"

backend_specific = ModelOption.replace_keys(model_options, model_opts_remapping)
backend_specific = ModelOption.remove_special_keys(backend_specific)

# Since LiteLLM has many different providers, we add some additional parameter logging here.
# There's two sets of parameters we have to look at:
# - unsupported_openai_params: standard OpenAI parameters that LiteLLM will automatically drop for us when `drop_params=True` if the provider doesn't support them.
# - unknown_keys: parameters that LiteLLM doesn't know about, aren't standard OpenAI parameters, and might be used by the provider. We don't drop these.
# We want to flag both for the end user.
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
unknown_keys = [] # Keys that are unknown to litellm.
unsupported_openai_params = [] # OpenAI params that are known to litellm but not supported for this model/provider.
for key in backend_specific.keys():
if key not in supported_params:
unsupported_openai_params.append(key)

# if len(unknown_keys) > 0:
# FancyLogger.get_logger().warning(
# f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
# )
if key in standard_openai_subset:
# LiteLLM is pretty confident that this standard OpenAI parameter won't work.
unsupported_openai_params.append(key)
else:
# LiteLLM doesn't make any claims about this parameter; we won't drop it but we will keep track of it..
unknown_keys.append(key)

if len(unknown_keys) > 0:
FancyLogger.get_logger().warning(
f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}"
)

if len(unsupported_openai_params) > 0:
FancyLogger.get_logger().warning(
f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to force LiteLLM to accept parameters that we should expose through our API as well?

Thinking of an analogy to the --force parameter from some linux commands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set drop_params=False in the call to the model. But I think this change is accomplishing what you are asking for with accept parameters that we should expose through our API.

Previously, we dropped all params that weren't "known" and "basic" openai parameters. Now, we let LiteLLM drop "known" but "unsupported" openai params. All other params get passed through transparently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And there are no false negatives - i.e. litellm filters out a parameter that the model understands but LiteLLM assumes it doesn't ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible; I searched the LiteLLM github and the only errors I could find with drop_params is that it is too permissive (ie keeps parameters around that it shouldn't). We can disable drop_params if you'd prefer to just pass through all parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.. false positives are ok IMHO.

f"litellm may drop the following openai keys that it doesn't seem to recognize as being supported by the current model/provider: {', '.join(unsupported_openai_params)}"
"\nThere are sometimes false positives here."
)
for key in unsupported_openai_params:
del backend_specific[key]

return backend_specific

Expand Down
11 changes: 8 additions & 3 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import warnings
from collections.abc import AsyncGenerator, Callable, Coroutine
from dataclasses import fields
from typing import Any

from ibm_watsonx_ai import APIClient, Credentials
Expand Down Expand Up @@ -110,7 +111,8 @@ def __init__(
# These are usually values that must be extracted before hand or that are common among backend providers.
self.to_mellea_model_opts_map_chats = {
"system": ModelOption.SYSTEM_PROMPT,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
"max_tokens": ModelOption.MAX_NEW_TOKENS, # Is being deprecated in favor of `max_completion_tokens.`
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
"tools": ModelOption.TOOLS,
"stream": ModelOption.STREAM,
}
Expand All @@ -120,7 +122,7 @@ def __init__(
# will be omitted here so that they will be removed when model_options are processed
# for the call to the model.
self.from_mellea_model_opts_map_chats = {
ModelOption.MAX_NEW_TOKENS: "max_tokens"
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens"
}

# See notes above.
Expand Down Expand Up @@ -168,7 +170,10 @@ def _get_watsonx_model_id(self) -> str:

def filter_chat_completions_kwargs(self, model_options: dict) -> dict:
"""Filter kwargs to only include valid watsonx chat.completions.create parameters."""
chat_params = TextChatParameters.get_sample_params().keys()
# TextChatParameters.get_sample_params().keys() can't be completely trusted. It doesn't always contain all
# all of the accepted keys. In version 1.3.39, max_tokens was removed even though it's still accepted.
# It's a dataclass so use the fields function to get the names.
chat_params = {field.name for field in fields(TextChatParameters)}
return {k: v for k, v in model_options.items() if k in chat_params}

def _simplify_and_merge(
Expand Down
73 changes: 62 additions & 11 deletions test/backends/test_litellm_ollama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import os
import pytest

from mellea import MelleaSession, generative
Expand All @@ -8,14 +9,73 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.sampling import RejectionSamplingStrategy

@pytest.fixture(scope="function")
def backend(gh_run: int):
"""Shared OpenAI backend configured for Ollama."""
if gh_run == 1:
# LiteLLM prepends 127.0.0.1 with a `/` which causes issues.
url = os.environ.get("OLLAMA_HOST", None)
if url is None:
url = "http://localhost:11434"
else:
url = url.replace("127.0.0.1", "http://localhost")

return LiteLLMBackend(
model_id="ollama_chat/llama3.2:1b",
base_url=url,
model_options={"api_base": url}
)
else:
return LiteLLMBackend()

@pytest.fixture(scope="function")
def session():
def session(backend):
"""Fresh Ollama session for each test."""
session = MelleaSession(LiteLLMBackend())
session = MelleaSession(backend=backend)
yield session
session.reset()

# Use capsys to check that the logging is working.
def test_make_backend_specific_and_remove():
# Doesn't need to be a real model here; just a provider that LiteLLM knows about.
backend = LiteLLMBackend(model_id="ollama_chat/")

params = {
"max_tokens": 1,
"stream": 1,
ModelOption.TEMPERATURE: 1,
"unknown_parameter": 1, # Unknown / non-OpenAI parameter
"web_search_options": 1, # Standard OpenAI parameter not supported by Ollama.
}

mellea = backend._simplify_and_merge(params)
backend_specific = backend._make_backend_specific_and_remove(mellea)


# All of these options should be in the model options that get passed to LiteLLM since it handles the dropping.
assert "max_completion_tokens" in backend_specific, "max_tokens should get remapped to max_completion_tokens for ollama_chat/"
assert "stream" in backend_specific
assert "temperature" in backend_specific
assert "unknown_parameter" in backend_specific
assert "web_search_options" in backend_specific

# TODO: Investigate why this isn't working on github action runners.
# Add the capsys or capfd fixture back.
# out = capsys.readouterr()
# # Check for the specific warning logs.
# assert "supported by the current model/provider: web_search_options" in out.out
# assert "mellea won't validate the following params that may cause issues: unknown_parameter" in out.out

# Do a quick test for the Watsonx specific scenario.
backend = LiteLLMBackend(model_id="watsonx/")
watsonx_params = {"max_tokens": 1}

# Make sure we make it Mellea specific correctly.
watsonx_mellea = backend._simplify_and_merge(watsonx_params)
assert ModelOption.MAX_NEW_TOKENS in watsonx_mellea

watsonx_backend_specific = backend._make_backend_specific_and_remove(watsonx_mellea)
assert "max_tokens" in watsonx_backend_specific

@pytest.mark.qualitative
def test_litellm_ollama_chat(session):
Expand All @@ -26,7 +86,6 @@ def test_litellm_ollama_chat(session):
f"Expected a message with content containing 2 but found {res}"
)

@pytest.mark.qualitative
def test_litellm_ollama_instruct(session):
res = session.instruct(
"Write an email to the interns.",
Expand All @@ -37,7 +96,6 @@ def test_litellm_ollama_instruct(session):
assert isinstance(res.value, str)


@pytest.mark.qualitative
def test_litellm_ollama_instruct_options(session):
model_options={
ModelOption.SEED: 123,
Expand All @@ -59,11 +117,6 @@ def test_litellm_ollama_instruct_options(session):
# make sure that homer_simpson is in the logged model_options
assert "homer_simpson" in res._generate_log.model_options

# make sure the backend function filters out the model option when passing to the generate call
backend = session.backend
assert isinstance(backend, LiteLLMBackend)
assert "homer_simpson" not in backend._make_backend_specific_and_remove(model_options)


@pytest.mark.qualitative
def test_gen_slot(session):
Expand All @@ -77,7 +130,6 @@ def is_happy(text: str) -> bool:
# should yield to true - but, of course, is model dependent
assert h is True

@pytest.mark.qualitative
async def test_async_parallel_requests(session):
model_opts = {ModelOption.STREAM: True}
mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts)
Expand All @@ -104,7 +156,6 @@ async def test_async_parallel_requests(session):
assert m1_final_val == mot1.value
assert m2_final_val == mot2.value

@pytest.mark.qualitative
async def test_async_avalue(session):
mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext())
m1_final_val = await mot1.avalue()
Expand Down
17 changes: 17 additions & 0 deletions test/backends/test_watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ def session(backend: WatsonxAIBackend):
yield session
session.reset()

@pytest.mark.qualitative
def test_filter_chat_completions_kwargs(backend: WatsonxAIBackend):
"""Detect changes to the WatsonxAI TextChatParameters."""

known_keys = ['frequency_penalty', 'logprobs', 'top_logprobs', 'presence_penalty', 'response_format', 'temperature', 'max_tokens', 'max_completion_tokens', 'time_limit', 'top_p', 'n', 'logit_bias', 'seed', 'stop', 'guided_choice', 'guided_regex', 'guided_grammar', 'guided_json']
test_dict = {key: 1 for key in known_keys}

# Make sure keys that we think should be in the TextChatParameters are there.
filtered_dict = backend.filter_chat_completions_kwargs(test_dict)

for key in known_keys:
assert key in filtered_dict

# Make sure unsupported keys still get filtered out.
incorrect_dict = {"random": 1}
filtered_incorrect_dict = backend.filter_chat_completions_kwargs(incorrect_dict)
assert "random" not in filtered_incorrect_dict

@pytest.mark.qualitative
def test_instruct(session: MelleaSession):
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def pytest_runtest_setup(item):
gh_run = int(os.environ.get("CICD", 0))

if gh_run == 1:
pytest.xfail(
pytest.skip(
reason="Skipping qualitative test: got env variable CICD == 1. Used only in gh workflows."
)