# Tuning the Top-k Query parameter

First we set up the environment

In [13]:
from dotenv import load_dotenv
import os
import psycopg
import pandas as pd
import numpy as np
from time import time
import torch
from TextEnrichers import get_enricher, TextEnricher
from database.database import Database
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)
print(os.getenv('DB_PORT'))


5432


In [14]:
# Database setup
db = Database()
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [26]:
data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(10, random_state=42)

def lowest_index_matching_doi(target_doi: str, query_results: list) -> int:
    """
    Returns the first index of the query results where the chunk doi matches the target doi.
    If no match is found, returns -1.
    """
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedding,
    top_k: int,
    ef_search: int,
    table_name: str = 'library',
    target_column: str = 'bge_norm',
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Query
    start = time()
    query_results = db.query_vector_column(
        query_vector=embedding,
        target_column=target_column,
        table_name=table_name,
        top_k=top_k,
        use_index=True,
        ef_search=ef_search
    )
    
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [27]:
def ranks_at_k(
                examples: pd.DataFrame,
                embedder_name: str, 
                enricher_name: str, 
                target_column: str, 
                top_k: int,
                ef_search: 256) -> list[int]:
    """
    Calculates the 'ranks' for a given embedding model and enrichment function.
    The ranks are the indices of the first chunk with the target doi for each example, i.e.
    the lowest k that would retrieve a chunk with the target doi.
    """
    
    # Setup
    print(f"Embedding model: {embedder_name}, Enricher: {enricher_name}")
    embedder = get_embedder(embedder_name, device=device)
    enricher = get_enricher(enricher_name, path_to_data='data/preprocessed/reviews.jsonl')
    
    # Enrich
    texts_with_dois = list(
        examples[['sent_no_cit', 'source_doi']].itertuples(index=False, name=None))
    enriched_texts = enricher.enrich_batch(texts_with_dois)
    embeddings = embedder(enriched_texts)

    # Rank
    ranks = []
    for i in tqdm(range(len(examples))):
        embedding = embeddings[i]
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example,
            embedding=embedding,
            table_name='library',
            target_column=target_column,
            top_k=top_k,
            ef_search=ef_search)
    return ranks

In [28]:
from itertools import product
                      
embedding_models = ['BAAI/bge-small-en']
enrichment_methods = ['identity', 'add_abstract',
                      'add_title', 'add_title_an_abstract']
if device == 'cuda':
    embedding_models += ['bert-base-uncased', 'adsabs/astroBERT']

combos = list(product(embedding_models, enrichment_methods))
print(f"Combos: {combos}")

Combos: [('BAAI/bge-small-en', 'identity'), ('BAAI/bge-small-en', 'add_abstract'), ('BAAI/bge-small-en', 'add_title'), ('BAAI/bge-small-en', 'add_title_an_abstract')]


In [30]:
ranks = ranks_at_k(
    examples=examples,
    embedder_name='BAAI/bge-small-en',
    enricher_name='identity',
    target_column='bge_norm',
    top_k=2_261_334,
    ef_search=20
)

Embedding model: BAAI/bge-small-en, Enricher: identity


ValueError: While enriching example with source doi '10.1016/S1387-6473(99)00004-4', full record not found

In [19]:
ranks

[710023,
 133663,
 1894654,
 1290280,
 -1,
 1187023,
 1418260,
 971105,
 574665,
 1250933,
 1878241,
 1909213,
 693581,
 235882,
 1021555,
 406718,
 1136869,
 1145444,
 302827,
 47560,
 470245,
 1663816,
 1152997,
 1432779,
 575236,
 1469286,
 559218,
 -1,
 -1,
 1367275,
 779756,
 1119317,
 94417,
 968245,
 1651681]

In [23]:
examples.iloc[4]

source_doi                              10.1007/s00159-022-00140-3
sent_original    Clouds with sufficiently short cooling time in...
sent_no_cit      Clouds with sufficiently short cooling time in...
sent_idx                                                       662
citation_dois                             [10.1093/mnras/stab3351]
Name: 6646, dtype: object

In [24]:
examples.iloc[4].sent_original

'Clouds with sufficiently short cooling time in the mixed gas can even grow with time (e.g. Gronke et al. 2022 ). '

In [25]:
research = pd.read_json('data/preprocessed/research.jsonl', lines=True)
research[research.doi == '10.1093/mnras/stab3351']

Unnamed: 0,bibcode,abstract,aff,author,bibstem,doctype,doi,id,keyword,pubdate,title,read_count,reference,data,citation_count,citation,body,dois,loaded_from
27600,2022MNRAS.511..859G,Astrophysical gases are commonly multiphase an...,"[Department of Physics &amp; Astronomy, Johns ...","[Gronke, Max, Oh, S. Peng, Ji, Suoqing, Norman...","[MNRAS, MNRAS.511]",article,10.1093/mnras/stab3351,21012594,"[hydrodynamics, ISM: clouds, ISM: structure, g...",2022-03-00,Survival and mass growth of cold gas in a turb...,51,"[1953ApJ...118..513H, 1956ApJ...124...20S, 196...",,82,"[2021ApJ...923..115M, 2021MNRAS.508.6155W, 202...","1 INTRODUCTION Turbulent, multiphase gases are...","[10.1093/mnras/stab3351, 10.48550/arXiv.2107.1...",data/json/salvaged_articles.json


In [None]:
model_names_to_filenames = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert',
    'adsabs/astroBERT': 'astrobert'
}