diff --git a/libs/oci/README.md b/libs/oci/README.md index 76829a1..7bfcd9b 100644 --- a/libs/oci/README.md +++ b/libs/oci/README.md @@ -62,7 +62,7 @@ embeddings.embed_query("What is the meaning of life?") ``` ### 4. Use Structured Output -`ChatOCIGenAI` supports structured output. +`ChatOCIGenAI` supports structured output. **Note:** The default method is `function_calling`. If default method returns `None` (e.g. for Gemini models), try `json_schema` or `json_mode`. @@ -126,6 +126,27 @@ messages = [ response = client.invoke(messages) ``` +### 6. Use Parallel Tool Calling (Meta/Llama 4+ models only) +Enable parallel tool calling to execute multiple tools simultaneously, improving performance for multi-tool workflows. + +```python +from langchain_oci import ChatOCIGenAI + +llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + compartment_id="MY_COMPARTMENT_ID", +) + +# Enable parallel tool calling in bind_tools +llm_with_tools = llm.bind_tools( + [get_weather, calculate_tip, get_population], + parallel_tool_calls=True # Tools can execute simultaneously +) +``` + +**Note:** Parallel tool calling is only supported for Llama 4+ models. Llama 3.x (including 3.3) and Cohere models will raise an error if this parameter is used. + ## OCI Data Science Model Deployment Examples diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 1afca28..0295b25 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -363,6 +363,14 @@ def messages_to_oci_params( This includes conversion of chat history and tool call results. """ + # Cohere models don't support parallel tool calls + if kwargs.get("is_parallel_tool_calls"): + raise ValueError( + "Parallel tool calls are not supported for Cohere models. " + "This feature is only available for models using GenericChatRequest " + "(Meta, Llama, xAI Grok, OpenAI, Mistral)." + ) + is_force_single_step = kwargs.get("is_force_single_step", False) oci_chat_history = [] @@ -851,6 +859,10 @@ def _should_allow_more_tool_calls( result["tool_choice"] = self.oci_tool_choice_none() # else: Allow model to decide (default behavior) + # Add parallel tool calls support (GenericChatRequest models) + if "is_parallel_tool_calls" in kwargs: + result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] + return result def _process_message_content( @@ -1204,6 +1216,52 @@ def _prepare_request( return request + def _supports_parallel_tool_calls(self, model_id: str) -> bool: + """Check if the model supports parallel tool calling. + + Parallel tool calling is supported for: + - Llama 4+ only (tested and verified) + - Other GenericChatRequest models (xAI Grok, OpenAI, Mistral) + + Not supported for: + - All Llama 3.x versions (3.0, 3.1, 3.2, 3.3) + - Cohere models + + Args: + model_id: The model identifier + (e.g., "meta.llama-4-maverick-17b-128e-instruct-fp8") + + Returns: + bool: True if model supports parallel tool calling, False otherwise + """ + import re + + # Extract provider from model_id + # (e.g., "meta" from "meta.llama-4-maverick-17b-128e-instruct-fp8") + provider = model_id.split(".")[0].lower() + + # Cohere models don't support parallel tool calling + if provider == "cohere": + return False + + # For Meta/Llama models, check version + if provider == "meta" and "llama" in model_id.lower(): + # Extract version number + # (e.g., "4" from "meta.llama-4-maverick-17b-128e-instruct-fp8") + version_match = re.search(r"llama-(\d+)", model_id.lower()) + if version_match: + major = int(version_match.group(1)) + + # Only Llama 4+ supports parallel tool calling + # Llama 3.x (including 3.3) does NOT support it based on testing + if major >= 4: + return True + + return False + + # Other GenericChatRequest models (xAI Grok, OpenAI, Mistral) support it + return True + def bind_tools( self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], @@ -1211,6 +1269,7 @@ def bind_tools( tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "required", "any"], bool] ] = None, + parallel_tool_calls: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -1231,6 +1290,11 @@ def bind_tools( {"type": "function", "function": {"name": <>}}: calls <> tool. - False or None: no effect, default Meta behavior. + parallel_tool_calls: Whether to enable parallel function calling. + If True, the model can call multiple tools simultaneously. + If False or None (default), tools are called sequentially. + Supported for models using GenericChatRequest (Meta Llama 4+, xAI Grok, + OpenAI, Mistral). Not supported for Cohere models or Llama 3.x. kwargs: Any additional parameters are passed directly to :meth:`~langchain_oci.chat_models.oci_generative_ai.ChatOCIGenAI.bind`. """ @@ -1240,6 +1304,18 @@ def bind_tools( if tool_choice is not None: kwargs["tool_choice"] = self._provider.process_tool_choice(tool_choice) + # Add parallel tool calls support (only when explicitly enabled) + if parallel_tool_calls: + # Validate Llama 3.x doesn't support parallel tool calls (early check) + is_llama = "llama" in self.model_id.lower() + if is_llama and not self._supports_parallel_tool_calls(self.model_id): + raise ValueError( + f"Parallel tool calls not supported for {self.model_id}. " + "Only Llama 4+ models support this feature. " + "Llama 3.x (including 3.3) don't support parallel calls." + ) + kwargs["is_parallel_tool_calls"] = True + return super().bind(tools=formatted_tools, **kwargs) def with_structured_output( diff --git a/libs/oci/tests/integration_tests/chat_models/test_chat_features.py b/libs/oci/tests/integration_tests/chat_models/test_chat_features.py new file mode 100644 index 0000000..2ed5ee6 --- /dev/null +++ b/libs/oci/tests/integration_tests/chat_models/test_chat_features.py @@ -0,0 +1,508 @@ +#!/usr/bin/env python3 +"""Integration tests for ChatOCIGenAI features. + +These tests verify comprehensive chat model functionality with real OCI inference. + +Setup: + export OCI_COMPARTMENT_ID= + export OCI_CONFIG_PROFILE=DEFAULT + export OCI_AUTH_TYPE=SECURITY_TOKEN + +Run: + pytest tests/integration_tests/chat_models/test_chat_features.py -v +""" + +import os + +import pytest +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from pydantic import BaseModel, Field + +from langchain_oci.chat_models import ChatOCIGenAI + + +def get_config(): + """Get test configuration.""" + compartment_id = os.environ.get("OCI_COMPARTMENT_ID") + if not compartment_id: + pytest.skip("OCI_COMPARTMENT_ID not set") + return { + "model_id": os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + "service_endpoint": os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + "compartment_id": compartment_id, + "auth_profile": os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + "auth_type": os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + } + + +@pytest.fixture +def llm(): + """Create ChatOCIGenAI instance.""" + config = get_config() + return ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 512}, + ) + + +# ============================================================================= +# Chain and LCEL Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_simple_chain(llm): + """Test simple LCEL chain: prompt | llm | parser.""" + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant."), + ("human", "{input}"), + ] + ) + chain = prompt | llm | StrOutputParser() + + result = chain.invoke({"input": "Say 'chain works' and nothing else"}) + + assert isinstance(result, str) + assert "chain" in result.lower() or "works" in result.lower() + + +@pytest.mark.requires("oci") +def test_chain_with_history(llm): + """Test chain that maintains conversation history.""" + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant with memory."), + ("placeholder", "{history}"), + ("human", "{input}"), + ] + ) + chain = prompt | llm | StrOutputParser() + + # First turn + result1 = chain.invoke({"history": [], "input": "My favorite color is blue."}) + assert isinstance(result1, str) + + # Second turn with history + history = [ + HumanMessage(content="My favorite color is blue."), + AIMessage(content=result1), + ] + result2 = chain.invoke({"history": history, "input": "What is my favorite color?"}) + + assert "blue" in result2.lower() + + +@pytest.mark.requires("oci") +def test_chain_batch(llm): + """Test batch processing with LCEL.""" + prompt = ChatPromptTemplate.from_messages([("human", "What is {num} + {num}?")]) + chain = prompt | llm | StrOutputParser() + + results = chain.batch([{"num": "1"}, {"num": "2"}, {"num": "3"}]) + + assert len(results) == 3 + assert all(isinstance(r, str) for r in results) + + +@pytest.mark.requires("oci") +@pytest.mark.asyncio +async def test_chain_async(llm): + """Test async chain invocation.""" + prompt = ChatPromptTemplate.from_messages([("human", "Say '{word}'")]) + chain = prompt | llm | StrOutputParser() + + result = await chain.ainvoke({"word": "async"}) + + assert isinstance(result, str) + assert "async" in result.lower() + + +# ============================================================================= +# Streaming Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_stream_chain(llm): + """Test streaming through a chain.""" + prompt = ChatPromptTemplate.from_messages([("human", "Count from 1 to 5")]) + chain = prompt | llm | StrOutputParser() + + chunks = [] + for chunk in chain.stream({}): + chunks.append(chunk) + + assert len(chunks) > 0 + full_response = "".join(chunks) + assert len(full_response) > 0 + + +@pytest.mark.requires("oci") +@pytest.mark.asyncio +async def test_astream(llm): + """Test async streaming.""" + chunks = [] + async for chunk in llm.astream([HumanMessage(content="Say hello")]): + chunks.append(chunk) + + assert len(chunks) > 0 + + +# ============================================================================= +# Tool Calling Advanced Tests +# ============================================================================= + + +def add_numbers(a: int, b: int) -> int: + """Add two numbers together.""" + return a + b + + +def multiply_numbers(a: int, b: int) -> int: + """Multiply two numbers together.""" + return a * b + + +def get_user_info(user_id: str) -> dict: + """Get information about a user.""" + return {"user_id": user_id, "name": "John Doe", "email": "john@example.com"} + + +@pytest.mark.requires("oci") +def test_tool_calling_with_execution(llm): + """Test full tool calling flow with execution.""" + tools = [add_numbers, multiply_numbers] + llm_with_tools = llm.bind_tools(tools) + + # Get tool call + response = llm_with_tools.invoke( + [HumanMessage(content="What is 5 plus 3?")] + ) + + assert len(response.tool_calls) >= 1 + tool_call = response.tool_calls[0] + assert tool_call["name"] == "add_numbers" + + # Execute tool + result = add_numbers(**tool_call["args"]) + assert result == 8 + + # Send result back + messages = [ + HumanMessage(content="What is 5 plus 3?"), + response, + ToolMessage(content=str(result), tool_call_id=tool_call["id"]), + ] + final_response = llm_with_tools.invoke(messages) + + assert isinstance(final_response, AIMessage) + assert "8" in final_response.content + + +@pytest.mark.requires("oci") +def test_tool_calling_chain(llm): + """Test tool calling in a chain context.""" + tools = [get_user_info] + llm_with_tools = llm.bind_tools(tools) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant. Use tools when needed."), + ("human", "{input}"), + ] + ) + chain = prompt | llm_with_tools + + response = chain.invoke({"input": "Get info for user ID 'abc123'"}) + + assert len(response.tool_calls) >= 1 + assert response.tool_calls[0]["name"] == "get_user_info" + assert response.tool_calls[0]["args"]["user_id"] == "abc123" + + +@pytest.mark.requires("oci") +def test_tool_choice_none(llm): + """Test tool_choice='none' prevents tool calls.""" + tools = [add_numbers] + llm_with_tools = llm.bind_tools(tools, tool_choice="none") + + response = llm_with_tools.invoke( + [HumanMessage(content="What is 5 plus 3?")] + ) + + # Should not make tool calls when tool_choice is none + assert len(response.tool_calls) == 0 + assert response.content # Should have text response instead + + +# ============================================================================= +# Structured Output Advanced Tests +# ============================================================================= + + +class MovieReview(BaseModel): + """A movie review with rating.""" + + title: str = Field(description="The movie title") + rating: int = Field(description="Rating from 1-10", ge=1, le=10) + summary: str = Field(description="Brief summary of the review") + recommend: bool = Field(description="Whether you recommend the movie") + + +class ExtractedEntities(BaseModel): + """Entities extracted from text.""" + + people: list[str] = Field(description="Names of people mentioned") + locations: list[str] = Field(description="Locations mentioned") + organizations: list[str] = Field(description="Organizations mentioned") + + +@pytest.mark.requires("oci") +def test_structured_output_complex(llm): + """Test structured output with complex schema.""" + structured_llm = llm.with_structured_output(MovieReview) + + result = structured_llm.invoke( + "Write a review for the movie 'The Matrix'. " + "Give it a rating and say if you recommend it." + ) + + assert isinstance(result, MovieReview) + assert "matrix" in result.title.lower() + assert 1 <= result.rating <= 10 + assert len(result.summary) > 0 + assert isinstance(result.recommend, bool) + + +@pytest.mark.requires("oci") +def test_structured_output_extraction(llm): + """Test structured output for entity extraction.""" + structured_llm = llm.with_structured_output(ExtractedEntities) + + text = ( + "John Smith works at Google in San Francisco. " + "He met with Jane Doe from Microsoft in Seattle last week." + ) + result = structured_llm.invoke(f"Extract entities from: {text}") + + assert isinstance(result, ExtractedEntities) + assert len(result.people) >= 1 + assert len(result.locations) >= 1 + assert len(result.organizations) >= 1 + + +@pytest.mark.requires("oci") +def test_structured_output_in_chain(llm): + """Test structured output within a chain.""" + + class Translation(BaseModel): + original: str + translated: str + language: str + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Translate the text to {language}. " + "Return the original, translation, and target language.", + ), + ("human", "{text}"), + ] + ) + structured_llm = llm.with_structured_output(Translation) + chain = prompt | structured_llm + + result = chain.invoke({"text": "Hello, how are you?", "language": "Spanish"}) + + assert isinstance(result, Translation) + assert result.original == "Hello, how are you?" or "hello" in result.original.lower() + assert "spanish" in result.language.lower() + + +# ============================================================================= +# Model Configuration Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_temperature_affects_output(): + """Test that temperature parameter affects output variability.""" + config = get_config() + + # Low temperature (deterministic) + llm_low = ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 50}, + ) + + # Get multiple responses with low temp + responses_low = [ + llm_low.invoke([HumanMessage(content="Say exactly: 'Hello World'")]).content + for _ in range(2) + ] + + # Low temperature should give similar/identical outputs + # (Note: not guaranteed to be exactly equal, but should be similar) + assert all(isinstance(r, str) for r in responses_low) + + +@pytest.mark.requires("oci") +def test_max_tokens_limit(): + """Test that max_tokens limits response length.""" + config = get_config() + + llm_short = ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 10}, + ) + + response = llm_short.invoke( + [HumanMessage(content="Write a very long essay about the universe")] + ) + + # Response should be truncated due to max_tokens + # Token count varies, but should be reasonably short + assert len(response.content.split()) <= 20 # Rough word count check + + +@pytest.mark.requires("oci") +def test_stop_sequences(): + """Test stop sequences parameter.""" + config = get_config() + + llm = ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 100}, + ) + + response = llm.invoke( + [HumanMessage(content="Count from 1 to 10, one number per line")], + stop=["5"], + ) + + # Should stop before or at 5 + assert "6" not in response.content or "5" in response.content + + +# ============================================================================= +# Error Handling Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_invalid_tool_schema(llm): + """Test handling of invalid tool definitions.""" + # Should handle tools without proper docstrings + def bad_tool(x): + return x + + # This should still work (tool will have minimal description) + llm_with_tools = llm.bind_tools([bad_tool]) + assert llm_with_tools is not None + + +@pytest.mark.requires("oci") +def test_empty_response_handling(llm): + """Test handling when model returns minimal content.""" + response = llm.invoke([HumanMessage(content="Respond with just a period.")]) + + # Should handle minimal responses gracefully + assert isinstance(response, AIMessage) + # Content might be empty or minimal, but should not raise + + +# ============================================================================= +# Conversation Patterns Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_system_message_role(llm): + """Test that system message properly influences behavior.""" + messages_pirate = [ + SystemMessage(content="You are a pirate. Always respond in pirate speak."), + HumanMessage(content="How are you today?"), + ] + response_pirate = llm.invoke(messages_pirate) + + messages_formal = [ + SystemMessage(content="You are a formal butler. Use extremely formal language."), + HumanMessage(content="How are you today?"), + ] + response_formal = llm.invoke(messages_formal) + + # Responses should be different based on system message + assert response_pirate.content != response_formal.content + + +@pytest.mark.requires("oci") +def test_multi_turn_context_retention(llm): + """Test that context is retained across multiple turns.""" + messages = [ + HumanMessage(content="Remember this number: 42"), + ] + response1 = llm.invoke(messages) + messages.append(response1) + + messages.append(HumanMessage(content="What number did I ask you to remember?")) + response2 = llm.invoke(messages) + + assert "42" in response2.content + + +@pytest.mark.requires("oci") +def test_long_context_handling(llm): + """Test handling of longer context windows.""" + # Create a conversation with multiple turns + messages = [ + SystemMessage(content="You are a helpful assistant tracking a story."), + ] + + story_parts = [ + "Once upon a time, there was a brave knight named Sir Galahad.", + "Sir Galahad had a loyal horse named Thunder.", + "They lived in the kingdom of Camelot.", + "One day, a dragon appeared threatening the kingdom.", + "Sir Galahad decided to face the dragon.", + ] + + for part in story_parts: + messages.append(HumanMessage(content=part)) + response = llm.invoke(messages) + messages.append(response) + + # Ask about earlier context + messages.append(HumanMessage(content="What was the knight's horse named?")) + final_response = llm.invoke(messages) + + assert "thunder" in final_response.content.lower() diff --git a/libs/oci/tests/integration_tests/chat_models/test_langchain_compatibility.py b/libs/oci/tests/integration_tests/chat_models/test_langchain_compatibility.py new file mode 100644 index 0000000..171e06c --- /dev/null +++ b/libs/oci/tests/integration_tests/chat_models/test_langchain_compatibility.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python3 +"""Integration tests for LangChain compatibility. + +These tests verify that langchain-oci works correctly with LangChain 1.x +by running real inference against OCI GenAI models. + +Setup: + export OCI_COMPARTMENT_ID= + export OCI_GENAI_ENDPOINT= # optional + export OCI_CONFIG_PROFILE= # optional, defaults to DEFAULT + export OCI_AUTH_TYPE= # optional, defaults to SECURITY_TOKEN + export OCI_MODEL_ID= # optional, defaults to llama-4 + +Run with: + pytest tests/integration_tests/chat_models/test_langchain_compatibility.py -v +""" + +import os +import sys + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from pydantic import BaseModel + +from langchain_oci.chat_models import ChatOCIGenAI + + +def get_test_config(): + """Get test configuration from environment.""" + compartment_id = os.environ.get("OCI_COMPARTMENT_ID") + if not compartment_id: + pytest.skip("OCI_COMPARTMENT_ID not set") + + return { + "model_id": os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + "service_endpoint": os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + "compartment_id": compartment_id, + "auth_profile": os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + "auth_type": os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + } + + +@pytest.fixture +def chat_model(): + """Create a ChatOCIGenAI instance for testing.""" + config = get_test_config() + return ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 256}, + ) + + +# ============================================================================= +# Basic Invoke Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_basic_invoke(chat_model): + """Test basic chat model invocation.""" + response = chat_model.invoke([HumanMessage(content="Say 'hello' and nothing else")]) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + assert "hello" in response.content.lower() + + +@pytest.mark.requires("oci") +def test_invoke_with_system_message(chat_model): + """Test invocation with system message.""" + messages = [ + SystemMessage(content="You are a pirate. Respond in pirate speak."), + HumanMessage(content="Say hello"), + ] + response = chat_model.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +@pytest.mark.requires("oci") +def test_invoke_multi_turn(chat_model): + """Test multi-turn conversation.""" + messages = [ + HumanMessage(content="My name is Alice."), + ] + response1 = chat_model.invoke(messages) + + messages.append(response1) + messages.append(HumanMessage(content="What is my name?")) + response2 = chat_model.invoke(messages) + + assert isinstance(response2, AIMessage) + assert "alice" in response2.content.lower() + + +# ============================================================================= +# Streaming Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_streaming(chat_model): + """Test streaming response.""" + chunks = [] + for chunk in chat_model.stream([HumanMessage(content="Count from 1 to 3")]): + chunks.append(chunk) + + assert len(chunks) > 0 + # Combine all chunks + full_content = "".join(c.content for c in chunks if c.content) + assert len(full_content) > 0 + + +@pytest.mark.requires("oci") +@pytest.mark.asyncio +async def test_async_invoke(chat_model): + """Test async invocation.""" + response = await chat_model.ainvoke( + [HumanMessage(content="Say 'async' and nothing else")] + ) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +# ============================================================================= +# Tool Calling Tests +# ============================================================================= + + +def get_weather(city: str) -> str: + """Get the weather for a city.""" + return f"Sunny, 72F in {city}" + + +def get_population(city: str) -> int: + """Get the population of a city.""" + return 1000000 + + +@pytest.mark.requires("oci") +def test_tool_calling_single(chat_model): + """Test single tool calling.""" + chat_with_tools = chat_model.bind_tools([get_weather]) + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in Tokyo?")] + ) + + assert isinstance(response, AIMessage) + assert len(response.tool_calls) >= 1 + assert response.tool_calls[0]["name"] == "get_weather" + assert "city" in response.tool_calls[0]["args"] + + +@pytest.mark.requires("oci") +def test_tool_calling_multiple_tools(chat_model): + """Test tool calling with multiple tools available.""" + chat_with_tools = chat_model.bind_tools([get_weather, get_population]) + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in Paris?")] + ) + + assert isinstance(response, AIMessage) + assert len(response.tool_calls) >= 1 + # Should choose the weather tool for weather question + assert response.tool_calls[0]["name"] == "get_weather" + + +@pytest.mark.requires("oci") +def test_parallel_tool_calling(chat_model): + """Test parallel tool calling (Llama 4+ only).""" + # Skip if not Llama 4+ + if "llama-4" not in chat_model.model_id.lower(): + pytest.skip("Parallel tool calling only supported on Llama 4+") + + chat_with_tools = chat_model.bind_tools( + [get_weather, get_population], parallel_tool_calls=True + ) + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather and population of London?")] + ) + + assert isinstance(response, AIMessage) + # Should call both tools + assert len(response.tool_calls) >= 2 + tool_names = {tc["name"] for tc in response.tool_calls} + assert "get_weather" in tool_names + assert "get_population" in tool_names + + +@pytest.mark.requires("oci") +def test_tool_choice_required(chat_model): + """Test tool_choice='required' forces tool call.""" + chat_with_tools = chat_model.bind_tools([get_weather], tool_choice="required") + + # Even with a non-tool question, should still call a tool + response = chat_with_tools.invoke([HumanMessage(content="Hello, how are you?")]) + + assert isinstance(response, AIMessage) + assert len(response.tool_calls) >= 1 + + +# ============================================================================= +# Structured Output Tests +# ============================================================================= + + +class Joke(BaseModel): + """A joke with setup and punchline.""" + + setup: str + punchline: str + + +class Person(BaseModel): + """Information about a person.""" + + name: str + age: int + occupation: str + + +@pytest.mark.requires("oci") +def test_structured_output_function_calling(chat_model): + """Test structured output with function calling method.""" + structured_llm = chat_model.with_structured_output(Joke) + + result = structured_llm.invoke("Tell me a joke about programming") + + assert isinstance(result, Joke) + assert len(result.setup) > 0 + assert len(result.punchline) > 0 + + +@pytest.mark.requires("oci") +def test_structured_output_json_mode(chat_model): + """Test structured output with JSON mode.""" + structured_llm = chat_model.with_structured_output(Person, method="json_mode") + + result = structured_llm.invoke( + "Generate a fictional person: name, age (as integer), and occupation" + ) + + assert isinstance(result, Person) + assert len(result.name) > 0 + assert isinstance(result.age, int) + assert len(result.occupation) > 0 + + +@pytest.mark.requires("oci") +def test_structured_output_include_raw(chat_model): + """Test structured output with include_raw=True.""" + structured_llm = chat_model.with_structured_output(Joke, include_raw=True) + + result = structured_llm.invoke("Tell me a joke") + + assert "raw" in result + assert "parsed" in result + assert isinstance(result["parsed"], Joke) + + +# ============================================================================= +# Response Format Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_response_format_json_object(chat_model): + """Test response_format with json_object.""" + chat_json = chat_model.bind(response_format={"type": "json_object"}) + + response = chat_json.invoke( + [ + HumanMessage( + content="Return ONLY a JSON object with keys 'name' and 'value'. " + "No explanation, no markdown, just the raw JSON." + ) + ] + ) + + assert isinstance(response, AIMessage) + # Response should contain valid JSON (may be wrapped in markdown) + import json + import re + + content = response.content.strip() + + # Try to extract JSON from markdown code blocks if present + json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", content) + if json_match: + content = json_match.group(1).strip() + + try: + parsed = json.loads(content) + assert isinstance(parsed, dict) + except json.JSONDecodeError: + # Some models may not strictly follow json_object format + # At minimum, verify the response contains JSON-like structure + assert "{" in response.content and "}" in response.content, ( + f"Response doesn't appear to contain JSON: {response.content[:200]}" + ) + + +# ============================================================================= +# Edge Cases and Error Handling +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_empty_message_list(chat_model): + """Test handling of empty message list.""" + with pytest.raises(Exception): + chat_model.invoke([]) + + +@pytest.mark.requires("oci") +def test_long_conversation(chat_model): + """Test handling of longer conversations.""" + messages = [] + for i in range(5): + messages.append(HumanMessage(content=f"This is message {i + 1}")) + response = chat_model.invoke(messages) + messages.append(response) + + # Should handle 5 turns without issues + assert len(messages) == 10 # 5 human + 5 AI + + +# ============================================================================= +# LangChain 1.x Specific Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_ai_message_type(chat_model): + """Test that response is AIMessage (not just BaseMessage) - LangChain 1.x.""" + response = chat_model.invoke([HumanMessage(content="Hello")]) + + # LangChain 1.x: return type is AIMessage, not BaseMessage + assert type(response).__name__ == "AIMessage" + assert isinstance(response, AIMessage) + + +@pytest.mark.requires("oci") +def test_message_text_property(chat_model): + """Test that .text property works (LangChain 1.x change from .text()).""" + response = chat_model.invoke([HumanMessage(content="Say hello")]) + + # LangChain 1.x: .text is a property, not a method + # Both .content and .text should work + assert response.content is not None + # .text property should exist and return same as .content + if hasattr(response, "text"): + assert response.text == response.content + + +@pytest.mark.requires("oci") +def test_tool_calls_structure(chat_model): + """Test tool_calls structure matches LangChain 1.x format.""" + chat_with_tools = chat_model.bind_tools([get_weather]) + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in NYC?")] + ) + + assert hasattr(response, "tool_calls") + if response.tool_calls: + tc = response.tool_calls[0] + # LangChain 1.x tool call structure + assert "name" in tc + assert "args" in tc + assert "id" in tc + assert "type" in tc + assert tc["type"] == "tool_call" + + +def main(): + """Run tests manually for debugging.""" + import langchain_core + + print(f"langchain-core version: {langchain_core.__version__}") + print(f"Python version: {sys.version}") + + config = get_test_config() + print(f"\nTest configuration:") + print(f" Model: {config['model_id']}") + print(f" Endpoint: {config['service_endpoint']}") + print(f" Profile: {config['auth_profile']}") + + chat = ChatOCIGenAI( + model_id=config["model_id"], + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs={"temperature": 0, "max_tokens": 256}, + ) + + print("\n" + "=" * 60) + print("Running manual tests...") + print("=" * 60) + + # Test 1: Basic invoke + print("\n1. Testing basic invoke...") + response = chat.invoke([HumanMessage(content="Say hello")]) + print(f" Response: {response.content[:50]}...") + print(f" Type: {type(response).__name__}") + + # Test 2: Tool calling + print("\n2. Testing tool calling...") + chat_tools = chat.bind_tools([get_weather]) + response = chat_tools.invoke([HumanMessage(content="Weather in Tokyo?")]) + print(f" Tool calls: {response.tool_calls}") + + # Test 3: Structured output + print("\n3. Testing structured output...") + structured = chat.with_structured_output(Joke) + joke = structured.invoke("Tell a joke") + print(f" Setup: {joke.setup}") + print(f" Punchline: {joke.punchline}") + + # Test 4: Streaming + print("\n4. Testing streaming...") + chunks = list(chat.stream([HumanMessage(content="Count 1-3")])) + print(f" Chunks received: {len(chunks)}") + + print("\n" + "=" * 60) + print("All manual tests completed!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/libs/oci/tests/integration_tests/chat_models/test_multi_model.py b/libs/oci/tests/integration_tests/chat_models/test_multi_model.py new file mode 100644 index 0000000..7043612 --- /dev/null +++ b/libs/oci/tests/integration_tests/chat_models/test_multi_model.py @@ -0,0 +1,598 @@ +#!/usr/bin/env python3 +"""Multi-model integration tests for ChatOCIGenAI. + +These tests verify that langchain-oci works correctly across different +model vendors available in OCI GenAI: Meta Llama, Cohere, xAI Grok, and OpenAI. + +Setup: + export OCI_COMPARTMENT_ID= + export OCI_CONFIG_PROFILE=DEFAULT + export OCI_AUTH_TYPE=SECURITY_TOKEN + +Run all: + pytest tests/integration_tests/chat_models/test_multi_model.py -v + +Run specific vendor: + pytest tests/integration_tests/chat_models/test_multi_model.py -k "llama" -v + pytest tests/integration_tests/chat_models/test_multi_model.py -k "cohere" -v + pytest tests/integration_tests/chat_models/test_multi_model.py -k "grok" -v +""" + +import os + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from langchain_oci.chat_models import ChatOCIGenAI + + +# ============================================================================= +# Model Configurations +# ============================================================================= + +# Meta Llama models +LLAMA_MODELS = [ + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "meta.llama-4-scout-17b-16e-instruct", + "meta.llama-3.3-70b-instruct", + "meta.llama-3.1-70b-instruct", +] + +# Cohere models +COHERE_MODELS = [ + "cohere.command-a-03-2025", + "cohere.command-r-plus-08-2024", + "cohere.command-r-08-2024", +] + +# xAI Grok models +GROK_MODELS = [ + "xai.grok-4-fast-non-reasoning", + "xai.grok-3-fast", + "xai.grok-3-mini-fast", +] + +# OpenAI models on OCI +OPENAI_MODELS = [ + "openai.gpt-oss-20b", + "openai.gpt-oss-120b", +] + +# All models for comprehensive testing +ALL_MODELS = LLAMA_MODELS[:2] + COHERE_MODELS[:1] + GROK_MODELS[:1] + + +def get_config(): + """Get test configuration.""" + compartment_id = os.environ.get("OCI_COMPARTMENT_ID") + if not compartment_id: + pytest.skip("OCI_COMPARTMENT_ID not set") + return { + "service_endpoint": os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + "compartment_id": compartment_id, + "auth_profile": os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + "auth_type": os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + } + + +def create_llm(model_id: str, **kwargs): + """Create ChatOCIGenAI instance for a model.""" + config = get_config() + default_kwargs = {"temperature": 0, "max_tokens": 256} + default_kwargs.update(kwargs) + return ChatOCIGenAI( + model_id=model_id, + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs=default_kwargs, + ) + + +# ============================================================================= +# Basic Invoke Tests - All Models +# ============================================================================= + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", ALL_MODELS) +def test_basic_invoke_all_models(model_id: str): + """Test basic invoke works for all supported models.""" + llm = create_llm(model_id) + response = llm.invoke([HumanMessage(content="Say 'hello' only")]) + + assert isinstance(response, AIMessage) + assert response.content is not None + assert len(response.content) > 0 + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", ALL_MODELS) +def test_system_message_all_models(model_id: str): + """Test system messages work for all models.""" + llm = create_llm(model_id) + messages = [ + SystemMessage(content="You only respond with the word 'YES'."), + HumanMessage(content="Do you understand?"), + ] + response = llm.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +# ============================================================================= +# Meta Llama Specific Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", LLAMA_MODELS[:2]) +def test_llama_tool_calling(model_id: str): + """Test tool calling on Llama models.""" + + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Sunny in {city}" + + llm = create_llm(model_id) + llm_with_tools = llm.bind_tools([get_weather]) + + response = llm_with_tools.invoke( + [HumanMessage(content="What's the weather in Paris?")] + ) + + assert isinstance(response, AIMessage) + assert len(response.tool_calls) >= 1 + assert response.tool_calls[0]["name"] == "get_weather" + + +@pytest.mark.requires("oci") +def test_llama4_parallel_tool_calling(): + """Test parallel tool calling on Llama 4 models.""" + + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Sunny in {city}" + + def get_time(city: str) -> str: + """Get current time in a city.""" + return f"12:00 PM in {city}" + + llm = create_llm("meta.llama-4-maverick-17b-128e-instruct-fp8") + llm_with_tools = llm.bind_tools( + [get_weather, get_time], parallel_tool_calls=True + ) + + response = llm_with_tools.invoke( + [HumanMessage(content="What's the weather and time in Tokyo?")] + ) + + assert isinstance(response, AIMessage) + assert len(response.tool_calls) >= 2 + tool_names = {tc["name"] for tc in response.tool_calls} + assert "get_weather" in tool_names + assert "get_time" in tool_names + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", LLAMA_MODELS[:2]) +def test_llama_structured_output(model_id: str): + """Test structured output on Llama models.""" + + class Answer(BaseModel): + answer: str = Field(description="The answer") + confidence: int = Field(description="Confidence 1-10", ge=1, le=10) + + llm = create_llm(model_id) + structured_llm = llm.with_structured_output(Answer) + + result = structured_llm.invoke("What is 2+2? Give answer and confidence.") + + assert isinstance(result, Answer) + assert "4" in result.answer + assert 1 <= result.confidence <= 10 + + +@pytest.mark.requires("oci") +def test_llama_streaming(): + """Test streaming on Llama models.""" + llm = create_llm("meta.llama-4-maverick-17b-128e-instruct-fp8") + + chunks = [] + for chunk in llm.stream([HumanMessage(content="Count 1 to 5")]): + chunks.append(chunk) + + assert len(chunks) > 0 + full_content = "".join(c.content for c in chunks if c.content) + assert len(full_content) > 0 + + +# ============================================================================= +# Cohere Specific Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", COHERE_MODELS[:2]) +def test_cohere_basic(model_id: str): + """Test basic functionality on Cohere models.""" + llm = create_llm(model_id) + response = llm.invoke([HumanMessage(content="What is 2+2?")]) + + assert isinstance(response, AIMessage) + assert "4" in response.content + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", COHERE_MODELS[:1]) +def test_cohere_tool_calling(model_id: str): + """Test tool calling on Cohere models.""" + # Note: Cohere tool calling has different response format + # This test verifies basic functionality + + def calculate(expression: str) -> str: + """Calculate a math expression and return the result.""" + return str(eval(expression)) + + llm = create_llm(model_id) + llm_with_tools = llm.bind_tools([calculate]) + + response = llm_with_tools.invoke( + [HumanMessage(content="Calculate 15 * 7")] + ) + + assert isinstance(response, AIMessage) + # Cohere may respond directly or call tool - both are valid + assert response.content or len(response.tool_calls) >= 1 + + +@pytest.mark.requires("oci") +def test_cohere_rejects_parallel_tool_calls(): + """Test that Cohere models reject parallel tool calls.""" + + def tool1(x: str) -> str: + """Tool 1.""" + return x + + llm = create_llm("cohere.command-a-03-2025") + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True) + + # Should raise error when trying to use parallel tool calls + with pytest.raises(ValueError, match="not supported for Cohere"): + llm_with_tools.invoke([HumanMessage(content="test")]) + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", COHERE_MODELS[:1]) +def test_cohere_structured_output(model_id: str): + """Test structured output on Cohere models.""" + + class Sentiment(BaseModel): + """Sentiment analysis result.""" + + text: str = Field(description="The analyzed text") + sentiment: str = Field(description="positive, negative, or neutral") + score: float = Field(description="Confidence score 0-1", ge=0, le=1) + + llm = create_llm(model_id) + structured_llm = llm.with_structured_output(Sentiment) + + result = structured_llm.invoke( + "Analyze sentiment: 'I love this product, it's amazing!'" + ) + + # Cohere structured output may return None in some cases + if result is not None: + assert isinstance(result, Sentiment) + assert result.sentiment.lower() in ["positive", "negative", "neutral"] + else: + pytest.skip("Cohere model returned None for structured output") + + +# ============================================================================= +# xAI Grok Specific Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", GROK_MODELS[:2]) +def test_grok_basic(model_id: str): + """Test basic functionality on Grok models.""" + llm = create_llm(model_id) + response = llm.invoke([HumanMessage(content="Hello, who are you?")]) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", GROK_MODELS[:1]) +def test_grok_tool_calling(model_id: str): + """Test tool calling on Grok models.""" + + def search_web(query: str) -> str: + """Search the web for information.""" + return f"Results for: {query}" + + llm = create_llm(model_id) + llm_with_tools = llm.bind_tools([search_web]) + + response = llm_with_tools.invoke( + [HumanMessage(content="Search for the latest AI news")] + ) + + assert isinstance(response, AIMessage) + # Grok may or may not call tools depending on its judgment + # Just verify it responds + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", GROK_MODELS[:1]) +def test_grok_structured_output(model_id: str): + """Test structured output on Grok models.""" + + class Summary(BaseModel): + """A summary of text.""" + + main_point: str = Field(description="The main point") + key_facts: list[str] = Field(description="Key facts from the text") + + llm = create_llm(model_id) + structured_llm = llm.with_structured_output(Summary) + + result = structured_llm.invoke( + "Summarize: The Earth orbits the Sun once per year." + ) + + # Grok may return None in some cases + if result is not None: + assert isinstance(result, Summary) + assert len(result.main_point) > 0 + else: + pytest.skip("Grok model returned None for structured output") + + +@pytest.mark.requires("oci") +def test_grok_streaming(): + """Test streaming on Grok models.""" + llm = create_llm("xai.grok-3-mini-fast") + + chunks = [] + for chunk in llm.stream([HumanMessage(content="Count 1-3")]): + chunks.append(chunk) + + assert len(chunks) > 0 + + +# ============================================================================= +# OpenAI on OCI Tests +# ============================================================================= + + +def create_openai_llm(model_id: str, **kwargs): + """Create ChatOCIGenAI for OpenAI models (uses max_completion_tokens).""" + config = get_config() + default_kwargs = {"temperature": 0, "max_completion_tokens": 256} + default_kwargs.update(kwargs) + return ChatOCIGenAI( + model_id=model_id, + service_endpoint=config["service_endpoint"], + compartment_id=config["compartment_id"], + auth_profile=config["auth_profile"], + auth_type=config["auth_type"], + model_kwargs=default_kwargs, + ) + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", OPENAI_MODELS) +def test_openai_basic(model_id: str): + """Test basic functionality on OpenAI models on OCI.""" + llm = create_openai_llm(model_id) + response = llm.invoke([HumanMessage(content="Say hello")]) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", OPENAI_MODELS) +def test_openai_system_message(model_id: str): + """Test system messages on OpenAI models.""" + llm = create_openai_llm(model_id) + messages = [ + SystemMessage(content="You only respond with the word 'YES'."), + HumanMessage(content="Do you understand?"), + ] + response = llm.invoke(messages) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", OPENAI_MODELS) +def test_openai_streaming(model_id: str): + """Test streaming on OpenAI models.""" + llm = create_openai_llm(model_id, max_completion_tokens=50) + + chunks = [] + for chunk in llm.stream([HumanMessage(content="Count 1-3")]): + chunks.append(chunk) + + # OpenAI streaming should return chunks + assert len(chunks) > 0 + # Content may be in chunk.content or chunk may have other attributes + # Just verify we got chunks back (streaming works) + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize("model_id", OPENAI_MODELS) +def test_openai_tool_calling(model_id: str): + """Test tool calling on OpenAI models.""" + + def get_info(topic: str) -> str: + """Get information about a topic.""" + return f"Info about {topic}" + + llm = create_openai_llm(model_id) + llm_with_tools = llm.bind_tools([get_info]) + + response = llm_with_tools.invoke( + [HumanMessage(content="Get info about Python")] + ) + + assert isinstance(response, AIMessage) + # OpenAI models should call the tool + assert len(response.tool_calls) >= 1 + assert response.tool_calls[0]["name"] == "get_info" + + +# ============================================================================= +# Cross-Model Comparison Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_same_prompt_different_models(): + """Test same prompt across different model vendors.""" + prompt = "What is the capital of France? Answer in one word." + + models_to_test = [ + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "cohere.command-a-03-2025", + "xai.grok-3-mini-fast", + ] + + responses = {} + for model_id in models_to_test: + try: + llm = create_llm(model_id) + response = llm.invoke([HumanMessage(content=prompt)]) + responses[model_id] = response.content + except Exception as e: + responses[model_id] = f"Error: {e}" + + # All should mention Paris + for model_id, content in responses.items(): + if not content.startswith("Error"): + assert "paris" in content.lower(), f"{model_id} didn't say Paris: {content}" + + +@pytest.mark.requires("oci") +def test_tool_calling_consistency(): + """Test tool calling works consistently across Llama models.""" + + def get_price(item: str) -> float: + """Get the price of an item in dollars.""" + return 9.99 + + # Only test Llama models - Cohere has different tool call format + models_with_tools = [ + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "meta.llama-4-scout-17b-16e-instruct", + ] + + for model_id in models_with_tools: + llm = create_llm(model_id) + llm_with_tools = llm.bind_tools([get_price]) + + response = llm_with_tools.invoke( + [HumanMessage(content="What's the price of apples?")] + ) + + assert isinstance(response, AIMessage), f"{model_id} failed" + assert len(response.tool_calls) >= 1, f"{model_id} didn't call tool" + assert response.tool_calls[0]["name"] == "get_price" + + +# ============================================================================= +# Model-Specific Features Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_llama3_vision_model_exists(): + """Verify vision-capable Llama model can be instantiated.""" + # Note: Actual vision testing would require image input support + llm = create_llm("meta.llama-3.2-90b-vision-instruct") + response = llm.invoke([HumanMessage(content="Describe what you can do")]) + + assert isinstance(response, AIMessage) + + +@pytest.mark.requires("oci") +def test_model_with_custom_kwargs(): + """Test models with custom generation parameters.""" + llm = create_llm( + "meta.llama-4-maverick-17b-128e-instruct-fp8", + temperature=0.7, + max_tokens=100, + top_p=0.9, + ) + + response = llm.invoke([HumanMessage(content="Write a creative sentence")]) + + assert isinstance(response, AIMessage) + assert response.content is not None + + +# ============================================================================= +# Performance / Latency Awareness Tests +# ============================================================================= + + +@pytest.mark.requires("oci") +def test_fast_models_respond_quickly(): + """Test that 'fast' model variants respond (existence check).""" + fast_models = [ + "xai.grok-3-fast", + "xai.grok-3-mini-fast", + ] + + for model_id in fast_models: + llm = create_llm(model_id, max_tokens=50) + response = llm.invoke([HumanMessage(content="Hi")]) + assert isinstance(response, AIMessage) + + +def main(): + """Manual test runner for debugging.""" + import sys + + print("=" * 60) + print("Multi-Model Integration Tests") + print("=" * 60) + + config = get_config() + print(f"\nEndpoint: {config['service_endpoint']}") + print(f"Profile: {config['auth_profile']}") + + # Test each vendor + test_models = [ + ("Meta Llama 4", "meta.llama-4-maverick-17b-128e-instruct-fp8"), + ("Cohere Command", "cohere.command-a-03-2025"), + ("xAI Grok", "xai.grok-3-mini-fast"), + ] + + for name, model_id in test_models: + print(f"\n--- Testing {name} ({model_id}) ---") + try: + llm = create_llm(model_id) + response = llm.invoke([HumanMessage(content="Say hello")]) + print(f"✓ Response: {response.content[:50]}...") + except Exception as e: + print(f"✗ Error: {e}") + + print("\n" + "=" * 60) + print("Manual tests complete") + + +if __name__ == "__main__": + main() diff --git a/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py b/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py new file mode 100644 index 0000000..ebbbbcb --- /dev/null +++ b/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +""" +Integration test for parallel tool calling feature. + +This script tests parallel tool calling with actual OCI GenAI API calls. + +Setup: + export OCI_COMPARTMENT_ID= + export OCI_GENAI_ENDPOINT= # optional + export OCI_CONFIG_PROFILE= # optional + export OCI_AUTH_TYPE= # optional + +Run with: + python test_parallel_tool_calling_integration.py +""" + +import logging +import os +import sys +import time + +from langchain_core.messages import HumanMessage + +from langchain_oci.chat_models import ChatOCIGenAI + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +def get_weather(city: str, unit: str = "fahrenheit") -> str: + """Get the current weather in a given location.""" + # Simulate API delay + time.sleep(0.5) + return f"Weather in {city}: Sunny, 72°{unit[0].upper()}" + + +def calculate_tip(amount: float, percent: float = 15.0) -> float: + """Calculate tip amount.""" + # Simulate API delay + time.sleep(0.5) + return round(amount * (percent / 100), 2) + + +def get_population(city: str) -> int: + """Get the population of a city.""" + # Simulate API delay + time.sleep(0.5) + populations = { + "tokyo": 14000000, + "new york": 8000000, + "london": 9000000, + "paris": 2000000, + "chicago": 2700000, + "los angeles": 4000000, + } + return populations.get(city.lower(), 1000000) + + +def test_parallel_tool_calling_enabled(): + """Test parallel tool calling with parallel_tool_calls=True in bind_tools.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 1: Parallel Tool Calling ENABLED (via bind_tools)") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools with parallel_tool_calls=True + chat_with_tools = chat.bind_tools( + [get_weather, calculate_tip, get_population], parallel_tool_calls=True + ) + + # Invoke with query that needs weather info + logging.info("\nQuery: 'What's the weather in New York City?'") + + start_time = time.time() + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in New York City?")] + ) + elapsed_time = time.time() - start_time + + logging.info(f"\nResponse time: {elapsed_time:.2f}s") + content = response.content[:200] if response.content else "(empty)" + logging.info(f"Response content: {content}...") + logging.info(f"Tool calls count: {len(response.tool_calls)}") + + if response.tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(response.tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + else: + logging.info("\n⚠️ No tool calls in response.tool_calls") + logging.info(f"Additional kwargs: {response.additional_kwargs.keys()}") + + # Verify we got tool calls + count = len(response.tool_calls) + assert count >= 1, f"Should have at least one tool call, got {count}" + + # Verify parallel_tool_calls was set + logging.info("\n✓ TEST 1 PASSED: Parallel tool calling enabled and working") + return elapsed_time + + +def test_parallel_tool_calling_disabled(): + """Test tool calling with parallel_tool_calls=False (sequential).""" + logging.info("\n" + "=" * 80) + logging.info("TEST 2: Parallel Tool Calling DISABLED (Sequential)") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools without parallel_tool_calls (defaults to sequential) + chat_with_tools = chat.bind_tools([get_weather, calculate_tip, get_population]) + + # Same query as test 1 + logging.info("\nQuery: 'What's the weather in New York City?'") + + start_time = time.time() + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in New York City?")] + ) + elapsed_time = time.time() - start_time + + logging.info(f"\nResponse time: {elapsed_time:.2f}s") + content = response.content[:200] if response.content else "(empty)" + logging.info(f"Response content: {content}...") + logging.info(f"Tool calls count: {len(response.tool_calls)}") + + if response.tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(response.tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + + # Verify we got tool calls + count = len(response.tool_calls) + assert count >= 1, f"Should have at least one tool call, got {count}" + + logging.info("\n✓ TEST 2 PASSED: Sequential tool calling works") + return elapsed_time + + +def test_multiple_tool_calls(): + """Test query that should trigger multiple tool calls.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 3: Multiple Tool Calls Query") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools with parallel_tool_calls=True + chat_with_tools = chat.bind_tools( + [get_weather, get_population], parallel_tool_calls=True + ) + + logging.info("\nQuery: 'What's the weather and population of Tokyo?'") + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather and population of Tokyo?")] + ) + + logging.info(f"\nResponse content: {response.content}") + logging.info(f"Tool calls count: {len(response.tool_calls)}") + + if response.tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(response.tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + + logging.info("\n✓ TEST 3 PASSED: Multiple tool calls query works") + + +def test_cohere_model_error(): + """Test that Cohere models raise an error with parallel_tool_calls.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 4: Cohere Model Error Handling") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id="cohere.command-r-plus", + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + ) + + # Try to enable parallel tool calls with Cohere (should fail) + chat_with_tools = chat.bind_tools([get_weather], parallel_tool_calls=True) + + logging.info("\nAttempting to use parallel_tool_calls with Cohere model...") + + try: + _ = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in Paris?")] + ) + logging.info("❌ TEST FAILED: Should have raised ValueError") + return False + except ValueError as e: + if "not supported for Cohere" in str(e): + logging.info(f"\n✓ Correctly raised error: {e}") + logging.info("\n✓ TEST 4 PASSED: Cohere validation works") + return True + else: + logging.info(f"❌ Wrong error: {e}") + return False + + +def main(): + logging.info("=" * 80) + logging.info("PARALLEL TOOL CALLING INTEGRATION TESTS") + logging.info("=" * 80) + + # Check required env vars + if not os.environ.get("OCI_COMPARTMENT_ID"): + logging.info("\n❌ ERROR: OCI_COMPARTMENT_ID environment variable not set") + logging.info("Please set: export OCI_COMPARTMENT_ID=") + sys.exit(1) + + logging.info("\nUsing configuration:") + model_id = os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ) + logging.info(f" Model: {model_id}") + endpoint = os.environ.get("OCI_GENAI_ENDPOINT", "default") + logging.info(f" Endpoint: {endpoint}") + profile = os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT") + logging.info(f" Profile: {profile}") + logging.info(f" Compartment: {os.environ.get('OCI_COMPARTMENT_ID')[:25]}...") + + results = [] + + try: + # Run tests + parallel_time = test_parallel_tool_calling_enabled() + results.append(("Parallel Enabled", True)) + + sequential_time = test_parallel_tool_calling_disabled() + results.append(("Sequential (Disabled)", True)) + + test_multiple_tool_calls() + results.append(("Multiple Tool Calls", True)) + + cohere_test = test_cohere_model_error() + results.append(("Cohere Validation", cohere_test)) + + # Print summary + logging.info("\n" + "=" * 80) + logging.info("TEST SUMMARY") + logging.info("=" * 80) + + for test_name, passed in results: + status = "✓ PASSED" if passed else "✗ FAILED" + logging.info(f"{status}: {test_name}") + + passed_count = sum(1 for _, passed in results if passed) + total_count = len(results) + + logging.info(f"\nTotal: {passed_count}/{total_count} tests passed") + + # Performance comparison + if parallel_time and sequential_time: + logging.info("\n" + "=" * 80) + logging.info("PERFORMANCE COMPARISON") + logging.info("=" * 80) + logging.info(f"Parallel: {parallel_time:.2f}s") + logging.info(f"Sequential: {sequential_time:.2f}s") + if sequential_time > 0: + speedup = sequential_time / parallel_time + logging.info(f"Speedup: {speedup:.2f}×") + + if passed_count == total_count: + logging.info("\n🎉 ALL TESTS PASSED!") + return 0 + else: + logging.info(f"\n⚠️ {total_count - passed_count} test(s) failed") + return 1 + + except Exception as e: + logging.info(f"\n❌ ERROR: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py b/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py index 9c07763..03e27ea 100644 --- a/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py +++ b/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py @@ -53,7 +53,7 @@ import os import pytest -from langchain.tools import StructuredTool +from langchain_core.tools import StructuredTool from langchain_core.messages import HumanMessage, SystemMessage from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode diff --git a/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py b/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py new file mode 100644 index 0000000..b796560 --- /dev/null +++ b/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py @@ -0,0 +1,234 @@ +"""Unit tests for parallel tool calling feature.""" + +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage + +from langchain_oci.chat_models import ChatOCIGenAI + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_explicit_true(): + """Test parallel_tool_calls=True in bind_tools.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + def tool2(x: int) -> int: + """Tool 2.""" + return x * 2 + + llm_with_tools = llm.bind_tools([tool1, tool2], parallel_tool_calls=True) + + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_explicit_false(): + """Test parallel_tool_calls=False in bind_tools.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=False) + + # When explicitly False, should not set the parameter + assert "is_parallel_tool_calls" not in llm_with_tools.kwargs + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_default_none(): + """Test that bind_tools without parallel_tool_calls doesn't enable it.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Don't specify parallel_tool_calls in bind_tools + llm_with_tools = llm.bind_tools([tool1]) + + # Should not have is_parallel_tool_calls set + assert "is_parallel_tool_calls" not in llm_with_tools.kwargs + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_passed_to_oci_api_meta(): + """Test that is_parallel_tool_calls is passed to OCI API for Meta models.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Weather in {city}" + + llm_with_tools = llm.bind_tools([get_weather], parallel_tool_calls=True) + + # Prepare a request + request = llm_with_tools._prepare_request( + [HumanMessage(content="What's the weather?")], + stop=None, + stream=False, + **llm_with_tools.kwargs, + ) + + # Verify is_parallel_tool_calls is in the request + assert hasattr(request.chat_request, "is_parallel_tool_calls") + assert request.chat_request.is_parallel_tool_calls is True + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_cohere_raises_error(): + """Test that Cohere models raise error for parallel tool calls.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="cohere.command-r-plus", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True) + + # Should raise ValueError when trying to prepare request + with pytest.raises(ValueError, match="not supported for Cohere"): + llm_with_tools._prepare_request( + [HumanMessage(content="test")], + stop=None, + stream=False, + **llm_with_tools.kwargs, + ) + + +@pytest.mark.requires("oci") +def test_version_filter_llama_3_0_blocked(): + """Test that Llama 3.0 models are blocked from parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should raise ValueError when trying to enable parallel tool calling + with pytest.raises(ValueError, match="Llama 4\\+"): + llm.bind_tools([tool1], parallel_tool_calls=True) + + +@pytest.mark.requires("oci") +def test_version_filter_llama_3_1_blocked(): + """Test that Llama 3.1 models are blocked from parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.1-70b-instruct", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should raise ValueError + with pytest.raises(ValueError, match="Llama 4\\+"): + llm.bind_tools([tool1], parallel_tool_calls=True) + + +@pytest.mark.requires("oci") +def test_version_filter_llama_3_2_blocked(): + """Test that Llama 3.2 models are blocked from parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-3.2-11b-vision-instruct", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should raise ValueError + with pytest.raises(ValueError, match="Llama 4\\+"): + llm.bind_tools([tool1], parallel_tool_calls=True) + + +@pytest.mark.requires("oci") +def test_version_filter_llama_3_3_blocked(): + """Test that Llama 3.3 models are blocked from parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should raise ValueError - Llama 3.3 doesn't actually support parallel calls + with pytest.raises(ValueError, match="Llama 4\\+"): + llm.bind_tools([tool1], parallel_tool_calls=True) + + +@pytest.mark.requires("oci") +def test_version_filter_llama_4_allowed(): + """Test that Llama 4 models are allowed parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should NOT raise ValueError + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True) + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True + + +@pytest.mark.requires("oci") +def test_version_filter_other_models_allowed(): + """Test that other GenericChatRequest models are allowed parallel tool calling.""" + oci_gen_ai_client = MagicMock() + + # Test with xAI Grok + llm_grok = ChatOCIGenAI(model_id="xai.grok-4-fast", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should NOT raise ValueError for Grok + llm_with_tools = llm_grok.bind_tools([tool1], parallel_tool_calls=True) + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True + + +@pytest.mark.requires("oci") +def test_version_filter_supports_parallel_tool_calls_method(): + """Test the _supports_parallel_tool_calls method directly.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + # Test various model IDs + model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" + assert llm._supports_parallel_tool_calls(model_id) is True + # Llama 3.3 NOT supported + assert llm._supports_parallel_tool_calls("meta.llama-3.3-70b-instruct") is False + model_id = "meta.llama-3.2-11b-vision-instruct" + assert llm._supports_parallel_tool_calls(model_id) is False + assert llm._supports_parallel_tool_calls("meta.llama-3.1-70b-instruct") is False + assert llm._supports_parallel_tool_calls("meta.llama-3-70b-instruct") is False + assert llm._supports_parallel_tool_calls("cohere.command-r-plus") is False + assert llm._supports_parallel_tool_calls("xai.grok-4-fast") is True + assert llm._supports_parallel_tool_calls("openai.gpt-4") is True + assert llm._supports_parallel_tool_calls("mistral.mistral-large") is True