In [1]:
import dotenv
import os
from google import genai
from google.genai import types
from google.api_core import retry
from chromadb import Documents, EmbeddingFunction, Embeddings
import chromadb
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import numpy as np
import pandas as pd
import pytrec_eval

  from tqdm.autonotebook import tqdm


In [2]:
# Retry helper (enables auto-retry if RPM exceeded)
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})
genai.models.Models.generate_content = retry.Retry(
    predicate=is_retriable)(genai.models.Models.generate_content)

# Set up client object, name model to use
dotenv.load_dotenv()
client = genai.Client(api_key=os.getenv('GOOGLE_API_KEY'))
model = 'models/text-embedding-004'

In [3]:
# Download smallish NFCorpus dataset of questions and document text
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip"
data_path = util.download_and_unzip(url, "datasets")

# Corpus of text chunks, text queries and “gold” set of query to relevant documents dict
corpus, queries, qrels = GenericDataLoader("datasets/nfcorpus").load(split="test")
doc_ids, docs = zip(*[(doc_id, doc['text']) for doc_id, doc in corpus.items()])
q_ids, questions = zip(*[(q_id, q) for q_id, q in queries.items()])

100%|██████████| 3633/3633 [00:00<00:00, 121791.20it/s]


In [None]:
# Original embedding function from paper - uses vertexai
# vertexai.init(project="PROJECT_ID", location="LOCATION")
# model = TextEmbeddingModel.from_pretrained("text-embedding-005")
# def embed_text(texts, model, task, batch_size=5):
#     embed_mat = np.zeros((len(texts), 768))
#     for batch_start in range(0, len(texts), batch_size):
#          size = min(len(texts) - batch_start, batch_size)
#          inputs = [TextEmbeddingInput(texts[batch_start+i], task_type=task) 
#                    for i in range(size)]
#          embeddings = model.get_embeddings(inputs)
#          for i in range(size):
#              embed_mat[batch_start + i, :] = embeddings[i].values
#     return embed_mat
# embed_text(docs, model, "RETRIEVAL_DOCUMENT")

# My failed attempt to do it with Gemini
# def embed_text(texts, task, batch_size=5):
#     embed_mat = np.zeros((len(texts), 768))
#     for batch_start in range(0, len(texts), batch_size):
#         print(f'\rBatch Start: {batch_start}', end='')
#         try:
#             resp = client.models.embed_content(model='text-embedding-004',
#                                                contents=texts[batch_start: batch_start+batch_size],
#                                                config=types.EmbedContentConfig(task_type=task))
#         except Exception as e:
#             if 'RESOURCE_EXHAUSTED' in str(e):
#                 time.sleep(60)
#             resp = client.models.embed_content(model='text-embedding-004',
#                                                contents=texts[batch_start: batch_start+batch_size],
#                                                config=types.EmbedContentConfig(task_type=task))
#         embed_mat[batch_start: batch_start+batch_size] = [i.values for i in resp.embeddings]
#     return embed_mat
# embed_text(docs, "RETRIEVAL_DOCUMENT")

# # Using faiss to find k nearest neighbors - kept crashing the kernel FSR
# doc_embeddings = embed_text(docs, model, "RETRIEVAL_DOCUMENT")
# index = faiss.IndexFlatL2(doc_embeddings.shape[1])
# index.add(doc_embeddings)
# # Example look up example query to find relevant doc - note using 'RETRIEVAL_QUERY'
# example_embed = embed_text(['Is Caffeinated Tea Really Dehydrating?'],
# model, 'RETRIEVAL_QUERY')
# s,q = index.search(example_embed,1)
# print(f'Score: {s[0][0]:.2f}, Text: "{docs[q[0][0]]}"')

