From 8dabc2b14c3ae2d73328387346149c7ac1847466 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 1 May 2023 16:53:38 -0700 Subject: [PATCH] Support `**encode_kwargs` for `HuggingfaceEmbeddings` (#3914) Support more configurability when using `HuggingfaceEmbeddings`. This allows you to configure the batch size for encoding. --- langchain/embeddings/huggingface.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index 242562271f07ca..f8cb12bf2c8039 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -58,11 +58,13 @@ class Config: extra = Extra.forbid - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: List[str], **encode_kwargs) -> List[List[float]]: """Compute doc embeddings using a HuggingFace transformer model. Args: texts: The list of texts to embed. + encode_kwargs: Additional kwargs to pass into the `encode` method of the + SentenceTransformer. Returns: List of embeddings, one for each text. @@ -71,11 +73,13 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = self.client.encode(texts) return embeddings.tolist() - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str, **encode_kwargs) -> List[float]: """Compute query embeddings using a HuggingFace transformer model. Args: text: The text to embed. + encode_kwargs: Additional kwargs to pass into the `encode` method of the + SentenceTransformer. Returns: Embeddings for the text.