Skip to content

Commit 3bb4d01

Browse files
committed
Apply ruff formatting to parallel tool calling tests
1 parent 1ed506a commit 3bb4d01

File tree

2 files changed

+38
-67
lines changed

2 files changed

+38
-67
lines changed

libs/oci/tests/integration_tests/chat_models/test_parallel_tool_calling_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def test_parallel_tool_calling_enabled():
8282
logging.info("\nQuery: 'What's the weather in New York City?'")
8383

8484
start_time = time.time()
85-
response = chat_with_tools.invoke([
86-
HumanMessage(content="What's the weather in New York City?")
87-
])
85+
response = chat_with_tools.invoke(
86+
[HumanMessage(content="What's the weather in New York City?")]
87+
)
8888
elapsed_time = time.time() - start_time
8989

9090
logging.info(f"\nResponse time: {elapsed_time:.2f}s")
@@ -135,9 +135,9 @@ def test_parallel_tool_calling_disabled():
135135
logging.info("\nQuery: 'What's the weather in New York City?'")
136136

137137
start_time = time.time()
138-
response = chat_with_tools.invoke([
139-
HumanMessage(content="What's the weather in New York City?")
140-
])
138+
response = chat_with_tools.invoke(
139+
[HumanMessage(content="What's the weather in New York City?")]
140+
)
141141
elapsed_time = time.time() - start_time
142142

143143
logging.info(f"\nResponse time: {elapsed_time:.2f}s")
@@ -181,14 +181,14 @@ def test_bind_tools_override():
181181
# Override with True in bind_tools
182182
chat_with_tools = chat.bind_tools(
183183
[get_weather, get_population],
184-
parallel_tool_calls=True # Override to enable
184+
parallel_tool_calls=True, # Override to enable
185185
)
186186

187187
logging.info("\nQuery: 'What's the weather and population of Tokyo?'")
188188

189-
response = chat_with_tools.invoke([
190-
HumanMessage(content="What's the weather and population of Tokyo?")
191-
])
189+
response = chat_with_tools.invoke(
190+
[HumanMessage(content="What's the weather and population of Tokyo?")]
191+
)
192192

193193
logging.info(f"\nResponse content: {response.content}")
194194
logging.info(f"Tool calls count: {len(response.tool_calls)}")
@@ -219,17 +219,14 @@ def test_cohere_model_error():
219219
)
220220

221221
# Try to enable parallel tool calls with Cohere (should fail)
222-
chat_with_tools = chat.bind_tools(
223-
[get_weather],
224-
parallel_tool_calls=True
225-
)
222+
chat_with_tools = chat.bind_tools([get_weather], parallel_tool_calls=True)
226223

227224
logging.info("\nAttempting to use parallel_tool_calls with Cohere model...")
228225

229226
try:
230-
_ = chat_with_tools.invoke([
231-
HumanMessage(content="What's the weather in Paris?")
232-
])
227+
_ = chat_with_tools.invoke(
228+
[HumanMessage(content="What's the weather in Paris?")]
229+
)
233230
logging.info("❌ TEST FAILED: Should have raised ValueError")
234231
return False
235232
except ValueError as e:
@@ -313,6 +310,7 @@ def main():
313310
except Exception as e:
314311
logging.info(f"\n❌ ERROR: {e}")
315312
import traceback
313+
316314
traceback.print_exc()
317315
return 1
318316

