diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 7f9b284a..89b61536 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -48,7 +48,8 @@ class LiteLLMBackend(FormatterBackend): def __init__( self, - model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name), + model_id: str = "ollama_chat/" + + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name), formatter: Formatter | None = None, base_url: str | None = "http://localhost:11434", model_options: dict | None = None, @@ -100,7 +101,7 @@ def __init__( # These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`. # Usually, values that are intentionally extracted while prepping for the backend generate call # will be omitted here so that they will be removed when model_options are processed - # for the call to the model. + # for the call to the model. For LiteLLM, this dict might change slightly depending on the provider. self.from_mellea_model_opts_map = { ModelOption.SEED: "seed", ModelOption.MAX_NEW_TOKENS: "max_completion_tokens", @@ -176,15 +177,9 @@ def _make_backend_specific_and_remove( Returns: a new dict """ - backend_specific = ModelOption.replace_keys( - model_options, self.from_mellea_model_opts_map - ) - backend_specific = ModelOption.remove_special_keys(backend_specific) - # We set `drop_params=True` which will drop non-supported openai params; check for non-openai # params that might cause errors and log which openai params aren't supported here. # See https://docs.litellm.ai/docs/completion/input. - # standard_openai_subset = litellm.get_standard_openai_params(backend_specific) supported_params_list = litellm.litellm_core_utils.get_supported_openai_params.get_supported_openai_params( self._model_id ) @@ -192,23 +187,47 @@ def _make_backend_specific_and_remove( set(supported_params_list) if supported_params_list is not None else set() ) - # unknown_keys = [] # keys that are unknown to litellm - unsupported_openai_params = [] # openai params that are known to litellm but not supported for this model/provider + # LiteLLM specific remappings (typically based on provider). There's a few cases where the provider accepts + # different parameters than LiteLLM says it does. Here's a few rules that help in those scenarios. + model_opts_remapping = self.from_mellea_model_opts_map.copy() + if ( + "max_completion_tokens" not in supported_params + and "max_tokens" in supported_params + ): + # Scenario hit by Watsonx. LiteLLM believes Watsonx doesn't accept "max_completion_tokens" even though + # OpenAI compatible endpoints should accept both (and Watsonx does accept both). + model_opts_remapping[ModelOption.MAX_NEW_TOKENS] = "max_tokens" + + backend_specific = ModelOption.replace_keys(model_options, model_opts_remapping) + backend_specific = ModelOption.remove_special_keys(backend_specific) + + # Since LiteLLM has many different providers, we add some additional parameter logging here. + # There's two sets of parameters we have to look at: + # - unsupported_openai_params: standard OpenAI parameters that LiteLLM will automatically drop for us when `drop_params=True` if the provider doesn't support them. + # - unknown_keys: parameters that LiteLLM doesn't know about, aren't standard OpenAI parameters, and might be used by the provider. We don't drop these. + # We want to flag both for the end user. + standard_openai_subset = litellm.get_standard_openai_params(backend_specific) + unknown_keys = [] # Keys that are unknown to litellm. + unsupported_openai_params = [] # OpenAI params that are known to litellm but not supported for this model/provider. for key in backend_specific.keys(): if key not in supported_params: - unsupported_openai_params.append(key) - - # if len(unknown_keys) > 0: - # FancyLogger.get_logger().warning( - # f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}" - # ) + if key in standard_openai_subset: + # LiteLLM is pretty confident that this standard OpenAI parameter won't work. + unsupported_openai_params.append(key) + else: + # LiteLLM doesn't make any claims about this parameter; we won't drop it but we will keep track of it.. + unknown_keys.append(key) + + if len(unknown_keys) > 0: + FancyLogger.get_logger().warning( + f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}" + ) if len(unsupported_openai_params) > 0: FancyLogger.get_logger().warning( - f"litellm will automatically drop the following openai keys that aren't supported by the current model/provider: {', '.join(unsupported_openai_params)}" + f"litellm may drop the following openai keys that it doesn't seem to recognize as being supported by the current model/provider: {', '.join(unsupported_openai_params)}" + "\nThere are sometimes false positives here." ) - for key in unsupported_openai_params: - del backend_specific[key] return backend_specific diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 13961339..7e0ce1b1 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -7,6 +7,7 @@ import os import warnings from collections.abc import AsyncGenerator, Callable, Coroutine +from dataclasses import fields from typing import Any from ibm_watsonx_ai import APIClient, Credentials @@ -110,7 +111,8 @@ def __init__( # These are usually values that must be extracted before hand or that are common among backend providers. self.to_mellea_model_opts_map_chats = { "system": ModelOption.SYSTEM_PROMPT, - "max_tokens": ModelOption.MAX_NEW_TOKENS, + "max_tokens": ModelOption.MAX_NEW_TOKENS, # Is being deprecated in favor of `max_completion_tokens.` + "max_completion_tokens": ModelOption.MAX_NEW_TOKENS, "tools": ModelOption.TOOLS, "stream": ModelOption.STREAM, } @@ -120,7 +122,7 @@ def __init__( # will be omitted here so that they will be removed when model_options are processed # for the call to the model. self.from_mellea_model_opts_map_chats = { - ModelOption.MAX_NEW_TOKENS: "max_tokens" + ModelOption.MAX_NEW_TOKENS: "max_completion_tokens" } # See notes above. @@ -168,7 +170,10 @@ def _get_watsonx_model_id(self) -> str: def filter_chat_completions_kwargs(self, model_options: dict) -> dict: """Filter kwargs to only include valid watsonx chat.completions.create parameters.""" - chat_params = TextChatParameters.get_sample_params().keys() + # TextChatParameters.get_sample_params().keys() can't be completely trusted. It doesn't always contain all + # all of the accepted keys. In version 1.3.39, max_tokens was removed even though it's still accepted. + # It's a dataclass so use the fields function to get the names. + chat_params = {field.name for field in fields(TextChatParameters)} return {k: v for k, v in model_options.items() if k in chat_params} def _simplify_and_merge( diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index a7f4879d..6717ef0f 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -1,4 +1,5 @@ import asyncio +import os import pytest from mellea import MelleaSession, generative @@ -8,14 +9,73 @@ from mellea.stdlib.chat import Message from mellea.stdlib.sampling import RejectionSamplingStrategy +@pytest.fixture(scope="function") +def backend(gh_run: int): + """Shared OpenAI backend configured for Ollama.""" + if gh_run == 1: + # LiteLLM prepends 127.0.0.1 with a `/` which causes issues. + url = os.environ.get("OLLAMA_HOST", None) + if url is None: + url = "http://localhost:11434" + else: + url = url.replace("127.0.0.1", "http://localhost") + + return LiteLLMBackend( + model_id="ollama_chat/llama3.2:1b", + base_url=url, + model_options={"api_base": url} + ) + else: + return LiteLLMBackend() @pytest.fixture(scope="function") -def session(): +def session(backend): """Fresh Ollama session for each test.""" - session = MelleaSession(LiteLLMBackend()) + session = MelleaSession(backend=backend) yield session session.reset() +# Use capsys to check that the logging is working. +def test_make_backend_specific_and_remove(): + # Doesn't need to be a real model here; just a provider that LiteLLM knows about. + backend = LiteLLMBackend(model_id="ollama_chat/") + + params = { + "max_tokens": 1, + "stream": 1, + ModelOption.TEMPERATURE: 1, + "unknown_parameter": 1, # Unknown / non-OpenAI parameter + "web_search_options": 1, # Standard OpenAI parameter not supported by Ollama. + } + + mellea = backend._simplify_and_merge(params) + backend_specific = backend._make_backend_specific_and_remove(mellea) + + + # All of these options should be in the model options that get passed to LiteLLM since it handles the dropping. + assert "max_completion_tokens" in backend_specific, "max_tokens should get remapped to max_completion_tokens for ollama_chat/" + assert "stream" in backend_specific + assert "temperature" in backend_specific + assert "unknown_parameter" in backend_specific + assert "web_search_options" in backend_specific + + # TODO: Investigate why this isn't working on github action runners. + # Add the capsys or capfd fixture back. + # out = capsys.readouterr() + # # Check for the specific warning logs. + # assert "supported by the current model/provider: web_search_options" in out.out + # assert "mellea won't validate the following params that may cause issues: unknown_parameter" in out.out + + # Do a quick test for the Watsonx specific scenario. + backend = LiteLLMBackend(model_id="watsonx/") + watsonx_params = {"max_tokens": 1} + + # Make sure we make it Mellea specific correctly. + watsonx_mellea = backend._simplify_and_merge(watsonx_params) + assert ModelOption.MAX_NEW_TOKENS in watsonx_mellea + + watsonx_backend_specific = backend._make_backend_specific_and_remove(watsonx_mellea) + assert "max_tokens" in watsonx_backend_specific @pytest.mark.qualitative def test_litellm_ollama_chat(session): @@ -26,7 +86,6 @@ def test_litellm_ollama_chat(session): f"Expected a message with content containing 2 but found {res}" ) -@pytest.mark.qualitative def test_litellm_ollama_instruct(session): res = session.instruct( "Write an email to the interns.", @@ -37,7 +96,6 @@ def test_litellm_ollama_instruct(session): assert isinstance(res.value, str) -@pytest.mark.qualitative def test_litellm_ollama_instruct_options(session): model_options={ ModelOption.SEED: 123, @@ -59,11 +117,6 @@ def test_litellm_ollama_instruct_options(session): # make sure that homer_simpson is in the logged model_options assert "homer_simpson" in res._generate_log.model_options - # make sure the backend function filters out the model option when passing to the generate call - backend = session.backend - assert isinstance(backend, LiteLLMBackend) - assert "homer_simpson" not in backend._make_backend_specific_and_remove(model_options) - @pytest.mark.qualitative def test_gen_slot(session): @@ -77,7 +130,6 @@ def is_happy(text: str) -> bool: # should yield to true - but, of course, is model dependent assert h is True -@pytest.mark.qualitative async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) @@ -104,7 +156,6 @@ async def test_async_parallel_requests(session): assert m1_final_val == mot1.value assert m2_final_val == mot2.value -@pytest.mark.qualitative async def test_async_avalue(session): mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 9cfd7b23..274374a5 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -34,6 +34,23 @@ def session(backend: WatsonxAIBackend): yield session session.reset() +@pytest.mark.qualitative +def test_filter_chat_completions_kwargs(backend: WatsonxAIBackend): + """Detect changes to the WatsonxAI TextChatParameters.""" + + known_keys = ['frequency_penalty', 'logprobs', 'top_logprobs', 'presence_penalty', 'response_format', 'temperature', 'max_tokens', 'max_completion_tokens', 'time_limit', 'top_p', 'n', 'logit_bias', 'seed', 'stop', 'guided_choice', 'guided_regex', 'guided_grammar', 'guided_json'] + test_dict = {key: 1 for key in known_keys} + + # Make sure keys that we think should be in the TextChatParameters are there. + filtered_dict = backend.filter_chat_completions_kwargs(test_dict) + + for key in known_keys: + assert key in filtered_dict + + # Make sure unsupported keys still get filtered out. + incorrect_dict = {"random": 1} + filtered_incorrect_dict = backend.filter_chat_completions_kwargs(incorrect_dict) + assert "random" not in filtered_incorrect_dict @pytest.mark.qualitative def test_instruct(session: MelleaSession): diff --git a/test/conftest.py b/test/conftest.py index e95ce41b..6e5d83c6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -21,6 +21,6 @@ def pytest_runtest_setup(item): gh_run = int(os.environ.get("CICD", 0)) if gh_run == 1: - pytest.xfail( + pytest.skip( reason="Skipping qualitative test: got env variable CICD == 1. Used only in gh workflows." )