In [111]:
# Gemini + ChromaDB
class GeminiEmbeddingFunction(EmbeddingFunction):
    def __init__(self, client, model, task):
        self.client = client
        assert model in [i.name for i in self.client.models.list() 
                         if 'embedContent' in i.supported_actions]
        self.model = model
        assert task.upper() in ['RETRIEVAL_DOCUMENT', 'RETRIEVAL_QUERY']
        self.task = task.upper()

    @retry.Retry(predicate=is_retriable)
    def __call__(self, input: Documents) -> Embeddings:
        response = self.client.models.embed_content(
            model=self.model,
            contents=input,
            config=types.EmbedContentConfig(task_type=self.task)
        )
        return [e.values for e in response.embeddings]

def embed_text(texts, model, task, batch_size=100):
    embed_fn = GeminiEmbeddingFunction(client, model, task)
    chroma_client = chromadb.Client()
    if 'vdb' in [i.name for i in chroma_client.list_collections()]:
        chroma_client.delete_collection('vdb')    
    db = chroma_client.get_or_create_collection(name='vdb', embedding_function=embed_fn)
    for i in range(0, len(texts), batch_size):
        end_i = min(i+batch_size, len(texts))
        db.add(documents=list(texts[i:end_i]), 
               ids=[str(j) for j in range(i, end_i)])
        print(f'\rProgress: {end_i} of {len(texts)} texts', end='')
    print()
    return db.get(include=['embeddings'])['embeddings'].astype(np.float32)

def knn(emb_query, emb_docs, k=5, batch_size=100):
    res_dists = np.full((emb_query.shape[0], k), np.nan, dtype=np.float32)
    res_ix = -np.ones_like(res_dists, dtype=int)
    ixr = np.repeat(np.expand_dims(np.arange(batch_size), 1), k, axis=1)
    for i in range(0, emb_query.shape[0], batch_size):
        end_i = min(i+batch_size, emb_query.shape[0])
        d2 = np.sum((np.expand_dims(emb_query[i: end_i], 1)
                     - np.expand_dims(emb_docs, 0)) ** 2, 
                    axis=2)
        ixc = np.argsort(d2, axis=1)[:, :k]
    
        res_dists[i: end_i, :ixc.shape[1]] = np.sqrt(d2[ixr[:(end_i-i)], ixc])
        res_ix[i: end_i, :ixc.shape[1]] = ixc    
        print(f'\rProgress: {end_i} of {emb_query.shape[0]} queries', end='')
    print()
    return res_dists, res_ix

In [55]:
# Embed the documents and queries jointly using different models
doc_embeddings = embed_text(docs, model, 'RETRIEVAL_DOCUMENT')
example_embed = embed_text(['Is Caffeinated Tea Really Dehydrating?'], model, 'RETRIEVAL_QUERY')
s, q = knn(example_embed, doc_embeddings, k=1)
print(f'Score: {s[0][0]:.2f}, Text: "{docs[q[0][0]]}"')

Score: 0.48, Text: "There is a belief that caffeinated drinks, such as tea, may adversely affect hydration. This was investigated in a randomised controlled trial. Healthy resting males (n 21) were recruited from the general population. Following 24 h of abstention from caffeine, alcohol and vigorous physical activity, including a 10 h overnight fast, all men underwent four separate test days in a counter-balanced order with a 5 d washout in between. The test beverages, provided at regular intervals, were 4 × 240 ml black (i.e. regular) tea and 6 × 240 ml black tea, providing 168 or 252 mg of caffeine. The controls were identical amounts of boiled water. The tea was prepared in a standardised way from tea bags and included 20 ml of semi-skimmed milk. All food taken during the 12 h intervention period was controlled, and subjects remained at rest. No other beverages were offered. Blood was sampled at 0, 1, 2, 4, 8 and 12 h, and a 24 h urine sample was collected. Outcome variables were w

In [None]:
# Embed all queries to evaluate quality compared to "gold" answers
query_embeddings = embed_text(questions, model, "RETRIEVAL_QUERY")
q_scores, q_doc_ids = knn(query_embeddings, doc_embeddings, k=10)

