Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HuggingFace Hub Embeddings #125

Merged
merged 10 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Wrappers around embedding modules."""
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings

__all__ = ["OpenAIEmbeddings", "HuggingFaceEmbeddings", "CohereEmbeddings"]
__all__ = [
"OpenAIEmbeddings",
"HuggingFaceEmbeddings",
"CohereEmbeddings",
"HuggingFaceHubEmbeddings",
]
6 changes: 4 additions & 2 deletions langchain/embeddings/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from langchain.embeddings.base import Embeddings

DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"


class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""Wrapper around sentence_transformers embedding models.
Expand All @@ -16,11 +18,11 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):

from langchain.embeddings import HuggingFaceEmbeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
huggingface = HuggingFaceEmbeddings(model_name=model_name)
hf = HuggingFaceEmbeddings(model_name=model_name)
"""

client: Any #: :meta private:
model_name: str = "sentence-transformers/all-mpnet-base-v2"
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""

def __init__(self, **kwargs: Any):
Expand Down
100 changes: 100 additions & 0 deletions langchain/embeddings/huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Wrapper around HuggingFace Hub embedding models."""
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env

DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2"
VALID_TASKS = ("feature-extraction",)


class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
"""Wrapper around HuggingFaceHub embedding models.

To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.

Example:
.. code-block:: python

from langchain.embeddings import HuggingFaceHubEmbeddings
repo_id = "sentence-transformers/all-mpnet-base-v2"
hf = HuggingFaceHubEmbeddings(
repo_id=repo_id,
task="feature-extraction",
huggingfacehub_api_token="my-api-key",
)
"""

client: Any #: :meta private:
repo_id: str = DEFAULT_REPO_ID
"""Model name to use."""
task: Optional[str] = None
"""Task to call the model with."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""

huggingfacehub_api_token: Optional[str] = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi

repo_id = values.get("repo_id", DEFAULT_REPO_ID)
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
client = InferenceApi(
repo_id=repo_id,
token=huggingfacehub_api_token,
task=values.get("task"),
)
if client.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please it install it with `pip install huggingface_hub`."
)
return values

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding search docs.

Args:
texts: The list of texts to embed.

Returns:
List of embeddings, one for each text.
"""
# replace newlines, which can negatively affect performance.
texts = [text.replace("\n", " ") for text in texts]
_model_kwargs = self.model_kwargs or {}
responses = self.client(inputs=texts, params=_model_kwargs)
return responses

def embed_query(self, text: str) -> List[float]:
"""Call out to HuggingFaceHub's embedding endpoint for embedding query text.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
response = self.embed_documents([text])[0]
return response
19 changes: 19 additions & 0 deletions tests/integration_tests/embeddings/test_huggingface_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test huggingfacehub embeddings."""
from langchain.embeddings import HuggingFaceHubEmbeddings


def test_huggingfacehub_embedding_documents() -> None:
"""Test huggingfacehub embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceHubEmbeddings(task="feature-extraction")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hwchase17 dunno how we feel about default model being one where the default task isn't valid (it's "sentence-similarity"). Can find a different one if that's preferred, was just matching the existing HuggingFace embeddings class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also seems like decent amount of feature-extraction models have 2D outputs — should we automatically flatten? or make the embed_ signatures more permissive?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good questions. i thik the reason for the weirdness is the HF embeddings use the sentence transformers package, but this can in theory use any

I think we'll also potentially want some different handling for whether the model is sentence transformer or not? im not super sure but how about this:

  1. default the task to feature-extraction (since its the only supported one, but we still let people change it if they really want to)

  2. check that repo_id starts with sentence_transformer - a bit restrictive but we can worry about extending it later? start simple and tight

output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768


def test_huggingfacehub_embedding_query() -> None:
"""Test huggingfacehub embeddings."""
document = "foo bar"
embedding = HuggingFaceHubEmbeddings(task="feature-extraction")
output = embedding.embed_query(document)
assert len(output) == 768