libs/oci/tests/unit_tests/chat_models/test_parallel_tool_calling.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Unit tests for parallel tool calling feature."""
2+
23
from unittest.mock import MagicMock
34

45
import pytest
@@ -14,7 +15,7 @@ def test_parallel_tool_calls_class_level():
1415
llm = ChatOCIGenAI(
1516
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
1617
parallel_tool_calls=True,
17-
client=oci_gen_ai_client
18+
client=oci_gen_ai_client,
1819
)
1920
assert llm.parallel_tool_calls is True
2021

@@ -24,8 +25,7 @@ def test_parallel_tool_calls_default_false():
2425
"""Test that parallel_tool_calls defaults to False."""
2526
oci_gen_ai_client = MagicMock()
2627
llm = ChatOCIGenAI(
27-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
28-
client=oci_gen_ai_client
28+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
2929
)
3030
assert llm.parallel_tool_calls is False
3131

@@ -35,8 +35,7 @@ def test_parallel_tool_calls_bind_tools_explicit_true():
3535
"""Test parallel_tool_calls=True in bind_tools."""
3636
oci_gen_ai_client = MagicMock()
3737
llm = ChatOCIGenAI(
38-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
39-
client=oci_gen_ai_client
38+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
4039
)
4140

4241
def tool1(x: int) -> int:
@@ -47,10 +46,7 @@ def tool2(x: int) -> int:
4746
"""Tool 2."""
4847
return x * 2
4948

50-
llm_with_tools = llm.bind_tools(
51-
[tool1, tool2],
52-
parallel_tool_calls=True
53-
)
49+
llm_with_tools = llm.bind_tools([tool1, tool2], parallel_tool_calls=True)
5450

5551
assert llm_with_tools.kwargs.get("is_parallel_tool_calls") is True
5652

@@ -60,18 +56,14 @@ def test_parallel_tool_calls_bind_tools_explicit_false():
6056
"""Test parallel_tool_calls=False in bind_tools."""
6157
oci_gen_ai_client = MagicMock()
6258
llm = ChatOCIGenAI(
63-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
64-
client=oci_gen_ai_client
59+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
6560
)
6661

6762
def tool1(x: int) -> int:
6863
"""Tool 1."""
6964
return x + 1
7065

71-
llm_with_tools = llm.bind_tools(
72-
[tool1],
73-
parallel_tool_calls=False
74-
)
66+
llm_with_tools = llm.bind_tools([tool1], parallel_tool_calls=False)
7567

7668
# When explicitly False, should not set the parameter
7769
assert "is_parallel_tool_calls" not in llm_with_tools.kwargs
@@ -84,7 +76,7 @@ def test_parallel_tool_calls_bind_tools_uses_class_default():
8476
llm = ChatOCIGenAI(
8577
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
8678
parallel_tool_calls=True, # Set class default
87-
client=oci_gen_ai_client
79+
client=oci_gen_ai_client,
8880
)
8981

9082
def tool1(x: int) -> int:
@@ -105,7 +97,7 @@ def test_parallel_tool_calls_bind_tools_overrides_class_default():
10597
llm = ChatOCIGenAI(
10698
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
10799
parallel_tool_calls=True, # Set class default to True
108-
client=oci_gen_ai_client
100+
client=oci_gen_ai_client,
109101
)
110102

111103
def tool1(x: int) -> int:
@@ -124,8 +116,7 @@ def test_parallel_tool_calls_passed_to_oci_api_meta():
124116
"""Test that is_parallel_tool_calls is passed to OCI API for Meta models."""
125117
oci_gen_ai_client = MagicMock()
126118
llm = ChatOCIGenAI(
127-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
128-
client=oci_gen_ai_client
119+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
129120
)
130121

131122
def get_weather(city: str) -> str:
@@ -139,22 +130,19 @@ def get_weather(city: str) -> str:
139130
[HumanMessage(content="What's the weather?")],
140131
stop=None,
141132
stream=False,
142-
**llm_with_tools.kwargs
133+
**llm_with_tools.kwargs,
143134
)
144135

145136
# Verify is_parallel_tool_calls is in the request
146-
assert hasattr(request.chat_request, 'is_parallel_tool_calls')
137+
assert hasattr(request.chat_request, "is_parallel_tool_calls")
147138
assert request.chat_request.is_parallel_tool_calls is True
148139

149140

150141
@pytest.mark.requires("oci")
151142
def test_parallel_tool_calls_cohere_raises_error():
152143
"""Test that Cohere models raise error for parallel tool calls."""
153144
oci_gen_ai_client = MagicMock()
154-
llm = ChatOCIGenAI(
155-
model_id="cohere.command-r-plus",
156-
client=oci_gen_ai_client
157-
)
145+
llm = ChatOCIGenAI(model_id="cohere.command-r-plus", client=oci_gen_ai_client)
158146

159147
def tool1(x: int) -> int:
160148
"""Tool 1."""
@@ -168,7 +156,7 @@ def tool1(x: int) -> int:
168156
[HumanMessage(content="test")],
169157
stop=None,
170158
stream=False,
171-
**llm_with_tools.kwargs
159+
**llm_with_tools.kwargs,
172160
)
173161

174162

@@ -179,7 +167,7 @@ def test_parallel_tool_calls_cohere_class_level_raises_error():
179167
llm = ChatOCIGenAI(
180168
model_id="cohere.command-r-plus",
181169
parallel_tool_calls=True, # Set at class level
182-
client=oci_gen_ai_client
170+
client=oci_gen_ai_client,
183171
)
184172

185173
def tool1(x: int) -> int:
@@ -194,18 +182,15 @@ def tool1(x: int) -> int:
194182
[HumanMessage(content="test")],
195183
stop=None,
196184
stream=False,
197-
**llm_with_tools.kwargs
185+
**llm_with_tools.kwargs,
198186
)
199187

200188

201189
@pytest.mark.requires("oci")
202190
def test_version_filter_llama_3_0_blocked():
203191
"""Test that Llama 3.0 models are blocked from parallel tool calling."""
204192
oci_gen_ai_client = MagicMock()
205-
llm = ChatOCIGenAI(
206-
model_id="meta.llama-3-70b-instruct",
207-
client=oci_gen_ai_client
208-
)
193+
llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client)
209194

210195
def tool1(x: int) -> int:
211196
"""Tool 1."""
@@ -220,10 +205,7 @@ def tool1(x: int) -> int:
220205
def test_version_filter_llama_3_1_blocked():
221206
"""Test that Llama 3.1 models are blocked from parallel tool calling."""
222207
oci_gen_ai_client = MagicMock()
223-
llm = ChatOCIGenAI(
224-
model_id="meta.llama-3.1-70b-instruct",
225-
client=oci_gen_ai_client
226-
)
208+
llm = ChatOCIGenAI(model_id="meta.llama-3.1-70b-instruct", client=oci_gen_ai_client)
227209

228210
def tool1(x: int) -> int:
229211
"""Tool 1."""
@@ -239,8 +221,7 @@ def test_version_filter_llama_3_2_blocked():
239221
"""Test that Llama 3.2 models are blocked from parallel tool calling."""
240222
oci_gen_ai_client = MagicMock()
241223
llm = ChatOCIGenAI(
242-
model_id="meta.llama-3.2-11b-vision-instruct",
243-
client=oci_gen_ai_client
224+
model_id="meta.llama-3.2-11b-vision-instruct", client=oci_gen_ai_client
244225
)
245226

246227
def tool1(x: int) -> int:
@@ -256,10 +237,7 @@ def tool1(x: int) -> int:
256237
def test_version_filter_llama_3_3_blocked():
257238
"""Test that Llama 3.3 models are blocked from parallel tool calling."""
258239
oci_gen_ai_client = MagicMock()
259-
llm = ChatOCIGenAI(
260-
model_id="meta.llama-3.3-70b-instruct",
261-
client=oci_gen_ai_client
262-
)
240+
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
263241

264242
def tool1(x: int) -> int:
265243
"""Tool 1."""
@@ -275,8 +253,7 @@ def test_version_filter_llama_4_allowed():
275253
"""Test that Llama 4 models are allowed parallel tool calling."""
276254
oci_gen_ai_client = MagicMock()
277255
llm = ChatOCIGenAI(
278-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
279-
client=oci_gen_ai_client
256+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
280257
)
281258

282259
def tool1(x: int) -> int:
@@ -294,10 +271,7 @@ def test_version_filter_other_models_allowed():
294271
oci_gen_ai_client = MagicMock()
295272

296273
# Test with xAI Grok
297-
llm_grok = ChatOCIGenAI(
298-
model_id="xai.grok-4-fast",
299-
client=oci_gen_ai_client
300-
)
274+
llm_grok = ChatOCIGenAI(model_id="xai.grok-4-fast", client=oci_gen_ai_client)
301275

302276
def tool1(x: int) -> int:
303277
"""Tool 1."""
@@ -313,8 +287,7 @@ def test_version_filter_supports_parallel_tool_calls_method():
313287
"""Test the _supports_parallel_tool_calls method directly."""
314288
oci_gen_ai_client = MagicMock()
315289
llm = ChatOCIGenAI(
316-
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8",
317-
client=oci_gen_ai_client
290+
model_id="meta.llama-4-maverick-17b-128e-instruct-fp8", client=oci_gen_ai_client
318291
)
319292

320293
# Test various model IDs

0 commit comments

Comments
 (0)