In [308]:
# Pytrec evaluation by query
search_qrels = {q_ids[i]: {doc_ids[_id]: -s.item() 
                           for _id, s in zip(q_doc_ids[i], q_scores[i])}
                for i in range(len(q_ids))}
evaluator = pytrec_eval.RelevanceEvaluator(qrels, 
                                           {'ndcg_cut.10', 'P_1', 'recall_10'})
eval_results = pd.DataFrame.from_dict(evaluator.evaluate(search_qrels), orient='index')
eval_results

Unnamed: 0,P_1,recall_10,ndcg_cut_10
PLAIN-2,1.0,0.333333,0.796312
PLAIN-12,0.0,0.100000,0.253469
PLAIN-23,1.0,0.055556,0.389724
PLAIN-33,1.0,0.093750,0.376728
PLAIN-44,0.0,0.068966,0.380096
...,...,...,...
PLAIN-3432,0.0,0.047619,0.147829
PLAIN-3442,0.0,0.034483,0.046914
PLAIN-3452,0.0,0.176471,0.302274
PLAIN-3462,1.0,0.109375,0.700095


In [312]:
# Pytrec evaluation averages
eval_results.mean()
#P_1 0.517028 // precision@1
#recall_10 0.203507 // recall@10
#ndcg_cut_10 0.402624 // nDCG@10

P_1            0.504644
recall_10      0.200019
ndcg_cut_10    0.403161
dtype: float64

In [314]:
# (Verify understanding of measures)
q = 'PLAIN-320'

# Precision@1
qid, sc = zip(*search_qrels[q].items())
rel_id = qrels[q].keys()
p_at_1 = float(qid[np.argmax(sc)] in rel_id)

# Recall@10
rec_at_10 = sum([i in rel_id for i in qid]) / len(rel_id)

# nDCG@10
relsc_actual_topn = [qrels[q].get(i, 0) for i in np.array(qid)[np.argsort([-i for i in sc])]]
relsc_ideal_topn = np.flip(np.sort(list(qrels[q].values()))[-10:])
denom = np.log2(np.arange(10) + 2)
dgc = np.sum(relsc_actual_topn / denom)
idgc = np.sum(relsc_ideal_topn / denom)
ndgc_at_10 = (dgc / idgc).item()

p_at_1, rec_at_10, ndgc_at_10

(0.0, 0.0, 0.0)

In [338]:
# Look at question, ideal response and actual response
print("QUESTION:")
print(queries[q])
print()
print("IDEAL RESPONSES:")
print("(Relevancy score == 2)")
for i in [i for i in qrels[q].keys() if qrels[q][i] == 2]:
    print(i)
    print(corpus[i]['text'])
print("(Relevancy score == 1)")
for i in [i for i in qrels[q].keys() if qrels[q][i] == 1]:
    print(i)
    print(corpus[i]['text'])
print()
print("ACTUAL RESPONSES:")
for i in qid:
    print(i)
    print(corpus[i]['text'])

QUESTION:
Breast Cancer and Diet

IDEAL RESPONSES:
(Relevancy score == 2)
(Relevancy score == 1)
MED-4650
Aromatase is a cytochrome P450 enzyme (CYP19) and is the rate limiting enzyme in the conversion of androgens to estrogens. Suppression of in situ estrogen production through aromatase inhibition is the current treatment strategy for hormone-responsive breast cancers. Drugs that inhibit aromatase have been developed and are currently utilized as adjuvant therapy for breast cancer in post-menopausal women with hormone dependent breast cancer. Natural compounds have been studied extensively for important biologic effects such as antioxidant, anti-tumor and anti-viral effects. A significant number of studies have also investigated the aromatase inhibitory properties of a variety of plant extracts and phytochemicals. The identification of natural compounds that inhibit aromatase could be useful both from a chemopreventive standpoint and in the development of new aromatase inhibitory dru