# Tuning the Top-k Query parameter

First we set up the environment

In [1]:
from dotenv import load_dotenv
import os
import psycopg
import pandas as pd
import numpy as np
from time import time
import torch
from TextEnrichers import get_enricher, TextEnricher
from database.database import Database
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)
print(os.getenv('DB_PORT'))


5432


In [2]:
# Database setup
db = Database()
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [3]:
data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(10, random_state=42)

def lowest_index_matching_doi(target_doi: str, query_results: list) -> int:
    """
    Returns the first index of the query results where the chunk doi matches the target doi.
    If no match is found, returns -1.
    """
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedding,
    top_k: int,
    ef_search: int,
    table_name: str = 'library',
    target_column: str = 'bge_norm',
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Query
    start = time()
    query_results = db.query_vector_column(
        query_vector=embedding,
        target_column=target_column,
        table_name=table_name,
        top_k=top_k,
        use_index=True,
        ef_search=ef_search
    )
    
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [4]:
def ranks_at_k(
                examples: pd.DataFrame,
                embedder_name: str, 
                enricher_name: str, 
                target_column: str, 
                top_k: int,
                ef_search: 256) -> list[int]:
    """
    Calculates the 'ranks' for a given embedding model and enrichment function.
    The ranks are the indices of the first chunk with the target doi for each example, i.e.
    the lowest k that would retrieve a chunk with the target doi.
    """
    
    # Setup
    print(f"Embedding model: {embedder_name}, Enricher: {enricher_name}")
    embedder = get_embedder(embedder_name, device=device)
    enricher = get_enricher(enricher_name, path_to_data='data/preprocessed/reviews.jsonl')
    
    # Enrich
    texts_with_dois = list(
        examples[['sent_no_cit', 'source_doi']].itertuples(index=False, name=None))
    enriched_texts = enricher.enrich_batch(texts_with_dois)
    embeddings = embedder(enriched_texts)

    # Rank
    ranks = []
    for i in tqdm(range(len(examples))):
        embedding = embeddings[i]
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example,
            embedding=embedding,
            table_name='library',
            target_column=target_column,
            top_k=top_k,
            ef_search=ef_search)
    return ranks

In [5]:
from itertools import product
                      
embedding_models = ['BAAI/bge-small-en']
enrichment_methods = ['identity', 'add_abstract',
                      'add_title', 'add_title_an_abstract']
if device == 'cuda':
    embedding_models += ['bert-base-uncased', 'adsabs/astroBERT']

combos = list(product(embedding_models, enrichment_methods))
print(f"Combos: {combos}")

Combos: [('BAAI/bge-small-en', 'identity'), ('BAAI/bge-small-en', 'add_abstract'), ('BAAI/bge-small-en', 'add_title'), ('BAAI/bge-small-en', 'add_title_an_abstract')]


In [6]:
ranks = ranks_at_k(
    examples=examples,
    embedder_name='BAAI/bge-small-en',
    enricher_name='identity',
    target_column='bge_norm',
    top_k=2_261_334,
    ef_search=20
)

Embedding model: BAAI/bge-small-en, Enricher: identity


  0%|          | 0/10 [00:00<?, ?it/s]

  Query execution time: 156.13 seconds


 10%|█         | 1/10 [02:52<25:49, 172.22s/it]

  Query execution time: 138.14 seconds


 20%|██        | 2/10 [05:20<21:04, 158.12s/it]

  Query execution time: 157.61 seconds


 30%|███       | 3/10 [08:09<19:03, 163.33s/it]

  Query execution time: 162.05 seconds


 40%|████      | 4/10 [11:03<16:43, 167.32s/it]

  Query execution time: 148.03 seconds


 50%|█████     | 5/10 [13:42<13:42, 164.47s/it]

  Query execution time: 157.41 seconds


 60%|██████    | 6/10 [16:31<11:04, 166.02s/it]

  Query execution time: 135.40 seconds


 70%|███████   | 7/10 [18:57<07:57, 159.32s/it]

  Query execution time: 136.99 seconds


 80%|████████  | 8/10 [21:26<05:12, 156.06s/it]

  Query execution time: 133.53 seconds


 90%|█████████ | 9/10 [23:50<02:32, 152.22s/it]

  Query execution time: 150.18 seconds


100%|██████████| 10/10 [26:32<00:00, 159.22s/it]


In [None]:
rank_data = {'identity': ranks}
for enricher_name in ['add_abstract', 'add_title', 'add_title_an_abstract']:
    rank_data[enricher_name] = ranks_at_k(
        examples=examples,
        embedder_name='BAAI/bge-small-en',
        enricher_name=enricher_name,
        target_column='bge_norm',
        top_k=2_261_334,
        ef_search=20
    )

Embedding model: BAAI/bge-small-en, Enricher: add_abstract


  0%|          | 0/10 [00:00<?, ?it/s]

  Query execution time: 139.02 seconds


 10%|█         | 1/10 [02:30<22:34, 150.49s/it]

  Query execution time: 120.09 seconds


 20%|██        | 2/10 [04:40<18:28, 138.57s/it]

  Query execution time: 155.48 seconds


 30%|███       | 3/10 [07:27<17:38, 151.23s/it]

  Query execution time: 141.55 seconds


 40%|████      | 4/10 [09:59<15:10, 151.69s/it]

  Query execution time: 189.57 seconds


 50%|█████     | 5/10 [13:19<14:05, 169.12s/it]

  Query execution time: 236.72 seconds


 60%|██████    | 6/10 [17:30<13:07, 196.92s/it]

  Query execution time: 126.54 seconds


 70%|███████   | 7/10 [19:48<08:53, 177.80s/it]

  Query execution time: 114.59 seconds


 80%|████████  | 8/10 [21:53<05:21, 160.81s/it]

  Query execution time: 261.31 seconds


 90%|█████████ | 9/10 [26:25<03:15, 195.58s/it]

  Query execution time: 419.79 seconds


100%|██████████| 10/10 [33:36<00:00, 201.67s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title


  0%|          | 0/10 [00:00<?, ?it/s]

  Query execution time: 451.70 seconds


 10%|█         | 1/10 [07:46<1:09:54, 466.07s/it]

  Query execution time: 160.13 seconds


 20%|██        | 2/10 [10:43<39:29, 296.23s/it]  



In [None]:
# write rank_data dict to file
import json
with open('data/rank_data.json', 'w') as f:
    json.dump(rank_data, f)
print('Rank data saved to file')

In [None]:
examples.iloc[4].sent_original

In [None]:
research = pd.read_json('data/preprocessed/research.jsonl', lines=True)
research[research.doi == '10.1093/mnras/stab3351']

In [None]:
model_names_to_filenames = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert',
    'adsabs/astroBERT': 'astrobert'
}