Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harrison/add huggingface hub #23

Merged
merged 5 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The following use cases require specific installs and environment variables:
- *Cohere*:
- Install requirements with `pip install cohere`
- Set the following environment variable: `COHERE_API_KEY`
- *HuggingFace*
- Install requirements with `pip install huggingface_hub`
- Set the following environment variable: `HUGGINGFACE_API_TOKEN`
- *SerpAPI*:
- Install requirements with `pip install google-search-results`
- Set the following environment variable: `SERPAPI_API_KEY`
Expand Down
3 changes: 2 additions & 1 deletion langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SelfAskWithSearchChain,
SerpAPIChain,
)
from langchain.llms import Cohere, OpenAI
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import Prompt

__all__ = [
Expand All @@ -24,4 +24,5 @@
"Cohere",
"OpenAI",
"Prompt",
"HuggingFaceHub",
]
3 changes: 2 additions & 1 deletion langchain/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrappers on top of large language models APIs."""
from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.openai import OpenAI

__all__ = ["Cohere", "OpenAI"]
__all__ = ["Cohere", "OpenAI", "HuggingFaceHub"]
11 changes: 2 additions & 9 deletions langchain/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM


def remove_stop_tokens(text: str, stop: List[str]) -> str:
"""Remove stop tokens, should they occur at end."""
for s in stop:
if text.endswith(s):
return text[: -len(s)]
return text
from langchain.llms.utils import enforce_stop_tokens


class Cohere(BaseModel, LLM):
Expand Down Expand Up @@ -104,5 +97,5 @@ def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None:
text = remove_stop_tokens(text, stop)
text = enforce_stop_tokens(text, stop)
return text
102 changes: 102 additions & 0 deletions langchain/llms/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Wrapper around HuggingFace APIs."""
import os
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens

DEFAULT_REPO_ID = "gpt2"


class HuggingFaceHub(BaseModel, LLM):
"""Wrapper around HuggingFaceHub models.

To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token.

Only supports task `text-generation` for now.

Example:
.. code-block:: python

from langchain import HuggingFace
hf = HuggingFace(model="text-davinci-002")
"""

client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
max_new_tokens: int = 200
"""The maximum number of tokens to generate in the completion."""
top_p: int = 1
"""Total probability mass of tokens to consider at each step."""
num_return_sequences: int = 1
"""How many completions to generate for each prompt."""

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if "HUGGINGFACEHUB_API_TOKEN" not in os.environ:
raise ValueError(
"Did not find HuggingFace API token, please add an environment variable"
" `HUGGINGFACEHUB_API_TOKEN` which contains it."
)
try:
from huggingface_hub.inference_api import InferenceApi

repo_id = values.get("repo_id", DEFAULT_REPO_ID)
values["client"] = InferenceApi(
repo_id=repo_id,
token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters seem like nice ones to include as init args with optional env args as backups.

Also the relationship feels a bit convoluted at first glance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what in particular seems convoluted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main thing is that the client is created in a validation function - but that seems out of scope of this PR given the existing structure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah i agree. is a bit weird. may look into post_init

task="text-generation",
)
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please it install it with `pip install huggingface_hub`."
)
return values

@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling HuggingFace Hub API."""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"num_return_sequences": self.num_return_sequences,
}

def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to HuggingFace Hub's inference endpoint.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.

Returns:
The string generated by the model.

Example:
.. code-block:: python

response = hf("Tell me a joke.")
"""
response = self.client(inputs=prompt, params=self._default_params)
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
text = response[0]["generated_text"][len(prompt) :]
if stop is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we share this across models?

Yeah it's a bit hacky but I don't see better support via InferenceAPI and ultimately it just ends with a bit of wasted computation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by share across models? its already factored out and used in cohere as well (cohere is a bit different - you can pass stop words but they are included at the end of the prompt)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is sufficient - not enough examples to merit more

# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
8 changes: 8 additions & 0 deletions langchain/llms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Common utility functions for working with LLM APIs."""
import re
from typing import List


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking - the prompt isn't returned as part of the generated text, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no it shouldnt be

"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0]
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
-r test_requirements.txt
# For linting
black
isort
mypy
flake8
flake8-docstrings
# For integrations
cohere
openai
google-search-results
playwright
huggingface_hub
# For development
jupyter
19 changes: 19 additions & 0 deletions tests/integration_tests/llms/test_huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test HuggingFace API wrapper."""

import pytest

from langchain.llms.huggingface_hub import HuggingFaceHub


def test_huggingface_call() -> None:
"""Test valid call to HuggingFace."""
llm = HuggingFaceHub(max_new_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)


def test_huggingface_call_error() -> None:
"""Test valid call to HuggingFace that errors."""
llm = HuggingFaceHub(max_new_tokens=-1)
with pytest.raises(ValueError):
llm("Say foo:")
17 changes: 0 additions & 17 deletions tests/unit_tests/llms/test_cohere.py

This file was deleted.

19 changes: 19 additions & 0 deletions tests/unit_tests/llms/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test LLM utility functions."""
from langchain.llms.utils import enforce_stop_tokens


def test_enforce_stop_tokens() -> None:
"""Test removing stop tokens when they occur."""
text = "foo bar baz"
output = enforce_stop_tokens(text, ["moo", "baz"])
assert output == "foo bar "
text = "foo bar baz"
output = enforce_stop_tokens(text, ["moo", "baz", "bar"])
assert output == "foo "


def test_enforce_stop_tokens_none() -> None:
"""Test removing stop tokens when they do not occur."""
text = "foo bar baz"
output = enforce_stop_tokens(text, ["moo"])
assert output == "foo bar baz"