-
Notifications
You must be signed in to change notification settings - Fork 13.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Harrison/add huggingface hub #23
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
"""Wrappers on top of large language models APIs.""" | ||
from langchain.llms.cohere import Cohere | ||
from langchain.llms.huggingface_hub import HuggingFaceHub | ||
from langchain.llms.openai import OpenAI | ||
|
||
__all__ = ["Cohere", "OpenAI"] | ||
__all__ = ["Cohere", "OpenAI", "HuggingFaceHub"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""Wrapper around HuggingFace APIs.""" | ||
import os | ||
from typing import Any, Dict, List, Mapping, Optional | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.llms.base import LLM | ||
from langchain.llms.utils import enforce_stop_tokens | ||
|
||
DEFAULT_REPO_ID = "gpt2" | ||
|
||
|
||
class HuggingFaceHub(BaseModel, LLM): | ||
"""Wrapper around HuggingFaceHub models. | ||
|
||
To use, you should have the ``huggingface_hub`` python package installed, and the | ||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token. | ||
|
||
Only supports task `text-generation` for now. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
from langchain import HuggingFace | ||
hf = HuggingFace(model="text-davinci-002") | ||
""" | ||
|
||
client: Any #: :meta private: | ||
repo_id: str = DEFAULT_REPO_ID | ||
"""Model name to use.""" | ||
temperature: float = 0.7 | ||
"""What sampling temperature to use.""" | ||
max_new_tokens: int = 200 | ||
"""The maximum number of tokens to generate in the completion.""" | ||
top_p: int = 1 | ||
"""Total probability mass of tokens to consider at each step.""" | ||
num_return_sequences: int = 1 | ||
"""How many completions to generate for each prompt.""" | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
if "HUGGINGFACEHUB_API_TOKEN" not in os.environ: | ||
raise ValueError( | ||
"Did not find HuggingFace API token, please add an environment variable" | ||
" `HUGGINGFACEHUB_API_TOKEN` which contains it." | ||
) | ||
try: | ||
from huggingface_hub.inference_api import InferenceApi | ||
|
||
repo_id = values.get("repo_id", DEFAULT_REPO_ID) | ||
values["client"] = InferenceApi( | ||
repo_id=repo_id, | ||
token=os.environ["HUGGINGFACEHUB_API_TOKEN"], | ||
task="text-generation", | ||
) | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import huggingface_hub python package. " | ||
"Please it install it with `pip install huggingface_hub`." | ||
) | ||
return values | ||
|
||
@property | ||
def _default_params(self) -> Mapping[str, Any]: | ||
"""Get the default parameters for calling HuggingFace Hub API.""" | ||
return { | ||
"temperature": self.temperature, | ||
"max_new_tokens": self.max_new_tokens, | ||
"top_p": self.top_p, | ||
"num_return_sequences": self.num_return_sequences, | ||
} | ||
|
||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: | ||
"""Call out to HuggingFace Hub's inference endpoint. | ||
|
||
Args: | ||
prompt: The prompt to pass into the model. | ||
stop: Optional list of stop words to use when generating. | ||
|
||
Returns: | ||
The string generated by the model. | ||
|
||
Example: | ||
.. code-block:: python | ||
|
||
response = hf("Tell me a joke.") | ||
""" | ||
response = self.client(inputs=prompt, params=self._default_params) | ||
if "error" in response: | ||
raise ValueError(f"Error raised by inference API: {response['error']}") | ||
text = response[0]["generated_text"][len(prompt) :] | ||
if stop is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we share this across models? Yeah it's a bit hacky but I don't see better support via InferenceAPI and ultimately it just ends with a bit of wasted computation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean by share across models? its already factored out and used in cohere as well (cohere is a bit different - you can pass stop words but they are included at the end of the prompt) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is sufficient - not enough examples to merit more |
||
# This is a bit hacky, but I can't figure out a better way to enforce | ||
# stop tokens when making calls to huggingface_hub. | ||
text = enforce_stop_tokens(text, stop) | ||
return text |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Common utility functions for working with LLM APIs.""" | ||
import re | ||
from typing import List | ||
|
||
|
||
def enforce_stop_tokens(text: str, stop: List[str]) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking - the prompt isn't returned as part of the generated text, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no it shouldnt be |
||
"""Cut off the text as soon as any stop words occur.""" | ||
return re.split("|".join(stop), text)[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,15 @@ | ||
-r test_requirements.txt | ||
# For linting | ||
black | ||
isort | ||
mypy | ||
flake8 | ||
flake8-docstrings | ||
# For integrations | ||
cohere | ||
openai | ||
google-search-results | ||
playwright | ||
huggingface_hub | ||
# For development | ||
jupyter |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Test HuggingFace API wrapper.""" | ||
|
||
import pytest | ||
|
||
from langchain.llms.huggingface_hub import HuggingFaceHub | ||
|
||
|
||
def test_huggingface_call() -> None: | ||
"""Test valid call to HuggingFace.""" | ||
llm = HuggingFaceHub(max_new_tokens=10) | ||
output = llm("Say foo:") | ||
assert isinstance(output, str) | ||
|
||
|
||
def test_huggingface_call_error() -> None: | ||
"""Test valid call to HuggingFace that errors.""" | ||
llm = HuggingFaceHub(max_new_tokens=-1) | ||
with pytest.raises(ValueError): | ||
llm("Say foo:") |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Test LLM utility functions.""" | ||
from langchain.llms.utils import enforce_stop_tokens | ||
|
||
|
||
def test_enforce_stop_tokens() -> None: | ||
"""Test removing stop tokens when they occur.""" | ||
text = "foo bar baz" | ||
output = enforce_stop_tokens(text, ["moo", "baz"]) | ||
assert output == "foo bar " | ||
text = "foo bar baz" | ||
output = enforce_stop_tokens(text, ["moo", "baz", "bar"]) | ||
assert output == "foo " | ||
|
||
|
||
def test_enforce_stop_tokens_none() -> None: | ||
"""Test removing stop tokens when they do not occur.""" | ||
text = "foo bar baz" | ||
output = enforce_stop_tokens(text, ["moo"]) | ||
assert output == "foo bar baz" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These parameters seem like nice ones to include as init args with optional env args as backups.
Also the relationship feels a bit convoluted at first glance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what in particular seems convoluted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main thing is that the client is created in a validation function - but that seems out of scope of this PR given the existing structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah i agree. is a bit weird. may look into post_init