In [8]:
import os

import google.generativeai as genai
import pandas as pd
from dotenv import load_dotenv

pd.set_option("display.max_colwidth", None)

load_dotenv()  # API key is stored in .env file

SAMPLE_SIZE = 20
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)

data = pd.read_csv("../data/amazon_products.csv", nrows=SAMPLE_SIZE, usecols=["asin", "title"])
data.head(5)

In [10]:
from chromadb import Documents, EmbeddingFunction, Embeddings
from google.api_core import retry


class GeminiEmbeddingFunction(EmbeddingFunction):
    document_mode = True
    model = "models/text-embedding-004"

    def __call__(self, input: Documents) -> Embeddings:
        if self.document_mode:
            embedding_task = "retrieval_document"
        else:
            embedding_task = "retrieval_query"

        retry_policy = {"retry": retry.Retry(predicate=retry.if_transient_error)}

        response = genai.embed_content(
            model=self.model, content=input, task_type=embedding_task, request_options=retry_policy
        )
        return response["embedding"]


In [11]:
import chromadb

DB_NAME = "products"
DB_PATH = "./chroma"
embed_fn = GeminiEmbeddingFunction()
embed_fn.document_mode = True
chroma_client = chromadb.PersistentClient(path=DB_PATH)
db = chroma_client.get_or_create_collection(name=DB_NAME, embedding_function=embed_fn)

In [12]:
db.add(documents=data.title.tolist(), ids=data.asin.tolist())

In [None]:
embed_fn.document_mode = False

query = ["iphone"]
result = db.query(query_texts=query, n_results=2)
result