Skip to content

Commit

Permalink
Fixed openai embeddings to be safe by batching them based on token si…
Browse files Browse the repository at this point in the history
…ze calculation. (langchain-ai#991)

I modified the logic of the batch calculation for embedding according to
this cookbook

https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
  • Loading branch information
Hase-U authored and dongreenberg committed Feb 17, 2023
1 parent e58867a commit a018210
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
72 changes: 64 additions & 8 deletions langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Wrapper around OpenAI embedding models."""
from typing import Any, Dict, List, Optional

import numpy as np
from pydantic import BaseModel, Extra, root_validator

from langchain.embeddings.base import Embeddings
Expand All @@ -24,6 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:
document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002"
embedding_ctx_length: int = -1
openai_api_key: Optional[str] = None

class Config:
Expand Down Expand Up @@ -69,11 +71,62 @@ def validate_environment(cls, values: Dict) -> Dict:
)
return values

# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: int = 1000
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for i in range(len(texts))]
try:
import tiktoken

tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.document_model_name)
for i, text in enumerate(texts):
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
token = encoding.encode(text)
for j in range(0, len(token), self.embedding_ctx_length):
tokens += [token[j : j + self.embedding_ctx_length]]
indices += [i]

batched_embeddings = []
for i in range(0, len(tokens), chunk_size):
response = self.client.create(
input=tokens[i : i + chunk_size], engine=self.document_model_name
)
batched_embeddings += [r["embedding"] for r in response["data"]]

results: List[List[List[float]]] = [[] for i in range(len(texts))]
lens: List[List[int]] = [[] for i in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
lens[indices[i]].append(len(batched_embeddings[i]))

for i in range(len(texts)):
average = np.average(results[i], axis=0, weights=lens[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()

return embeddings

except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. "
"Please it install it with `pip install tiktoken`."
)

def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0][
"embedding"
]

def embed_documents(
self, texts: List[str], chunk_size: int = 1000
Expand All @@ -89,13 +142,16 @@ def embed_documents(
List of embeddings, one for each text.
"""
# handle large batches of texts
results = []
for i in range(0, len(texts), chunk_size):
response = self.client.create(
input=texts[i : i + chunk_size], engine=self.document_model_name
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings(
texts, engine=self.document_model_name, chunk_size=chunk_size
)
results += [r["embedding"] for r in response["data"]]
return results
else:
responses = [
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses

def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/embeddings/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_openai_embedding_documents_multiple() -> None:
"""Test openai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = OpenAIEmbeddings()
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents, chunk_size=2)
assert len(output) == 3
assert len(output[0]) == 1536
Expand Down

0 comments on commit a018210

Please sign in to comment.