# Tuning the Top-k Query parameter

First we set up the environment

In [13]:
from dotenv import load_dotenv
import os
import psycopg
import pandas as pd
import numpy as np
from time import time
import torch
from TextEnrichers import get_enricher, TextEnricher
from database.database import Database
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)
print(os.getenv('DB_PORT'))


5432


In [14]:
# Database setup
db = Database()
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [15]:
data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(30, random_state=42)

def lowest_index_matching_doi(target_doi: str, query_results: list) -> int:
    """
    Returns the first index of the query results where the chunk doi matches the target doi.
    If no match is found, returns -1.
    """
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedding,
    top_k: int,
    ef_search: int,
    table_name: str = 'library',
    target_column: str = 'bge_norm',
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Query
    start = time()
    query_results = db.query_vector_column(
        query_vector=embedding,
        target_column=target_column,
        table_name=table_name,
        top_k=top_k,
        use_index=True,
        ef_search=ef_search
    )
    
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [16]:
def ranks_at_k(embedder_name: str, 
                   enricher_name: str, 
                   target_column: str, 
                   top_k: int,
                   ef_search: 256) -> list[int]:
    """
    Calculates the 'ranks' for a given embedding model and enrichment function.
    The ranks are the indices of the first chunk with the target doi for each example, i.e.
    the lowest k that would retrieve a chunk with the target doi.
    """
    
    # Setup
    print(f"Embedding model: {embedder_name}, Enricher: {enricher_name}")
    embedder = get_embedder(embedder_name, device=device)
    enricher = get_enricher(enricher_name, path_to_data='data/preprocessed/reviews.jsonl')
    
    # Enrich
    texts_with_dois = list(
        examples[['sent_no_cit', 'source_doi']].itertuples(index=False, name=None))
    enriched_texts = enricher.enrich_batch(texts_with_dois)
    embeddings = embedder(enriched_texts)

    # Rank
    ranks = []
    for i in tqdm(range(len(examples))):
        embedding = embeddings[i]
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example,
            embedding=embedding,
            table_name='library',
            target_column=target_column,
            top_k=top_k,
            ef_search=ef_search)
    return ranks

In [17]:
from itertools import product
                      
embedding_models = ['BAAI/bge-small-en']
enrichment_methods = ['identity', 'add_abstract',
                      'add_title', 'add_title_an_abstract']
if device == 'cuda':
    embedding_models += ['bert-base-uncased', 'adsabs/astroBERT']

combos = list(product(embedding_models, enrichment_methods))
print(f"Combos: {combos}")

Combos: [('BAAI/bge-small-en', 'identity'), ('BAAI/bge-small-en', 'add_abstract'), ('BAAI/bge-small-en', 'add_title'), ('BAAI/bge-small-en', 'add_title_an_abstract')]


In [None]:
ranks = ranks_at_k(
    embedder_name='BAAI/bge-small-en',
    enricher_name='identity',
    target_column='bge_norm',
    top_k=2_000_000,
    ef_search=20
)

Embedding model: BAAI/bge-small-en, Enricher: identity


  0%|          | 0/30 [00:00<?, ?it/s]

  Query execution time: 45.11 seconds


  3%|▎         | 1/30 [00:53<25:54, 53.60s/it]

  Query execution time: 35.23 seconds


  7%|▋         | 2/30 [01:37<22:17, 47.78s/it]

  Query execution time: 36.12 seconds


 10%|█         | 3/30 [02:21<20:45, 46.13s/it]

  Query execution time: 32.13 seconds


 13%|█▎        | 4/30 [03:02<19:02, 43.96s/it]

  Query execution time: 31.01 seconds


 17%|█▋        | 5/30 [03:41<17:36, 42.25s/it]

  Query execution time: 29.07 seconds


 20%|██        | 6/30 [04:18<16:14, 40.59s/it]

  Query execution time: 28.12 seconds


 23%|██▎       | 7/30 [04:55<15:07, 39.46s/it]

  Query execution time: 28.85 seconds


 27%|██▋       | 8/30 [05:33<14:12, 38.75s/it]

  Query execution time: 26.99 seconds


 30%|███       | 9/30 [06:08<13:12, 37.72s/it]

  Query execution time: 27.46 seconds


 33%|███▎      | 10/30 [06:44<12:24, 37.24s/it]



In [12]:
ranks

[912610,
 -1,
 -1,
 -1,
 -1,
 616004,
 -1,
 363759,
 983676,
 387492,
 -1,
 101192,
 -1,
 -1,
 918438,
 -1,
 971662,
 -1,
 929161,
 -1,
 814375,
 748254,
 -1,
 276940,
 391234,
 46754,
 -1,
 -1,
 429886,
 665052,
 -1,
 163978,
 126813,
 -1,
 -1,
 -1,
 455377,
 -1,
 -1,
 -1,
 267118,
 -1,
 -1,
 269072,
 -1,
 -1,
 970349,
 -1,
 -1,
 -1,
 -1,
 -1,
 131949,
 255685,
 915977,
 785395,
 830972,
 -1,
 -1,
 -1,
 -1,
 952792,
 -1,
 311983,
 -1,
 -1,
 916618,
 -1,
 -1,
 933916,
 309379,
 864622,
 -1,
 -1,
 -1,
 93847,
 -1,
 -1,
 350856,
 469779,
 674694,
 657160,
 410045,
 -1,
 -1,
 -1,
 301988,
 -1,
 -1,
 -1,
 125965,
 242438,
 400852,
 -1,
 -1,
 712308,
 -1,
 -1,
 574468,
 343457,
 -1,
 -1,
 353884,
 328252,
 759926,
 900356,
 19798,
 462805,
 40503,
 -1,
 210881,
 541533]

In [None]:
some_doi = list(identity_enricher.doi_to_record.keys())[0]
record = identity_enricher.doi_to_record[some_doi]
print(some_doi)
print(record.keys())

In [None]:
ranks

In [None]:
model_names_to_filenames = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert',
    'adsabs/astroBERT': 'astrobert'
}