Skip to content

Commit

Permalink
feat: Add Azure embedders support (#6676)
Browse files Browse the repository at this point in the history
* Add Azure embedders
---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
  • Loading branch information
vblagoje and anakin87 committed Jan 5, 2024
1 parent b7159ad commit 552f0e3
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 0 deletions.
4 changes: 4 additions & 0 deletions haystack/components/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder

__all__ = [
"HuggingFaceTEITextEmbedder",
Expand All @@ -12,4 +14,6 @@
"SentenceTransformersDocumentEmbedder",
"OpenAITextEmbedder",
"OpenAIDocumentEmbedder",
"AzureOpenAITextEmbedder",
"AzureOpenAIDocumentEmbedder",
]
178 changes: 178 additions & 0 deletions haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from typing import List, Optional, Dict, Any, Tuple

from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from tqdm import tqdm

from haystack import component, Document, default_to_dict


@component
class AzureOpenAIDocumentEmbedder:
"""
A component for computing Document embeddings using OpenAI models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage example:
```python
from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
doc = Document(content="I love pizza!")
document_embedder = AzureOpenAIDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
"""

def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create an AzureOpenAITextEmbedder component.
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")

self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]

text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
).replace("\n", " ")

texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""

all_embeddings: List[List[float]] = []
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
batch = texts_to_embed[i : i + batch_size]
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)

# Append embeddings to the list
all_embeddings.extend(el.embedding for el in response.data)

# Update the meta information only once if it's empty
if not meta["model"]:
meta["model"] = response.model
meta["usage"] = dict(response.usage)
else:
# Update the usage tokens
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens

return all_embeddings, meta

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document.
:param documents: A list of Documents to embed.
"""
if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)):
raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.")

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)

# Assign the corresponding embeddings to each document
for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": meta}
123 changes: 123 additions & 0 deletions haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
from typing import List, Optional, Dict, Any

from openai.lib.azure import AzureADTokenProvider, AzureOpenAI

from haystack import component, default_to_dict, Document


@component
class AzureOpenAITextEmbedder:
"""
A component for embedding strings using OpenAI models.
Usage example:
```python
from haystack.components.embedders import AzureOpenAITextEmbedder
text_to_embed = "I love pizza!"
text_embedder = AzureOpenAITextEmbedder()
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""

def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
):
"""
Create an AzureOpenAITextEmbedder component.
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# Why is this here?
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
# of passing it as a parameter.
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")

self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.prefix = prefix
self.suffix = suffix

self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string using AzureOpenAITextEmbedder."""
if not isinstance(text, str):
# Check if input is a list and all elements are instances of Document
if isinstance(text, list) and all(isinstance(elem, Document) for elem in text):
error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents."
else:
error_message = "Input must be a string."
raise TypeError(error_message)

# Preprocess the text by adding prefixes/suffixes
# finally, replace newlines as recommended by OpenAI docs
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")

response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)

return {
"embedding": response.data[0].embedding,
"meta": {"model": response.model, "usage": dict(response.usage)},
}
5 changes: 5 additions & 0 deletions releasenotes/notes/add-azure-embedders-fd5f4fbcab0e1c48.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Adds AzureOpenAIDocumentEmbedder and AzureOpenAITextEmbedder as new embedders. These embedders are very similar to
their OpenAI counterparts, but they use the Azure API instead of the OpenAI API.
Loading

0 comments on commit 552f0e3

Please sign in to comment.