In [None]:
%pip install chromadb torch transformers sentence_transformers
%pip install llama-index llama-index-llms-anthropic

In [1]:
import chromadb
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
from typing import List, Tuple
from collections import namedtuple
import os
import json
import tqdm
from dotenv import load_dotenv

load_dotenv()

chroma_client = chromadb.Client()
paper_collection = chroma_client.get_or_create_collection(name="papers")
note_collection = chroma_client.get_or_create_collection(name="notes")

QueryData = namedtuple("QueryData", ["query", "source", "query_type"])

  from .autonotebook import tqdm as notebook_tqdm


## Data Loading

In [2]:
papers = []
ids = []
for fn in os.listdir("data/synthetic/strong"):
  with open(f"data/synthetic/strong/{fn}") as f:
    content = json.load(f)
    text = content["input"]
    papers.append(text)
    ids.append(fn[:-5])

In [3]:
model_path = 'Alibaba-NLP/gte-base-en-v1.5'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)

def generate_embeddings(texts: List[str]) -> Tensor:
    # Tokenize the input texts
    batch_dict = tokenizer(texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**batch_dict)
    embeddings = outputs.last_hidden_state[:, 0]
    
    # (Optionally) normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings



In [4]:
# Process papers in batches of 10
batch_size = 10
for batch in range(0, len(papers), batch_size):
    print(f"Processing papers {batch} to {batch+batch_size}")
    batch_papers = papers[batch:batch+batch_size]
    batch_embeddings = generate_embeddings(batch_papers).tolist()
    batch_ids = ids[batch:batch+batch_size]
    paper_collection.add(
        documents=batch_papers,
        embeddings=batch_embeddings,
        ids=batch_ids
    )

Processing papers 0 to 10
Processing papers 10 to 20
Processing papers 20 to 30
Processing papers 30 to 40
Processing papers 40 to 50
Processing papers 50 to 60
Processing papers 60 to 70


In [8]:
# Load queries
queries: List[QueryData] = []
for ref_type in ["weak", "strong", "weakv2"]:
  for fn in os.listdir(f"data/synthetic/{ref_type}"):
    with open(f"data/synthetic/{ref_type}/{fn}") as f:
      content = json.load(f)
      query = content["output"]
      source = fn[:-5]
      queries.append(QueryData(query, source, query_type=ref_type))

## Base Query

In [12]:
correct = 0
scores = {}
for query_data in tqdm.tqdm(queries):
  query = query_data.query
  source = query_data.source
  query_type = query_data.query_type
  query_embedding = generate_embeddings([query]).tolist()
  results = paper_collection.query(query_embeddings=query_embedding, n_results=3)
  ids = results["ids"][0]
  if source in ids:
    correct += 1
    scores[query_type] = scores.get(query_type, 0) + 1
batch_size = 70
print("Top 3 accuracy")
print(f"Accuracy: {correct}/{len(queries)} ({correct / len(queries) * 100:.2f}%)")
score_percentage = {k: v / batch_size * 100 for k, v in scores.items()}
print(f"Scores: {score_percentage}")

100%|██████████| 210/210 [00:39<00:00,  5.26it/s]

Top 3 accuracy
Accuracy: 197/210 (93.81%)
Scores: {'weak': 98.57142857142858, 'strong': 97.14285714285714, 'weakv2': 85.71428571428571}





In [13]:
correct = 0
scores = {}
for query_data in tqdm.tqdm(queries):
  query = query_data.query
  source = query_data.source
  query_type = query_data.query_type
  query_embedding = generate_embeddings([query]).tolist()
  results = paper_collection.query(query_embeddings=query_embedding, n_results=1)
  ids = results["ids"][0]
  if source in ids:
    correct += 1
    scores[query_type] = scores.get(query_type, 0) + 1
batch_size = 70
print("Top 1 accuracy")
print(f"Accuracy: {correct}/{len(queries)} ({correct / len(queries) * 100:.2f}%)")
score_percentage = {k: v / batch_size * 100 for k, v in scores.items()}
print(f"Scores: {score_percentage}")

100%|██████████| 210/210 [00:39<00:00,  5.29it/s]

Top 1 accuracy
Accuracy: 168/210 (80.00%)
Scores: {'weak': 87.14285714285714, 'strong': 90.0, 'weakv2': 62.857142857142854}





In [None]:
model_id = "Alibaba-NLP/gte-large-en-v1.5"


In [None]:
def eval(dataset, model, top_k=5, verbose=False):
  corpus = dataset.corpus
  queries = dataset.queries
  relevant_docs = dataset.relevant_docs

  nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
  index = VectorStoreIndex(
      nodes, embed_model=embed_model, show_progress=True
  )
  retriever = index.as_retriever(similarity_top_k=top_k)

  eval_results = []
  for query_id, query in tqdm(queries.items()):
      retrieved_nodes = retriever.retrieve(query)
      retrieved_ids = [node.node.node_id for node in retrieved_nodes]
      expected_id = relevant_docs[query_id][0]
      is_hit = expected_id in retrieved_ids  # assume 1 relevant doc

      eval_result = {
          "is_hit": is_hit,
          "retrieved": retrieved_ids,
          "expected": expected_id,
          "query": query_id,
      }
      eval_results.append(eval_result)
  return eval_results