11"""Unit tests for parallel tool calling feature."""
2+
23from unittest .mock import MagicMock
34
45import pytest
@@ -14,7 +15,7 @@ def test_parallel_tool_calls_class_level():
1415 llm = ChatOCIGenAI (
1516 model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
1617 parallel_tool_calls = True ,
17- client = oci_gen_ai_client
18+ client = oci_gen_ai_client ,
1819 )
1920 assert llm .parallel_tool_calls is True
2021
@@ -24,8 +25,7 @@ def test_parallel_tool_calls_default_false():
2425 """Test that parallel_tool_calls defaults to False."""
2526 oci_gen_ai_client = MagicMock ()
2627 llm = ChatOCIGenAI (
27- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
28- client = oci_gen_ai_client
28+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
2929 )
3030 assert llm .parallel_tool_calls is False
3131
@@ -35,8 +35,7 @@ def test_parallel_tool_calls_bind_tools_explicit_true():
3535 """Test parallel_tool_calls=True in bind_tools."""
3636 oci_gen_ai_client = MagicMock ()
3737 llm = ChatOCIGenAI (
38- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
39- client = oci_gen_ai_client
38+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
4039 )
4140
4241 def tool1 (x : int ) -> int :
@@ -47,10 +46,7 @@ def tool2(x: int) -> int:
4746 """Tool 2."""
4847 return x * 2
4948
50- llm_with_tools = llm .bind_tools (
51- [tool1 , tool2 ],
52- parallel_tool_calls = True
53- )
49+ llm_with_tools = llm .bind_tools ([tool1 , tool2 ], parallel_tool_calls = True )
5450
5551 assert llm_with_tools .kwargs .get ("is_parallel_tool_calls" ) is True
5652
@@ -60,18 +56,14 @@ def test_parallel_tool_calls_bind_tools_explicit_false():
6056 """Test parallel_tool_calls=False in bind_tools."""
6157 oci_gen_ai_client = MagicMock ()
6258 llm = ChatOCIGenAI (
63- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
64- client = oci_gen_ai_client
59+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
6560 )
6661
6762 def tool1 (x : int ) -> int :
6863 """Tool 1."""
6964 return x + 1
7065
71- llm_with_tools = llm .bind_tools (
72- [tool1 ],
73- parallel_tool_calls = False
74- )
66+ llm_with_tools = llm .bind_tools ([tool1 ], parallel_tool_calls = False )
7567
7668 # When explicitly False, should not set the parameter
7769 assert "is_parallel_tool_calls" not in llm_with_tools .kwargs
@@ -84,7 +76,7 @@ def test_parallel_tool_calls_bind_tools_uses_class_default():
8476 llm = ChatOCIGenAI (
8577 model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
8678 parallel_tool_calls = True , # Set class default
87- client = oci_gen_ai_client
79+ client = oci_gen_ai_client ,
8880 )
8981
9082 def tool1 (x : int ) -> int :
@@ -105,7 +97,7 @@ def test_parallel_tool_calls_bind_tools_overrides_class_default():
10597 llm = ChatOCIGenAI (
10698 model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
10799 parallel_tool_calls = True , # Set class default to True
108- client = oci_gen_ai_client
100+ client = oci_gen_ai_client ,
109101 )
110102
111103 def tool1 (x : int ) -> int :
@@ -124,8 +116,7 @@ def test_parallel_tool_calls_passed_to_oci_api_meta():
124116 """Test that is_parallel_tool_calls is passed to OCI API for Meta models."""
125117 oci_gen_ai_client = MagicMock ()
126118 llm = ChatOCIGenAI (
127- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
128- client = oci_gen_ai_client
119+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
129120 )
130121
131122 def get_weather (city : str ) -> str :
@@ -139,22 +130,19 @@ def get_weather(city: str) -> str:
139130 [HumanMessage (content = "What's the weather?" )],
140131 stop = None ,
141132 stream = False ,
142- ** llm_with_tools .kwargs
133+ ** llm_with_tools .kwargs ,
143134 )
144135
145136 # Verify is_parallel_tool_calls is in the request
146- assert hasattr (request .chat_request , ' is_parallel_tool_calls' )
137+ assert hasattr (request .chat_request , " is_parallel_tool_calls" )
147138 assert request .chat_request .is_parallel_tool_calls is True
148139
149140
150141@pytest .mark .requires ("oci" )
151142def test_parallel_tool_calls_cohere_raises_error ():
152143 """Test that Cohere models raise error for parallel tool calls."""
153144 oci_gen_ai_client = MagicMock ()
154- llm = ChatOCIGenAI (
155- model_id = "cohere.command-r-plus" ,
156- client = oci_gen_ai_client
157- )
145+ llm = ChatOCIGenAI (model_id = "cohere.command-r-plus" , client = oci_gen_ai_client )
158146
159147 def tool1 (x : int ) -> int :
160148 """Tool 1."""
@@ -168,7 +156,7 @@ def tool1(x: int) -> int:
168156 [HumanMessage (content = "test" )],
169157 stop = None ,
170158 stream = False ,
171- ** llm_with_tools .kwargs
159+ ** llm_with_tools .kwargs ,
172160 )
173161
174162
@@ -179,7 +167,7 @@ def test_parallel_tool_calls_cohere_class_level_raises_error():
179167 llm = ChatOCIGenAI (
180168 model_id = "cohere.command-r-plus" ,
181169 parallel_tool_calls = True , # Set at class level
182- client = oci_gen_ai_client
170+ client = oci_gen_ai_client ,
183171 )
184172
185173 def tool1 (x : int ) -> int :
@@ -194,18 +182,15 @@ def tool1(x: int) -> int:
194182 [HumanMessage (content = "test" )],
195183 stop = None ,
196184 stream = False ,
197- ** llm_with_tools .kwargs
185+ ** llm_with_tools .kwargs ,
198186 )
199187
200188
201189@pytest .mark .requires ("oci" )
202190def test_version_filter_llama_3_0_blocked ():
203191 """Test that Llama 3.0 models are blocked from parallel tool calling."""
204192 oci_gen_ai_client = MagicMock ()
205- llm = ChatOCIGenAI (
206- model_id = "meta.llama-3-70b-instruct" ,
207- client = oci_gen_ai_client
208- )
193+ llm = ChatOCIGenAI (model_id = "meta.llama-3-70b-instruct" , client = oci_gen_ai_client )
209194
210195 def tool1 (x : int ) -> int :
211196 """Tool 1."""
@@ -220,10 +205,7 @@ def tool1(x: int) -> int:
220205def test_version_filter_llama_3_1_blocked ():
221206 """Test that Llama 3.1 models are blocked from parallel tool calling."""
222207 oci_gen_ai_client = MagicMock ()
223- llm = ChatOCIGenAI (
224- model_id = "meta.llama-3.1-70b-instruct" ,
225- client = oci_gen_ai_client
226- )
208+ llm = ChatOCIGenAI (model_id = "meta.llama-3.1-70b-instruct" , client = oci_gen_ai_client )
227209
228210 def tool1 (x : int ) -> int :
229211 """Tool 1."""
@@ -239,8 +221,7 @@ def test_version_filter_llama_3_2_blocked():
239221 """Test that Llama 3.2 models are blocked from parallel tool calling."""
240222 oci_gen_ai_client = MagicMock ()
241223 llm = ChatOCIGenAI (
242- model_id = "meta.llama-3.2-11b-vision-instruct" ,
243- client = oci_gen_ai_client
224+ model_id = "meta.llama-3.2-11b-vision-instruct" , client = oci_gen_ai_client
244225 )
245226
246227 def tool1 (x : int ) -> int :
@@ -256,10 +237,7 @@ def tool1(x: int) -> int:
256237def test_version_filter_llama_3_3_blocked ():
257238 """Test that Llama 3.3 models are blocked from parallel tool calling."""
258239 oci_gen_ai_client = MagicMock ()
259- llm = ChatOCIGenAI (
260- model_id = "meta.llama-3.3-70b-instruct" ,
261- client = oci_gen_ai_client
262- )
240+ llm = ChatOCIGenAI (model_id = "meta.llama-3.3-70b-instruct" , client = oci_gen_ai_client )
263241
264242 def tool1 (x : int ) -> int :
265243 """Tool 1."""
@@ -275,8 +253,7 @@ def test_version_filter_llama_4_allowed():
275253 """Test that Llama 4 models are allowed parallel tool calling."""
276254 oci_gen_ai_client = MagicMock ()
277255 llm = ChatOCIGenAI (
278- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
279- client = oci_gen_ai_client
256+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
280257 )
281258
282259 def tool1 (x : int ) -> int :
@@ -294,10 +271,7 @@ def test_version_filter_other_models_allowed():
294271 oci_gen_ai_client = MagicMock ()
295272
296273 # Test with xAI Grok
297- llm_grok = ChatOCIGenAI (
298- model_id = "xai.grok-4-fast" ,
299- client = oci_gen_ai_client
300- )
274+ llm_grok = ChatOCIGenAI (model_id = "xai.grok-4-fast" , client = oci_gen_ai_client )
301275
302276 def tool1 (x : int ) -> int :
303277 """Tool 1."""
@@ -313,8 +287,7 @@ def test_version_filter_supports_parallel_tool_calls_method():
313287 """Test the _supports_parallel_tool_calls method directly."""
314288 oci_gen_ai_client = MagicMock ()
315289 llm = ChatOCIGenAI (
316- model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" ,
317- client = oci_gen_ai_client
290+ model_id = "meta.llama-4-maverick-17b-128e-instruct-fp8" , client = oci_gen_ai_client
318291 )
319292
320293 # Test various model IDs
0 commit comments