Skip to content

Commit

Permalink
langchain[patch]: Adds progress bar to GooglePalmEmbeddings (#13812)
Browse files Browse the repository at this point in the history
- **Description:** Adds a tqdm progress bar to GooglePalmEmbeddings when
embedding a list.
  - **Issue:** #13637
  - **Dependencies:** TQDM as a main dependency (instead of extra)


Signed-off-by: ugm2 <unaigaraymaestre@gmail.com>

---------

Signed-off-by: ugm2 <unaigaraymaestre@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
  • Loading branch information
ugm2 and hwchase17 committed Nov 29, 2023
1 parent 1cd9d5f commit 9e2ae86
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion libs/langchain/langchain/embeddings/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class GooglePalmEmbeddings(BaseModel, Embeddings):
google_api_key: Optional[str]
model_name: str = "models/embedding-gecko-001"
"""Model name to use."""
show_progress_bar: bool = False
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand All @@ -79,7 +81,20 @@ def validate_environment(cls, values: Dict) -> Dict:
return values

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
if self.show_progress_bar:
try:
from tqdm import tqdm

iter_ = tqdm(texts, desc="GooglePalmEmbeddings")
except ImportError:
logger.warning(
"Unable to show progress bar because tqdm could not be imported. "
"Please install with `pip install tqdm`."
)
iter_ = texts
else:
iter_ = texts
return [self.embed_query(text) for text in iter_]

def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
Expand Down

0 comments on commit 9e2ae86

Please sign in to comment.