# Tuning the Top-k Query parameter

First we set up the environment

In [1]:
from dotenv import load_dotenv
import os
import psycopg2
import pandas as pd
from time import time
import torch
from Enrichers import Enricher, get_enricher, ENRICHMENT_FN
from database.database import DatabaseProcessor
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)
print(os.getenv('DB_PORT'))


  from .autonotebook import tqdm as notebook_tqdm


8265


In [2]:
# Database setup
db_params = {
    'dbname': os.getenv('DB_NAME'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'host': os.getenv('DB_HOST'),
    'port': os.getenv('DB_PORT')
}
db = DatabaseProcessor(db_params)
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         holy8a30112.rc.fas.harvard.edu   8265            
Database version: ('PostgreSQL 17.0 (Debian 17.0-1.pgdg120+1) on x86_64-pc-linux-gnu, compiled by gcc (Debian 12.2.0-14) 12.2.0, 64-bit',)
Using device: cuda


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [3]:
from database.database import SingleQueryResult

data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(100, random_state=42)
examples.head()

Unnamed: 0,source_doi,sent_original,sent_no_cit,sent_idx,citation_dois
9902,10.1007/s00159-013-0064-5,CGRO observations of X-ray binaries detected n...,CGRO observations of X-ray binaries detected n...,821,[10.1086/305746]
1808,10.1016/S1387-6473(99)00004-4,By the time a 0.5 M ⊙ disk has dwindled to a t...,By the time a 0.5 M ⊙ disk has dwindled to a t...,77,[10.1086/115385]
16785,10.1146/annurev.aa.31.090193.001135,"New York: Springer-Verlag (1991) Aitken, D. K....","New York: Aitken, D. K., Smith, C. H., James,...",420,[10.1086/115938]
34992,10.1016/j.newar.2023.101684,", 1988 and the spectroscopic mass ratio from R...",", 1988 and the spectroscopic mass ratio from R...",468,[10.1086/300933]
6646,10.1007/s00159-022-00140-3,Clouds with sufficiently short cooling time in...,Clouds with sufficiently short cooling time in...,662,[10.1093/mnras/stab3351]


In [4]:
def lowest_index_matching_doi(target_doi: str, query_results: list[SingleQueryResult]) -> int:
    """
    Returns the first index of the query results where the chunk doi matches the target doi.
    If no match is found, returns -1.
    """
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedder: Embedder,
    enricher: Enricher,
    table_name: str,
    top_k: int,
    ef_search: int,
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Prepare query vector
    enriched_sentence = enricher.enrich(example=example)
    embedding = embedder([enriched_sentence])[0]

    # Query
    start = time()
    query_results = db.query_vector_table(
        table_name=table_name,
        query_vector=embedding,
        metric=metric,
        use_index=True,
        top_k=top_k,
        ef_search=ef_search
    )
    
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [8]:
embedder = get_embedder('bert-base-uncased', device=device)
enricher = get_enricher('identity')
embeddings = embedder(examples.sent_no_cit.tolist())

In [9]:
for embedding in embeddings:
    start = time()
    query_results = db.query_vector_table(
        table_name='bert_hnsw',
        query_vector=embedding,
        metric='vector_cosine_ops',
        top_k=50,
        ef_search=50)
    end = time()
    print(f"Time take: {end - start:.3f}")

  Query execution time: 26.62 seconds
  Query fetch time: 0.00 seconds
Time take: 26.627
  Query execution time: 12.43 seconds
  Query fetch time: 0.00 seconds
Time take: 12.433
  Query execution time: 12.75 seconds
  Query fetch time: 0.00 seconds
Time take: 12.763


KeyboardInterrupt: 

In [9]:
from itertools import product

embedding_models = ['BAAI/bge-small-en']
if device == 'cuda':
    embedding_models += ['bert-base-uncased', 'adsabs/astroBERT']
enrichers = ENRICHMENT_FN.keys()
combos = list(product(embedding_models, enrichers))

model_names_to_tables = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert_hnsw',
    'adsabs/astroBERT': 'astrobert_hnsw'
}

# Dict mapping "model name": pd.DataFrame() of results (enrichment functions used -> columns)
results = {model: pd.DataFrame() for model in embedding_models}
top_k=50

for embedding_model, enricher_name in combos:
    print(f"Embedding model: {embedding_model}, Enricher: {enricher_name}")
    embedder = get_embedder(embedding_model, device=device)
    enricher = get_enricher(enricher_name)
    table_name = model_names_to_tables[embedding_model]
    ranks = []
    for i in tqdm(range(len(examples))):
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example, 
            embedder=embedder, 
            enricher=enricher,
            table_name=table_name,
            top_k=top_k,
            ef_search=top_k)
    series = pd.Series(ranks)
    results[embedding_model][enricher_name] = series



Embedding model: BAAI/bge-small-en, Enricher: identity


  1%|          | 1/100 [00:28<47:14, 28.63s/it]

  Query execution time: 21.89 seconds
  Query fetch time: 0.00 seconds


  2%|▏         | 2/100 [00:39<30:06, 18.44s/it]

  Query execution time: 11.17 seconds
  Query fetch time: 0.00 seconds


  3%|▎         | 3/100 [00:51<24:43, 15.29s/it]

  Query execution time: 11.52 seconds
  Query fetch time: 0.00 seconds


  4%|▍         | 4/100 [01:02<21:49, 13.64s/it]

  Query execution time: 11.09 seconds
  Query fetch time: 0.00 seconds


  5%|▌         | 5/100 [01:14<20:41, 13.07s/it]

  Query execution time: 12.05 seconds
  Query fetch time: 0.00 seconds


  5%|▌         | 5/100 [01:25<27:09, 17.15s/it]

KeyboardInterrupt



In [8]:
model_names_to_filenames = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert',
    'adsabs/astroBERT': 'astrobert'
}

# Save the results to CSV files
for model_name, df in results.items():
    filename = f"tests/{model_names_to_filenames[model_name]}_ranks.csv"
    df.to_csv(filename, index=False)
    print(f"Saved results for {model_name} to {filename}")

Saved results for BAAI/bge-small-en to tests/bge_ranks.csv
