diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 670076b..8d8ce34 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -814,9 +814,7 @@ def convert_to_oci_tool( "type": "object", "properties": { p_name: { - "type": JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), p_def.get("type", "string") - ), + "type": p_def.get("type", "any"), "description": p_def.get("description", ""), } for p_name, p_def in tool.args.items() diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index 240598c..88a1f1f 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -42,6 +42,7 @@ def completion_response_to_text(self, response: Any) -> str: class GenericProvider(Provider): """Provider for models using generic API spec.""" + stop_sequence_key: str = "stop" def __init__(self) -> None: @@ -51,10 +52,11 @@ def __init__(self) -> None: def completion_response_to_text(self, response: Any) -> str: return response.data.inference_response.choices[0].text - + class MetaProvider(GenericProvider): """Provider for Meta models. This provider is for backward compatibility.""" + pass @@ -217,15 +219,14 @@ def _get_provider(self, provider_map: Mapping[str, Any]) -> Any: elif self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): raise ValueError("provider is required for custom endpoints.") else: - - provider = provider_map.get(self.model_id.split(".")[0].lower(), "generic") + provider = self.model_id.split(".")[0].lower() + # Use generic provider for non-custom endpoint + # if provider derived from the model_id is not in the provider map + if provider not in provider_map: + provider = "generic" if provider not in provider_map: - raise ValueError( - f"Invalid provider derived from model_id: {self.model_id} " - "Please explicitly pass in the supported provider " - "when using custom endpoint" - ) + raise ValueError(f"Invalid provider {provider}.") return provider_map[provider] diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py index b9940c8..d33adde 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -24,7 +24,7 @@ def __getattr__(self, val): # type: ignore[no-untyped-def] @pytest.mark.requires("oci") @pytest.mark.parametrize( - "test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"] + "test_model_id", ["cohere.command-r-16k", "meta.llama-3.3-70b-instruct"] ) def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None: """Test valid chat call to OCI Generative AI LLM service.""" @@ -77,25 +77,34 @@ def mocked_response(*args): # type: ignore[no-untyped-def] { "chat_response": MockResponseDict( { + "api_format": "GENERIC", "choices": [ MockResponseDict( { "message": MockResponseDict( { + "role": "ASSISTANT", + "name": None, "content": [ MockResponseDict( { "text": response_text, # noqa: E501 + "type": "TEXT", } ) ], "tool_calls": [ MockResponseDict( { - "type": "function", + "type": "FUNCTION", "id": "call_123", - "function": { - "name": "get_weather", # noqa: E501 + "name": "get_weather", # noqa: E501 + "arguments": '{"location": "current location"}', # noqa: E501 + "attribute_map": { + "id": "id", + "type": "type", + "name": "name", + "arguments": "arguments", # noqa: E501 }, } ) @@ -106,10 +115,10 @@ def mocked_response(*args): # type: ignore[no-untyped-def] } ) ], - "time_created": "2024-09-01T00:00:00Z", + "time_created": "2025-08-14T10:00:01.100000+00:00", } ), - "model_id": "meta.llama-3.1-70b-instruct", + "model_id": "meta.llama-3.3-70b-instruct", "model_version": "1.0.0", } ), @@ -164,11 +173,15 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] "tool_calls": [ MockResponseDict( { - "type": "function", + "type": "FUNCTION", "id": "call_456", - "function": { - "name": "get_weather", # noqa: E501 - "arguments": '{"location": "San Francisco"}', # noqa: E501 + "name": "get_weather", # noqa: E501 + "arguments": '{"location": "San Francisco"}', # noqa: E501 + "attribute_map": { + "id": "id", + "type": "type", + "name": "name", + "arguments": "arguments", # noqa: E501 }, } ) @@ -179,7 +192,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] } ) ], - "time_created": "2024-09-01T00:00:00Z", + "time_created": "2025-08-14T10:00:01.100000+00:00", } ), "model_id": "meta.llama-3-70b-instruct", @@ -285,9 +298,21 @@ def test_meta_tool_conversion(monkeypatch: MonkeyPatch) -> None: from pydantic import BaseModel, Field oci_gen_ai_client = MagicMock() - llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client) + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] + request = args[0] + # Check the conversion of tools to oci generic API spec + # Function tool + assert request.chat_request.tools[0].parameters["properties"] == { + "x": {"description": "Input number", "type": "integer"} + } + # Pydantic tool + assert request.chat_request.tools[1].parameters["properties"] == { + "x": {"description": "Input number", "type": "integer"}, + "y": {"description": "Input string", "type": "string"}, + } + return MockResponseDict( { "status": 200, @@ -295,26 +320,40 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] { "chat_response": MockResponseDict( { + "api_format": "GENERIC", "choices": [ MockResponseDict( { "message": MockResponseDict( { - "content": [ + "role": "ASSISTANT", + "content": None, + "tool_calls": [ MockResponseDict( - {"text": "Response"} + { + "arguments": '{"x": "10"}', # noqa: E501 + "id": "chatcmpl-tool-d123", # noqa: E501 + "name": "function_tool", + "type": "FUNCTION", + "attribute_map": { + "id": "id", + "type": "type", + "name": "name", + "arguments": "arguments", # noqa: E501 + }, + } ) - ] + ], } ), - "finish_reason": "completed", + "finish_reason": "tool_calls", } ) ], - "time_created": "2024-09-01T00:00:00Z", + "time_created": "2025-08-14T10:00:01.100000+00:00", } ), - "model_id": "meta.llama-3-70b-instruct", + "model_id": "meta.llama-3.3-70b-instruct", "model_version": "1.0.0", } ), @@ -348,7 +387,10 @@ class PydanticTool(BaseModel): tools=[function_tool, PydanticTool], ).invoke(messages) - assert response.content == "Response" + # For tool calls, the response content should be empty. + assert response.content == "" + assert len(response.tool_calls) == 1 + assert response.tool_calls[0]["name"] == "function_tool" @pytest.mark.requires("oci") @@ -411,13 +453,13 @@ class WeatherResponse(BaseModel): conditions: str = Field(description="Weather conditions") oci_gen_ai_client = MagicMock() - llm = ChatOCIGenAI(model_id="cohere.command-r-16k", client=oci_gen_ai_client) + llm = ChatOCIGenAI(model_id="cohere.command-latest", client=oci_gen_ai_client) def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] # Verify that response_format contains the schema request = args[0] - assert request.response_format["type"] == "JSON_OBJECT" - assert "schema" in request.response_format + assert request.chat_request.response_format["type"] == "JSON_OBJECT" + assert "schema" in request.chat_request.response_format return MockResponseDict( { @@ -426,8 +468,9 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] { "chat_response": MockResponseDict( { + "api_format": "COHERE", "text": '{"temperature": 25.5, "conditions": "Sunny"}', - "finish_reason": "completed", + "finish_reason": "COMPLETE", "is_search_required": None, "search_queries": None, "citations": None, @@ -435,7 +478,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] "tool_calls": None, } ), - "model_id": "cohere.command-r-16k", + "model_id": "cohere.command-latest", "model_version": "1.0.0", } ), @@ -462,13 +505,19 @@ def test_auth_file_location(monkeypatch: MonkeyPatch) -> None: from unittest.mock import patch with patch("oci.config.from_file") as mock_from_file: - custom_config_path = "/custom/path/config" - ChatOCIGenAI( - model_id="cohere.command-r-16k", auth_file_location=custom_config_path - ) - mock_from_file.assert_called_once_with( - file_location=custom_config_path, profile_name="DEFAULT" - ) + with patch( + "oci.generative_ai_inference.generative_ai_inference_client.validate_config" + ): + with patch("oci.base_client.validate_config"): + with patch("oci.signer.load_private_key"): + custom_config_path = "/custom/path/config" + ChatOCIGenAI( + model_id="cohere.command-r-16k", + auth_file_location=custom_config_path, + ) + mock_from_file.assert_called_once_with( + file_location=custom_config_path, profile_name="DEFAULT" + ) @pytest.mark.requires("oci") @@ -524,3 +573,17 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def] assert isinstance(response["parsed"], WeatherResponse) assert response["parsed"].temperature == 25.5 assert response["parsed"].conditions == "Sunny" + + +def test_get_provider(): + """Test determining the provider based on the model_id.""" + model_provider_map = { + "cohere.command-latest": "CohereProvider", + "meta.llama-3.3-70b-instruct": "MetaProvider", + "xai.grok-3": "GenericProvider", + } + for model_id, provider_name in model_provider_map.items(): + assert ( + ChatOCIGenAI(model_id=model_id)._provider.__class__.__name__ + == provider_name + )