# Tuning the Top-k Query parameter

First we set up the environment

In [8]:
from dotenv import load_dotenv
import os
import psycopg2
import pandas as pd
import torch
from Enrichers import Enricher, get_enricher
from database.database import DatabaseProcessor
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)

# Database setup
db_params = {
    'dbname': os.getenv('DB_NAME'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'host': os.getenv('DB_HOST'),
    'port': os.getenv('DB_PORT')
}
db = DatabaseProcessor(db_params)
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [9]:
from database.database import SingleQueryResult

data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(100, random_state=42)

def lowest_index_matching_doi(target_doi: str, query_results: list[SingleQueryResult]) -> int:
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedder: Embedder,
    enricher: Enricher,
    table_name: str,
    top_k:
    int,
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Prepare query vector
    enriched_sentence = enricher.enrich(example=example)
    embedding = embedder([enriched_sentence])[0]

    # Query
    query_results = db.query_vector_table(
        table_name=table_name,
        query_vector=embedding,
        metric=metric,
        use_index=True,
        top_k=top_k
    )
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [None]:
embedder = get_embedder(model_name='BAAI/bge-small-en', device=device, normalize=True)
from Enrichers import ENRICHMENT_FN as enrichment_functions

results = {}
for fn in enrichment_functions:
    enricher = get_enricher(fn)
    ranks = []
    for i in tqdm(range(len(examples))):
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example, 
            embedder=embedder, 
            enricher=enricher,
            table_name='bge', 
            top_k=340000)
    results[fn] = ranks
    series = pd.Series(ranks)
    print(f"Enrichment function: {fn}.")
    print(f"Stats: {series.describe()}")
# Save results
df = pd.DataFrame(results)
df.to_csv('ranks.csv', index=False)





  6%|▌         | 6/100 [00:26<06:53,  4.39s/it]

In [None]:
# Try again at 750,000 top k
ranks = []
for i in tqdm(range(len(examples))):
    example = examples.iloc[i]
    ranks += get_ranks(example=example, table_name='bge', top_k=500000)

print(ranks)
len([rank for rank in ranks if rank == -1])

In [None]:
# Try again at 100,000 top k
ranks = []
for i in tqdm(range(len(examples))):
    example = examples.iloc[i]
    ranks += get_ranks(example=example, table_name='bge', top_k=250000)

print(ranks)

len([rank for rank in ranks if rank == -1])

In [None]:
print(f"Total number of ranks: {len(ranks)}")

In [None]:
import matplotlib.pyplot as plt

#
plt.hist(ranks, bins=100)
plt.show()