diff --git a/libs/oci/README.md b/libs/oci/README.md index 76829a1..7bfcd9b 100644 --- a/libs/oci/README.md +++ b/libs/oci/README.md @@ -62,7 +62,7 @@ embeddings.embed_query("What is the meaning of life?") ``` ### 4. Use Structured Output -`ChatOCIGenAI` supports structured output. +`ChatOCIGenAI` supports structured output. **Note:** The default method is `function_calling`. If default method returns `None` (e.g. for Gemini models), try `json_schema` or `json_mode`. @@ -126,6 +126,27 @@ messages = [ response = client.invoke(messages) ``` +### 6. Use Parallel Tool Calling (Meta/Llama 4+ models only) +Enable parallel tool calling to execute multiple tools simultaneously, improving performance for multi-tool workflows. + +```python +from langchain_oci import ChatOCIGenAI + +llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + compartment_id="MY_COMPARTMENT_ID", +) + +# Enable parallel tool calling in bind_tools +llm_with_tools = llm.bind_tools( + [get_weather, calculate_tip, get_population], + parallel_tool_calls=True # Tools can execute simultaneously +) +``` + +**Note:** Parallel tool calling is only supported for Llama 4+ models. Llama 3.x (including 3.3) and Cohere models will raise an error if this parameter is used. + ## OCI Data Science Model Deployment Examples diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 1afca28..00ea3df 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -209,6 +209,18 @@ def process_stream_tool_calls( """Process streaming tool calls from event data into chunks.""" ... + @property + def supports_parallel_tool_calls(self) -> bool: + """Whether this provider supports parallel tool calling. + + Parallel tool calling allows the model to call multiple tools + simultaneously in a single response. + + Returns: + bool: True if parallel tool calling is supported, False otherwise. + """ + return False + class CohereProvider(Provider): """Provider implementation for Cohere.""" @@ -363,6 +375,14 @@ def messages_to_oci_params( This includes conversion of chat history and tool call results. """ + # Cohere models don't support parallel tool calls + if kwargs.get("is_parallel_tool_calls"): + raise ValueError( + "Parallel tool calls are not supported for Cohere models. " + "This feature is only available for models using GenericChatRequest " + "(Meta, Llama, xAI Grok, OpenAI, Mistral)." + ) + is_force_single_step = kwargs.get("is_force_single_step", False) oci_chat_history = [] @@ -585,6 +605,11 @@ class GenericProvider(Provider): stop_sequence_key: str = "stop" + @property + def supports_parallel_tool_calls(self) -> bool: + """GenericProvider models support parallel tool calling.""" + return True + def __init__(self) -> None: from oci.generative_ai_inference import models @@ -851,6 +876,10 @@ def _should_allow_more_tool_calls( result["tool_choice"] = self.oci_tool_choice_none() # else: Allow model to decide (default behavior) + # Add parallel tool calls support (GenericChatRequest models) + if "is_parallel_tool_calls" in kwargs: + result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] + return result def _process_message_content( @@ -916,23 +945,9 @@ def convert_to_oci_tool( Raises: ValueError: If the tool type is not supported. """ - if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): - as_json_schema_function = convert_to_openai_function(tool) - parameters = as_json_schema_function.get("parameters", {}) + # Check BaseTool first since it's callable but needs special handling + if isinstance(tool, BaseTool): return self.oci_function_definition( - name=as_json_schema_function.get("name"), - description=as_json_schema_function.get( - "description", - as_json_schema_function.get("name"), - ), - parameters={ - "type": "object", - "properties": parameters.get("properties", {}), - "required": parameters.get("required", []), - }, - ) - elif isinstance(tool, BaseTool): # type: ignore[unreachable] - return self.oci_function_definition( # type: ignore[unreachable] name=tool.name, description=OCIUtils.remove_signature_from_tool_description( tool.name, tool.description @@ -953,6 +968,21 @@ def convert_to_oci_tool( ], }, ) + if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + return self.oci_function_definition( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameters={ + "type": "object", + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + ) raise ValueError( f"Unsupported tool type {type(tool)}. " "Tool must be passed in as a BaseTool " @@ -1211,6 +1241,7 @@ def bind_tools( tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "required", "any"], bool] ] = None, + parallel_tool_calls: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -1231,6 +1262,11 @@ def bind_tools( {"type": "function", "function": {"name": <>}}: calls <> tool. - False or None: no effect, default Meta behavior. + parallel_tool_calls: Whether to enable parallel function calling. + If True, the model can call multiple tools simultaneously. + If False or None (default), tools are called sequentially. + Supported for models using GenericChatRequest (Meta, xAI Grok, + OpenAI, Mistral). Not supported for Cohere models. kwargs: Any additional parameters are passed directly to :meth:`~langchain_oci.chat_models.oci_generative_ai.ChatOCIGenAI.bind`. """ @@ -1240,6 +1276,15 @@ def bind_tools( if tool_choice is not None: kwargs["tool_choice"] = self._provider.process_tool_choice(tool_choice) + # Add parallel tool calls support (only when explicitly enabled) + if parallel_tool_calls: + if not self._provider.supports_parallel_tool_calls: + raise ValueError( + "Parallel tool calls not supported for this provider. " + "Only GenericChatRequest models support parallel tool calling." + ) + kwargs["is_parallel_tool_calls"] = True + return super().bind(tools=formatted_tools, **kwargs) def with_structured_output( diff --git a/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py b/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py new file mode 100644 index 0000000..9a9ceb4 --- /dev/null +++ b/libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +Integration test for parallel tool calling feature. + +This script tests parallel tool calling with actual OCI GenAI API calls. + +Setup: + export OCI_COMPARTMENT_ID= + export OCI_GENAI_ENDPOINT= # optional + export OCI_CONFIG_PROFILE= # optional + export OCI_AUTH_TYPE= # optional + +Run with: + python test_parallel_tool_calling_integration.py +""" + +import logging +import os +import sys +import time + +from langchain_core.messages import HumanMessage + +from langchain_oci.chat_models import ChatOCIGenAI + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +def get_weather(city: str, unit: str = "fahrenheit") -> str: + """Get the current weather in a given location.""" + # Simulate API delay + time.sleep(0.5) + return f"Weather in {city}: Sunny, 72°{unit[0].upper()}" + + +def calculate_tip(amount: float, percent: float = 15.0) -> float: + """Calculate tip amount.""" + # Simulate API delay + time.sleep(0.5) + return round(amount * (percent / 100), 2) + + +def get_population(city: str) -> int: + """Get the population of a city.""" + # Simulate API delay + time.sleep(0.5) + populations = { + "tokyo": 14000000, + "new york": 8000000, + "london": 9000000, + "paris": 2000000, + "chicago": 2700000, + "los angeles": 4000000, + } + return populations.get(city.lower(), 1000000) + + +def test_parallel_tool_calling_enabled(): + """Test parallel tool calling with parallel_tool_calls=True in bind_tools.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 1: Parallel Tool Calling ENABLED (via bind_tools)") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools with parallel_tool_calls=True + chat_with_tools = chat.bind_tools( + [get_weather, calculate_tip, get_population], parallel_tool_calls=True + ) + + # Invoke with query that needs weather info + logging.info("\nQuery: 'What's the weather in New York City?'") + + start_time = time.time() + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in New York City?")] + ) + elapsed_time = time.time() - start_time + + logging.info(f"\nResponse time: {elapsed_time:.2f}s") + content = response.content[:200] if response.content else "(empty)" + logging.info(f"Response content: {content}...") + # AIMessage has tool_calls attribute at runtime + tool_calls = getattr(response, "tool_calls", []) + logging.info(f"Tool calls count: {len(tool_calls)}") + + if tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + else: + logging.info("\n⚠️ No tool calls in response.tool_calls") + logging.info(f"Additional kwargs: {response.additional_kwargs.keys()}") + + # Verify we got tool calls + count = len(tool_calls) + assert count >= 1, f"Should have at least one tool call, got {count}" + + # Verify parallel_tool_calls was set + logging.info("\n✓ TEST 1 PASSED: Parallel tool calling enabled and working") + return elapsed_time + + +def test_parallel_tool_calling_disabled(): + """Test tool calling with parallel_tool_calls=False (sequential).""" + logging.info("\n" + "=" * 80) + logging.info("TEST 2: Parallel Tool Calling DISABLED (Sequential)") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools without parallel_tool_calls (defaults to sequential) + chat_with_tools = chat.bind_tools([get_weather, calculate_tip, get_population]) + + # Same query as test 1 + logging.info("\nQuery: 'What's the weather in New York City?'") + + start_time = time.time() + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather in New York City?")] + ) + elapsed_time = time.time() - start_time + + logging.info(f"\nResponse time: {elapsed_time:.2f}s") + content = response.content[:200] if response.content else "(empty)" + logging.info(f"Response content: {content}...") + # AIMessage has tool_calls attribute at runtime + tool_calls = getattr(response, "tool_calls", []) + logging.info(f"Tool calls count: {len(tool_calls)}") + + if tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + + # Verify we got tool calls + count = len(tool_calls) + assert count >= 1, f"Should have at least one tool call, got {count}" + + logging.info("\n✓ TEST 2 PASSED: Sequential tool calling works") + return elapsed_time + + +def test_multiple_tool_calls(): + """Test query that should trigger multiple tool calls.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 3: Multiple Tool Calls Query") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id=os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ), + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + model_kwargs={"temperature": 0, "max_tokens": 500}, + ) + + # Bind tools with parallel_tool_calls=True + chat_with_tools = chat.bind_tools( + [get_weather, get_population], parallel_tool_calls=True + ) + + logging.info("\nQuery: 'What's the weather and population of Tokyo?'") + + response = chat_with_tools.invoke( + [HumanMessage(content="What's the weather and population of Tokyo?")] + ) + + logging.info(f"\nResponse content: {response.content}") + # AIMessage has tool_calls attribute at runtime + tool_calls = getattr(response, "tool_calls", []) + logging.info(f"Tool calls count: {len(tool_calls)}") + + if tool_calls: + logging.info("\nTool calls:") + for i, tc in enumerate(tool_calls, 1): + logging.info(f" {i}. {tc['name']}({tc['args']})") + + logging.info("\n✓ TEST 3 PASSED: Multiple tool calls query works") + + +def test_cohere_model_error(): + """Test that Cohere models raise an error with parallel_tool_calls.""" + logging.info("\n" + "=" * 80) + logging.info("TEST 4: Cohere Model Error Handling") + logging.info("=" * 80) + + chat = ChatOCIGenAI( + model_id="cohere.command-r-plus", + service_endpoint=os.environ.get( + "OCI_GENAI_ENDPOINT", + "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ), + compartment_id=os.environ.get("OCI_COMPARTMENT_ID"), + auth_profile=os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT"), + auth_type=os.environ.get("OCI_AUTH_TYPE", "SECURITY_TOKEN"), + ) + + logging.info("\nAttempting to use parallel_tool_calls with Cohere model...") + + # Try to enable parallel tool calls with Cohere (should fail at bind_tools) + try: + chat.bind_tools([get_weather], parallel_tool_calls=True) + logging.info("❌ TEST FAILED: Should have raised ValueError") + return False + except ValueError as e: + if "not supported" in str(e): + logging.info(f"\n✓ Correctly raised error: {e}") + logging.info("\n✓ TEST 4 PASSED: Cohere validation works") + return True + else: + logging.info(f"❌ Wrong error: {e}") + return False + + +def main(): + logging.info("=" * 80) + logging.info("PARALLEL TOOL CALLING INTEGRATION TESTS") + logging.info("=" * 80) + + # Check required env vars + if not os.environ.get("OCI_COMPARTMENT_ID"): + logging.info("\n❌ ERROR: OCI_COMPARTMENT_ID environment variable not set") + logging.info("Please set: export OCI_COMPARTMENT_ID=") + sys.exit(1) + + logging.info("\nUsing configuration:") + model_id = os.environ.get( + "OCI_MODEL_ID", "meta.llama-4-maverick-17b-128e-instruct-fp8" + ) + logging.info(f" Model: {model_id}") + endpoint = os.environ.get("OCI_GENAI_ENDPOINT", "default") + logging.info(f" Endpoint: {endpoint}") + profile = os.environ.get("OCI_CONFIG_PROFILE", "DEFAULT") + logging.info(f" Profile: {profile}") + compartment_id = os.environ.get("OCI_COMPARTMENT_ID", "") + logging.info(f" Compartment: {compartment_id[:25]}...") + + results = [] + + try: + # Run tests + parallel_time = test_parallel_tool_calling_enabled() + results.append(("Parallel Enabled", True)) + + sequential_time = test_parallel_tool_calling_disabled() + results.append(("Sequential (Disabled)", True)) + + test_multiple_tool_calls() + results.append(("Multiple Tool Calls", True)) + + cohere_test = test_cohere_model_error() + results.append(("Cohere Validation", cohere_test)) + + # Print summary + logging.info("\n" + "=" * 80) + logging.info("TEST SUMMARY") + logging.info("=" * 80) + + for test_name, passed in results: + status = "✓ PASSED" if passed else "✗ FAILED" + logging.info(f"{status}: {test_name}") + + passed_count = sum(1 for _, passed in results if passed) + total_count = len(results) + + logging.info(f"\nTotal: {passed_count}/{total_count} tests passed") + + # Performance comparison + if parallel_time and sequential_time: + logging.info("\n" + "=" * 80) + logging.info("PERFORMANCE COMPARISON") + logging.info("=" * 80) + logging.info(f"Parallel: {parallel_time:.2f}s") + logging.info(f"Sequential: {sequential_time:.2f}s") + if sequential_time > 0: + speedup = sequential_time / parallel_time + logging.info(f"Speedup: {speedup:.2f}×") + + if passed_count == total_count: + logging.info("\n🎉 ALL TESTS PASSED!") + return 0 + else: + logging.info(f"\n⚠️ {total_count - passed_count} test(s) failed") + return 1 + + except Exception as e: + logging.info(f"\n❌ ERROR: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py b/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py index 9c07763..cb65206 100644 --- a/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py +++ b/libs/oci/tests/integration_tests/chat_models/test_tool_calling.py @@ -53,8 +53,8 @@ import os import pytest -from langchain.tools import StructuredTool from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tools import StructuredTool from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode diff --git a/libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py b/libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py index 68b7e7a..e2d3b8a 100644 --- a/libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py +++ b/libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py @@ -1,11 +1,11 @@ """Test Chat model for OCI Data Science Model Deployment Endpoint.""" import sys -from typing import Any, AsyncGenerator, Dict, Generator +from typing import Any, AsyncGenerator, Dict, Generator, Optional from unittest import mock import pytest -from langchain_core.messages import AIMessage, AIMessageChunk +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk from requests.exceptions import HTTPError from langchain_oci.chat_models import ( @@ -145,19 +145,18 @@ def test_stream_vllm(*args: Any) -> None: endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) assert llm._headers().get("route") == CONST_COMPLETION_ROUTE - output = None + output: Optional[BaseMessageChunk] = None count = 0 for chunk in llm.stream(CONST_PROMPT): assert isinstance(chunk, AIMessageChunk) if output is None: output = chunk else: - output += chunk # type: ignore[assignment] + output = output + chunk count += 1 assert count == 5 assert output is not None - if output is not None: - assert str(output.content).strip() == CONST_COMPLETION + assert str(output.content).strip() == CONST_COMPLETION async def mocked_async_streaming_response( diff --git a/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py b/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py new file mode 100644 index 0000000..46357b9 --- /dev/null +++ b/libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py @@ -0,0 +1,166 @@ +"""Unit tests for parallel tool calling feature.""" + +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage + +from langchain_oci.chat_models import ChatOCIGenAI + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_explicit_true(): + """Test parallel_tool_calls=True in bind_tools.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + def tool2(x: int) -> int: + """Tool 2.""" + return x * 2 + + llm_with_tools = llm.bind_tools([tool1, tool2], parallel_tool_calls=True) + + # RunnableBinding has kwargs attribute at runtime + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True # type: ignore[attr-defined] + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_explicit_false(): + """Test parallel_tool_calls=False in bind_tools.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=False) + + # When explicitly False, should not set the parameter + # RunnableBinding has kwargs attribute at runtime + assert "is_parallel_tool_calls" not in llm_with_tools.kwargs # type: ignore[attr-defined] + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_bind_tools_default_none(): + """Test that bind_tools without parallel_tool_calls doesn't enable it.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Don't specify parallel_tool_calls in bind_tools + llm_with_tools = llm.bind_tools([tool1]) + + # Should not have is_parallel_tool_calls set + # RunnableBinding has kwargs attribute at runtime + assert "is_parallel_tool_calls" not in llm_with_tools.kwargs # type: ignore[attr-defined] + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_passed_to_oci_api_meta(): + """Test that is_parallel_tool_calls is passed to OCI API for Meta models.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def get_weather(city: str) -> str: + """Get weather for a city.""" + return f"Weather in {city}" + + llm_with_tools = llm.bind_tools([get_weather], parallel_tool_calls=True) + + # Prepare a request + # RunnableBinding has _prepare_request and kwargs attributes at runtime + request = llm_with_tools._prepare_request( # type: ignore[attr-defined] + [HumanMessage(content="What's the weather?")], + stop=None, + stream=False, + **llm_with_tools.kwargs, # type: ignore[attr-defined] + ) + + # Verify is_parallel_tool_calls is in the request + assert hasattr(request.chat_request, "is_parallel_tool_calls") + assert request.chat_request.is_parallel_tool_calls is True + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_cohere_raises_error(): + """Test that Cohere models raise error for parallel tool calls at bind_tools.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id="cohere.command-r-plus", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should raise ValueError at bind_tools time (not at request time) + with pytest.raises(ValueError, match="not supported"): + llm.bind_tools([tool1], parallel_tool_calls=True) + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_meta_allowed(): + """Test that Meta models are allowed parallel tool calling.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should NOT raise ValueError + llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=True) + # RunnableBinding has kwargs attribute at runtime + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True # type: ignore[attr-defined] + + +@pytest.mark.requires("oci") +def test_parallel_tool_calls_other_generic_models_allowed(): + """Test that other GenericChatRequest models are allowed parallel tool calling.""" + oci_gen_ai_client = MagicMock() + + # Test with xAI Grok (uses GenericProvider) + llm_grok = ChatOCIGenAI(model_id="xai.grok-4-fast", client=oci_gen_ai_client) + + def tool1(x: int) -> int: + """Tool 1.""" + return x + 1 + + # Should NOT raise ValueError for Grok + llm_with_tools = llm_grok.bind_tools([tool1], parallel_tool_calls=True) + # RunnableBinding has kwargs attribute at runtime + assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True # type: ignore[attr-defined] + + +@pytest.mark.requires("oci") +def test_provider_supports_parallel_tool_calls_property(): + """Test the provider supports_parallel_tool_calls property.""" + oci_gen_ai_client = MagicMock() + + # Meta model uses GenericProvider which supports parallel tool calls + llm_meta = ChatOCIGenAI( + model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client + ) + assert llm_meta._provider.supports_parallel_tool_calls is True + + # Cohere model uses CohereProvider which does NOT support parallel tool calls + llm_cohere = ChatOCIGenAI( + model_id="cohere.command-r-plus", client=oci_gen_ai_client + ) + assert llm_cohere._provider.supports_parallel_tool_calls is False