Skip to content

Commit

Permalink
community[minor]: New model parameters and dynamic batching for Verte…
Browse files Browse the repository at this point in the history
…xAIEmbeddings (#13999)

- **Description:** VertexAIEmbeddings performance improvements
  - **Twitter handle:** @vladkol

## Improvements

- Dynamic batch size, starting from 250, lowering down to 5. Batch size
varies across regions.
Some regions support larger batches, and it significantly improves
performance.
When running large batches of texts in `us-central1`, performance gain
can be up to 3.5x.
The dynamic batching also makes sure every batch is below 20K token
limit.
- New model parameter `embeddings_type` that translates to `task_type`
parameter of the API. Newer model versions support [different embeddings
task
types](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023).
  • Loading branch information
vladkol committed Dec 18, 2023
1 parent 2e6a9e6 commit 11fda49
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 17 deletions.
306 changes: 291 additions & 15 deletions libs/community/langchain_community/embeddings/vertexai.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,324 @@
from typing import Dict, List
import logging
import re
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, List, Literal, Optional, Tuple

from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator

from langchain_community.llms.vertexai import _VertexAICommon
from langchain_community.utilities.vertexai import raise_vertex_import_error

logger = logging.getLogger(__name__)

_MAX_TOKENS_PER_BATCH = 20000
_MAX_BATCH_SIZE = 250
_MIN_BATCH_SIZE = 5


class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""

model_name: str = "textembedding-gecko"
# Instance context
instance: Dict[str, Any] = {} #: :meta private:

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._try_init_vertexai(values)
try:
from vertexai.language_models import TextEmbeddingModel

values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
except ImportError:
raise_vertex_import_error()
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def embed_documents(
self, texts: List[str], batch_size: int = 5
def __init__(
self,
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
model_name: str = "textembedding-gecko",
credentials: Optional[Any] = None,
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
model_name=model_name,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
self.instance["batch_size"] = self.instance["max_batch_size"]
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
self.instance["lock"] = threading.Lock()
self.instance["batch_size_validated"] = False
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]

@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""Splits texts in batches based on current maximum batch size
and maximum tokens per request.
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches

def _get_embeddings_with_retry(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
ResourceExhausted,
ServiceUnavailable,
)

errors = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
retry_decorator = create_base_retry_decorator(
error_types=errors, max_retries=self.max_retries
)

@retry_decorator
def _completion_with_retry(texts_to_process: List[str]) -> Any:
if embeddings_type and self.instance["embeddings_task_type_supported"]:
from vertexai.language_models import TextEmbeddingInput

requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type)
for t in texts_to_process
]
else:
requests = texts_to_process
embeddings = self.client.get_embeddings(requests)
return [embs.values for embs in embeddings]

return _completion_with_retry(texts)

def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> Tuple[List[List[float]], List[List[str]]]:
"""Prepares text batches with one-time validation of batch size.
Batch size varies between GCP regions and individual project quotas.
# Returns embeddings of the first text batch that went through,
# and text batches for the rest of the texts.
"""
from google.api_core.exceptions import InvalidArgument

batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# If batch size if less or equal to one that went through before,
# then keep batches as they are.
if len(batches[0]) <= self.instance["min_good_batch_size"]:
return [], batches
with self.instance["lock"]:
# If largest possible batch size was validated
# while waiting for the lock, then check for rebuilding
# our batches, and return.
if self.instance["batch_size_validated"]:
if len(batches[0]) <= self.instance["batch_size"]:
return [], batches
else:
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
had_failure = False
while True:
try:
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
break
except InvalidArgument:
had_failure = True
first_batch_len = len(first_batch)
if first_batch_len == self.instance["min_batch_size"]:
raise
first_batch_len = max(
self.instance["min_batch_size"], int(first_batch_len / 2)
)
first_batch = first_batch[:first_batch_len]
first_batch_len = len(first_batch)
self.instance["min_good_batch_size"] = max(
self.instance["min_good_batch_size"], first_batch_len
)
# If had a failure and recovered
# or went through with the max size, then it's a legit batch size.
if had_failure or first_batch_len == self.instance["max_batch_size"]:
self.instance["batch_size"] = first_batch_len
self.instance["batch_size_validated"] = True
# If batch size was updated,
# rebuild batches with the new batch size
# (texts that went through are excluded here).
if first_batch_len != self.instance["max_batch_size"]:
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
# Still figuring out max batch size.
batches = batches[1:]
# Returning embeddings of the first text batch that went through,
# and text batches for the rest of texts.
return first_result, batches

def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
]
] = None,
) -> List[List[float]]:
"""Embed a list of strings. Vertex AI currently
sets a max batch size of 5 strings.
"""Embed a list of strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
embeddings_task_type: [str] optional embeddings task type,
one of the following
RETRIEVAL_QUERY - Text is a query
in a search/retrieval setting.
RETRIEVAL_DOCUMENT - Text is a document
in a search/retrieval setting.
SEMANTIC_SIMILARITY - Embeddings will be used
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for batch in range(0, len(texts), batch_size):
text_batch = texts[batch : batch + batch_size]
embeddings_batch = self.client.get_embeddings(text_batch)
embeddings.extend([el.values for el in embeddings_batch])
if len(texts) == 0:
return []
embeddings: List[List[float]] = []
first_batch_result: List[List[float]] = []
if batch_size > 0:
# Fixed batch size.
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
else:
# Dynamic batch size, starting from 250 at the first call.
first_batch_result, batches = self._prepare_and_validate_batches(
texts, embeddings_task_type
)
# First batch result may have some embeddings already.
# In such case, batches have texts that were not processed yet.
embeddings.extend(first_batch_result)
tasks = []
for batch in batches:
tasks.append(
self.instance["task_executor"].submit(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
)
)
if len(tasks) > 0:
wait(tasks)
for t in tasks:
embeddings.extend(t.result())
return embeddings

def embed_documents(
self, texts: List[str], batch_size: int = 0
) -> List[List[float]]:
"""Embed a list of documents.
Args:
texts: List[str] The list of texts to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Expand All @@ -52,5 +328,5 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
embeddings = self.client.get_embeddings([text])
return embeddings[0].values
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]
14 changes: 12 additions & 2 deletions libs/community/tests/integration_tests/embeddings/test_vertexai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Test Vertex AI API wrapper.
In order to run this test, you need to install VertexAI SDK
In order to run this test, you need to install VertexAI SDK
pip install google-cloud-aiplatform>=1.35.0
Your end-user credentials would be used to make the calls (make sure you've run
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
from langchain_community.embeddings import VertexAIEmbeddings
Expand All @@ -24,6 +24,16 @@ def test_embedding_query() -> None:
assert len(output) == 768


def test_large_batches() -> None:
documents = ["foo bar" for _ in range(0, 251)]
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
model_uscentral1.embed_documents(documents)
model_asianortheast1.embed_documents(documents)
assert model_uscentral1.instance["batch_size"] >= 250
assert model_asianortheast1.instance["batch_size"] < 50


def test_paginated_texts() -> None:
documents = [
"foo bar",
Expand Down

0 comments on commit 11fda49

Please sign in to comment.