diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 1967f6f3347..6a57deb1364 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -1,6 +1,12 @@ """Wrappers around embedding modules.""" from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings -__all__ = ["OpenAIEmbeddings", "HuggingFaceEmbeddings", "CohereEmbeddings"] +__all__ = [ + "OpenAIEmbeddings", + "HuggingFaceEmbeddings", + "CohereEmbeddings", + "HuggingFaceHubEmbeddings", +] diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index a32aef6561a..8d5d817b8a3 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -5,6 +5,8 @@ from langchain.embeddings.base import Embeddings +DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" + class HuggingFaceEmbeddings(BaseModel, Embeddings): """Wrapper around sentence_transformers embedding models. @@ -16,11 +18,11 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): from langchain.embeddings import HuggingFaceEmbeddings model_name = "sentence-transformers/all-mpnet-base-v2" - huggingface = HuggingFaceEmbeddings(model_name=model_name) + hf = HuggingFaceEmbeddings(model_name=model_name) """ client: Any #: :meta private: - model_name: str = "sentence-transformers/all-mpnet-base-v2" + model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" def __init__(self, **kwargs: Any): diff --git a/langchain/embeddings/huggingface_hub.py b/langchain/embeddings/huggingface_hub.py new file mode 100644 index 00000000000..66c662f0554 --- /dev/null +++ b/langchain/embeddings/huggingface_hub.py @@ -0,0 +1,105 @@ +"""Wrapper around HuggingFace Hub embedding models.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +DEFAULT_REPO_ID = "sentence-transformers/all-mpnet-base-v2" +VALID_TASKS = ("feature-extraction",) + + +class HuggingFaceHubEmbeddings(BaseModel, Embeddings): + """Wrapper around HuggingFaceHub embedding models. + + To use, you should have the ``huggingface_hub`` python package installed, and the + environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain.embeddings import HuggingFaceHubEmbeddings + repo_id = "sentence-transformers/all-mpnet-base-v2" + hf = HuggingFaceHubEmbeddings( + repo_id=repo_id, + task="feature-extraction", + huggingfacehub_api_token="my-api-key", + ) + """ + + client: Any #: :meta private: + repo_id: str = DEFAULT_REPO_ID + """Model name to use.""" + task: Optional[str] = "feature-extraction" + """Task to call the model with.""" + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + + huggingfacehub_api_token: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + huggingfacehub_api_token = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) + try: + from huggingface_hub.inference_api import InferenceApi + + repo_id = values["repo_id"] + if not repo_id.startswith("sentence-transformers"): + raise ValueError( + "Currently only 'sentence-transformers' embedding models " + f"are supported. Got invalid 'repo_id' {repo_id}." + ) + client = InferenceApi( + repo_id=repo_id, + token=huggingfacehub_api_token, + task=values.get("task"), + ) + if client.task not in VALID_TASKS: + raise ValueError( + f"Got invalid task {client.task}, " + f"currently only {VALID_TASKS} are supported" + ) + values["client"] = client + except ImportError: + raise ValueError( + "Could not import huggingface_hub python package. " + "Please it install it with `pip install huggingface_hub`." + ) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to HuggingFaceHub's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + # replace newlines, which can negatively affect performance. + texts = [text.replace("\n", " ") for text in texts] + _model_kwargs = self.model_kwargs or {} + responses = self.client(inputs=texts, params=_model_kwargs) + return responses + + def embed_query(self, text: str) -> List[float]: + """Call out to HuggingFaceHub's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + response = self.embed_documents([text])[0] + return response diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index c67c9720a4e..5cded677376 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -51,7 +51,7 @@ def validate_environment(cls, values: Dict) -> Dict: try: from huggingface_hub.inference_api import InferenceApi - repo_id = values.get("repo_id", DEFAULT_REPO_ID) + repo_id = values["repo_id"] client = InferenceApi( repo_id=repo_id, token=huggingfacehub_api_token, diff --git a/tests/integration_tests/embeddings/test_huggingface_hub.py b/tests/integration_tests/embeddings/test_huggingface_hub.py new file mode 100644 index 00000000000..ed57bcccd8a --- /dev/null +++ b/tests/integration_tests/embeddings/test_huggingface_hub.py @@ -0,0 +1,19 @@ +"""Test HuggingFaceHub embeddings.""" +from langchain.embeddings import HuggingFaceHubEmbeddings + + +def test_huggingfacehub_embedding_documents() -> None: + """Test huggingfacehub embeddings.""" + documents = ["foo bar"] + embedding = HuggingFaceHubEmbeddings() + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 768 + + +def test_huggingfacehub_embedding_query() -> None: + """Test huggingfacehub embeddings.""" + document = "foo bar" + embedding = HuggingFaceHubEmbeddings() + output = embedding.embed_query(document) + assert len(output) == 768 diff --git a/tests/unit_tests/embeddings/__init__.py b/tests/unit_tests/embeddings/__init__.py new file mode 100644 index 00000000000..9aaef73a027 --- /dev/null +++ b/tests/unit_tests/embeddings/__init__.py @@ -0,0 +1 @@ +"""All unit tests for Embeddings objects.""" diff --git a/tests/unit_tests/embeddings/test_huggingface_hub.py b/tests/unit_tests/embeddings/test_huggingface_hub.py new file mode 100644 index 00000000000..ecc1fbfbe21 --- /dev/null +++ b/tests/unit_tests/embeddings/test_huggingface_hub.py @@ -0,0 +1,11 @@ +"""Test HuggingFaceHub embeddings.""" +import pytest + +from langchain.embeddings import HuggingFaceHubEmbeddings + + +def test_huggingfacehub_embedding_invalid_repo() -> None: + """Test huggingfacehub embedding repo id validation.""" + # Only sentence-transformers models are currently supported. + with pytest.raises(ValueError): + HuggingFaceHubEmbeddings(repo_id="allenai/specter")