Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,7 @@ def convert_to_oci_tool(
"type": "object",
"properties": {
p_name: {
"type": JSON_TO_PYTHON_TYPES.get(
p_def.get("type"), p_def.get("type", "string")
),
"type": p_def.get("type", "any"),
"description": p_def.get("description", ""),
}
for p_name, p_def in tool.args.items()
Expand Down
17 changes: 9 additions & 8 deletions libs/oci/langchain_oci/llms/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def completion_response_to_text(self, response: Any) -> str:

class GenericProvider(Provider):
"""Provider for models using generic API spec."""

stop_sequence_key: str = "stop"

def __init__(self) -> None:
Expand All @@ -51,10 +52,11 @@ def __init__(self) -> None:

def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text


class MetaProvider(GenericProvider):
"""Provider for Meta models. This provider is for backward compatibility."""

pass


Expand Down Expand Up @@ -217,15 +219,14 @@ def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
elif self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
raise ValueError("provider is required for custom endpoints.")
else:

provider = provider_map.get(self.model_id.split(".")[0].lower(), "generic")
provider = self.model_id.split(".")[0].lower()
# Use generic provider for non-custom endpoint
# if provider derived from the model_id is not in the provider map
if provider not in provider_map:
provider = "generic"

if provider not in provider_map:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
raise ValueError(f"Invalid provider {provider}.")
return provider_map[provider]


Expand Down
125 changes: 94 additions & 31 deletions libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __getattr__(self, val): # type: ignore[no-untyped-def]

@pytest.mark.requires("oci")
@pytest.mark.parametrize(
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
"test_model_id", ["cohere.command-r-16k", "meta.llama-3.3-70b-instruct"]
)
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid chat call to OCI Generative AI LLM service."""
Expand Down Expand Up @@ -77,25 +77,34 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
{
"chat_response": MockResponseDict(
{
"api_format": "GENERIC",
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"role": "ASSISTANT",
"name": None,
"content": [
MockResponseDict(
{
"text": response_text, # noqa: E501
"type": "TEXT",
}
)
],
"tool_calls": [
MockResponseDict(
{
"type": "function",
"type": "FUNCTION",
"id": "call_123",
"function": {
"name": "get_weather", # noqa: E501
"name": "get_weather", # noqa: E501
"arguments": '{"location": "current location"}', # noqa: E501
"attribute_map": {
"id": "id",
"type": "type",
"name": "name",
"arguments": "arguments", # noqa: E501
},
}
)
Expand All @@ -106,10 +115,10 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
}
)
],
"time_created": "2024-09-01T00:00:00Z",
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3.1-70b-instruct",
"model_id": "meta.llama-3.3-70b-instruct",
"model_version": "1.0.0",
}
),
Expand Down Expand Up @@ -164,11 +173,15 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
"tool_calls": [
MockResponseDict(
{
"type": "function",
"type": "FUNCTION",
"id": "call_456",
"function": {
"name": "get_weather", # noqa: E501
"arguments": '{"location": "San Francisco"}', # noqa: E501
"name": "get_weather", # noqa: E501
"arguments": '{"location": "San Francisco"}', # noqa: E501
"attribute_map": {
"id": "id",
"type": "type",
"name": "name",
"arguments": "arguments", # noqa: E501
},
}
)
Expand All @@ -179,7 +192,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
}
)
],
"time_created": "2024-09-01T00:00:00Z",
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3-70b-instruct",
Expand Down Expand Up @@ -285,36 +298,62 @@ def test_meta_tool_conversion(monkeypatch: MonkeyPatch) -> None:
from pydantic import BaseModel, Field

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client)
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
request = args[0]
# Check the conversion of tools to oci generic API spec
# Function tool
assert request.chat_request.tools[0].parameters["properties"] == {
"x": {"description": "Input number", "type": "integer"}
}
# Pydantic tool
assert request.chat_request.tools[1].parameters["properties"] == {
"x": {"description": "Input number", "type": "integer"},
"y": {"description": "Input string", "type": "string"},
}

return MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"api_format": "GENERIC",
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"content": [
"role": "ASSISTANT",
"content": None,
"tool_calls": [
MockResponseDict(
{"text": "Response"}
{
"arguments": '{"x": "10"}', # noqa: E501
"id": "chatcmpl-tool-d123", # noqa: E501
"name": "function_tool",
"type": "FUNCTION",
"attribute_map": {
"id": "id",
"type": "type",
"name": "name",
"arguments": "arguments", # noqa: E501
},
}
)
]
],
}
),
"finish_reason": "completed",
"finish_reason": "tool_calls",
}
)
],
"time_created": "2024-09-01T00:00:00Z",
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3-70b-instruct",
"model_id": "meta.llama-3.3-70b-instruct",
"model_version": "1.0.0",
}
),
Expand Down Expand Up @@ -348,7 +387,10 @@ class PydanticTool(BaseModel):
tools=[function_tool, PydanticTool],
).invoke(messages)

assert response.content == "Response"
# For tool calls, the response content should be empty.
assert response.content == ""
assert len(response.tool_calls) == 1
assert response.tool_calls[0]["name"] == "function_tool"


@pytest.mark.requires("oci")
Expand Down Expand Up @@ -411,13 +453,13 @@ class WeatherResponse(BaseModel):
conditions: str = Field(description="Weather conditions")

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="cohere.command-r-16k", client=oci_gen_ai_client)
llm = ChatOCIGenAI(model_id="cohere.command-latest", client=oci_gen_ai_client)

def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
# Verify that response_format contains the schema
request = args[0]
assert request.response_format["type"] == "JSON_OBJECT"
assert "schema" in request.response_format
assert request.chat_request.response_format["type"] == "JSON_OBJECT"
assert "schema" in request.chat_request.response_format

return MockResponseDict(
{
Expand All @@ -426,16 +468,17 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
{
"chat_response": MockResponseDict(
{
"api_format": "COHERE",
"text": '{"temperature": 25.5, "conditions": "Sunny"}',
"finish_reason": "completed",
"finish_reason": "COMPLETE",
"is_search_required": None,
"search_queries": None,
"citations": None,
"documents": None,
"tool_calls": None,
}
),
"model_id": "cohere.command-r-16k",
"model_id": "cohere.command-latest",
"model_version": "1.0.0",
}
),
Expand All @@ -462,13 +505,19 @@ def test_auth_file_location(monkeypatch: MonkeyPatch) -> None:
from unittest.mock import patch

with patch("oci.config.from_file") as mock_from_file:
custom_config_path = "/custom/path/config"
ChatOCIGenAI(
model_id="cohere.command-r-16k", auth_file_location=custom_config_path
)
mock_from_file.assert_called_once_with(
file_location=custom_config_path, profile_name="DEFAULT"
)
with patch(
"oci.generative_ai_inference.generative_ai_inference_client.validate_config"
):
with patch("oci.base_client.validate_config"):
with patch("oci.signer.load_private_key"):
custom_config_path = "/custom/path/config"
ChatOCIGenAI(
model_id="cohere.command-r-16k",
auth_file_location=custom_config_path,
)
mock_from_file.assert_called_once_with(
file_location=custom_config_path, profile_name="DEFAULT"
)


@pytest.mark.requires("oci")
Expand Down Expand Up @@ -524,3 +573,17 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
assert isinstance(response["parsed"], WeatherResponse)
assert response["parsed"].temperature == 25.5
assert response["parsed"].conditions == "Sunny"


def test_get_provider():
"""Test determining the provider based on the model_id."""
model_provider_map = {
"cohere.command-latest": "CohereProvider",
"meta.llama-3.3-70b-instruct": "MetaProvider",
"xai.grok-3": "GenericProvider",
}
for model_id, provider_name in model_provider_map.items():
assert (
ChatOCIGenAI(model_id=model_id)._provider.__class__.__name__
== provider_name
)