Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions tests/integrations/huggingface_hub/test_huggingface_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
)


def get_hf_provider_inference_client():
# The provider parameter was added in version 0.28.0 of huggingface_hub
return (
InferenceClient(model="test-model", provider="hf-inference")
if HF_VERSION >= (0, 28, 0)
else InferenceClient(model="test-model")
)


def _add_mock_response(
httpx_mock, rsps, method, url, json=None, status=200, body=None, headers=None
):
Expand Down Expand Up @@ -616,7 +625,7 @@ def test_chat_completion(
)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

with sentry_sdk.start_transaction(name="test"):
client.chat_completion(
Expand Down Expand Up @@ -688,7 +697,7 @@ def test_chat_completion_streaming(
)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

with sentry_sdk.start_transaction(name="test"):
_ = list(
Expand Down Expand Up @@ -752,7 +761,7 @@ def test_chat_completion_api_error(
sentry_init(traces_sample_rate=1.0)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

with sentry_sdk.start_transaction(name="test"):
with pytest.raises(HfHubHTTPError):
Expand Down Expand Up @@ -804,7 +813,7 @@ def test_span_status_error(sentry_init, capture_events, mock_hf_api_with_errors)
sentry_init(traces_sample_rate=1.0)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

with sentry_sdk.start_transaction(name="test"):
with pytest.raises(HfHubHTTPError):
Expand Down Expand Up @@ -849,7 +858,7 @@ def test_chat_completion_with_tools(
)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

tools = [
{
Expand Down Expand Up @@ -938,7 +947,7 @@ def test_chat_completion_streaming_with_tools(
)
events = capture_events()

client = InferenceClient(model="test-model")
client = get_hf_provider_inference_client()

tools = [
{
Expand Down