Skip to content

Commit

Permalink
Huggingface Hub integration (#3033)
Browse files Browse the repository at this point in the history
Adds integration for Huggingface Hub.

---------

Co-authored-by: Anton Pirker <anton.pirker@sentry.io>
  • Loading branch information
colin-sentry and antonpirker committed May 2, 2024
1 parent eac253a commit 41aa99b
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 1 deletion.
8 changes: 8 additions & 0 deletions .github/workflows/test-integrations-data-processing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ jobs:
run: |
set -x # print commands that are executed
./scripts/runtox.sh "py${{ matrix.python-version }}-openai-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test huggingface_hub latest
run: |
set -x # print commands that are executed
./scripts/runtox.sh "py${{ matrix.python-version }}-huggingface_hub-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test rq latest
run: |
set -x # print commands that are executed
Expand Down Expand Up @@ -134,6 +138,10 @@ jobs:
run: |
set -x # print commands that are executed
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-openai" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test huggingface_hub pinned
run: |
set -x # print commands that are executed
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-huggingface_hub" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
- name: Test rq pinned
run: |
set -x # print commands that are executed
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-openai.*]
ignore_missing_imports = True
[mypy-huggingface_hub.*]
ignore_missing_imports = True
[mypy-arq.*]
ignore_missing_imports = True
[mypy-grpc.*]
Expand Down
1 change: 1 addition & 0 deletions scripts/split-tox-gh-actions/split-tox-gh-actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"huey",
"langchain",
"openai",
"huggingface_hub",
"rq",
],
"Databases": [
Expand Down
3 changes: 3 additions & 0 deletions sentry_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ class OP:
MIDDLEWARE_STARLITE_SEND = "middleware.starlite.send"
OPENAI_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.openai"
OPENAI_EMBEDDINGS_CREATE = "ai.embeddings.create.openai"
HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE = (
"ai.chat_completions.create.huggingface_hub"
)
LANGCHAIN_PIPELINE = "ai.pipeline.langchain"
LANGCHAIN_RUN = "ai.run.langchain"
LANGCHAIN_TOOL = "ai.tool.langchain"
Expand Down
1 change: 1 addition & 0 deletions sentry_sdk/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def iter_default_integrations(with_auto_enabling_integrations):
"sentry_sdk.integrations.graphene.GrapheneIntegration",
"sentry_sdk.integrations.httpx.HttpxIntegration",
"sentry_sdk.integrations.huey.HueyIntegration",
"sentry_sdk.integrations.huggingface_hub.HuggingfaceHubIntegration",
"sentry_sdk.integrations.langchain.LangchainIntegration",
"sentry_sdk.integrations.loguru.LoguruIntegration",
"sentry_sdk.integrations.openai.OpenAIIntegration",
Expand Down
173 changes: 173 additions & 0 deletions sentry_sdk/integrations/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from functools import wraps

from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized
from sentry_sdk.consts import SPANDATA

from typing import Any, Iterable, Callable

import sentry_sdk
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.utils import (
capture_internal_exceptions,
event_from_exception,
ensure_integration_enabled,
)

try:
import huggingface_hub.inference._client

from huggingface_hub import ChatCompletionStreamOutput, TextGenerationOutput
except ImportError:
raise DidNotEnable("Huggingface not installed")


class HuggingfaceHubIntegration(Integration):
identifier = "huggingface_hub"

def __init__(self, include_prompts=True):
# type: (HuggingfaceHubIntegration, bool) -> None
self.include_prompts = include_prompts

@staticmethod
def setup_once():
# type: () -> None
huggingface_hub.inference._client.InferenceClient.text_generation = (
_wrap_text_generation(
huggingface_hub.inference._client.InferenceClient.text_generation
)
)


def _capture_exception(exc):
# type: (Any) -> None
event, hint = event_from_exception(
exc,
client_options=sentry_sdk.get_client().options,
mechanism={"type": "huggingface_hub", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)


def _wrap_text_generation(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
@ensure_integration_enabled(HuggingfaceHubIntegration, f)
def new_text_generation(*args, **kwargs):
# type: (*Any, **Any) -> Any
if "prompt" in kwargs:
prompt = kwargs["prompt"]
elif len(args) >= 2:
kwargs["prompt"] = args[1]
prompt = kwargs["prompt"]
args = (args[0],) + args[2:]
else:
# invalid call, let it return error
return f(*args, **kwargs)

model = kwargs.get("model")
streaming = kwargs.get("stream")

span = sentry_sdk.start_span(
op=consts.OP.HUGGINGFACE_HUB_CHAT_COMPLETIONS_CREATE,
description="Text Generation",
)
span.__enter__()
try:
res = f(*args, **kwargs)
except Exception as e:
_capture_exception(e)
span.__exit__(None, None, None)
raise e from None

integration = sentry_sdk.get_client().get_integration(HuggingfaceHubIntegration)

with capture_internal_exceptions():
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(span, SPANDATA.AI_INPUT_MESSAGES, prompt)

set_data_normalized(span, SPANDATA.AI_MODEL_ID, model)
set_data_normalized(span, SPANDATA.AI_STREAMING, streaming)

if isinstance(res, str):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
"ai.responses",
[res],
)
span.__exit__(None, None, None)
return res

if isinstance(res, TextGenerationOutput):
if should_send_default_pii() and integration.include_prompts:
set_data_normalized(
span,
"ai.responses",
[res.generated_text],
)
if res.details is not None and res.details.generated_tokens > 0:
record_token_usage(span, total_tokens=res.details.generated_tokens)
span.__exit__(None, None, None)
return res

if not isinstance(res, Iterable):
# we only know how to deal with strings and iterables, ignore
set_data_normalized(span, "unknown_response", True)
span.__exit__(None, None, None)
return res

if kwargs.get("details", False):
# res is Iterable[TextGenerationStreamOutput]
def new_details_iterator():
# type: () -> Iterable[ChatCompletionStreamOutput]
with capture_internal_exceptions():
tokens_used = 0
data_buf: list[str] = []
for x in res:
if hasattr(x, "token") and hasattr(x.token, "text"):
data_buf.append(x.token.text)
if hasattr(x, "details") and hasattr(
x.details, "generated_tokens"
):
tokens_used = x.details.generated_tokens
yield x
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
if tokens_used > 0:
record_token_usage(span, total_tokens=tokens_used)
span.__exit__(None, None, None)

return new_details_iterator()
else:
# res is Iterable[str]

def new_iterator():
# type: () -> Iterable[str]
data_buf: list[str] = []
with capture_internal_exceptions():
for s in res:
if isinstance(s, str):
data_buf.append(s)
yield s
if (
len(data_buf) > 0
and should_send_default_pii()
and integration.include_prompts
):
set_data_normalized(
span, SPANDATA.AI_RESPONSES, "".join(data_buf)
)
span.__exit__(None, None, None)

return new_iterator()

return new_text_generation
2 changes: 1 addition & 1 deletion sentry_sdk/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def count_tokens(s):

# To avoid double collecting tokens, we do *not* measure
# token counts for models for which we have an explicit integration
NO_COLLECT_TOKEN_MODELS = ["openai-chat"]
NO_COLLECT_TOKEN_MODELS = ["openai-chat"] # TODO add huggingface and anthropic


class LangchainIntegration(Integration):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_file_text(file_name):
"grpcio": ["grpcio>=1.21.1"],
"httpx": ["httpx>=0.16.0"],
"huey": ["huey>=2"],
"huggingface_hub": ["huggingface_hub>=0.22"],
"langchain": ["langchain>=0.0.210"],
"loguru": ["loguru>=0.5"],
"openai": ["openai>=1.0.0", "tiktoken>=0.3.0"],
Expand Down
3 changes: 3 additions & 0 deletions tests/integrations/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytest.importorskip("huggingface_hub")

0 comments on commit 41aa99b

Please sign in to comment.