# Tuning the Top-k Query parameter

First we set up the environment

In [5]:
from dotenv import load_dotenv
import os
import psycopg2
import pandas as pd
from time import time
import torch
from Enrichers import Enricher, get_enricher
from database.database import DatabaseProcessor
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)

# Database setup
db_params = {
    'dbname': os.getenv('DB_NAME'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'host': os.getenv('DB_HOST'),
    'port': os.getenv('DB_PORT')
}
db = DatabaseProcessor(db_params)
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [6]:
from database.database import SingleQueryResult

data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(100, random_state=42)

In [None]:
def lowest_index_matching_doi(target_doi: str, query_results: list[SingleQueryResult]) -> int:
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedder: Embedder,
    enricher: Enricher,
    table_name: str,
    top_k:
    int,
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Prepare query vector
    enriched_sentence = enricher.enrich(example=example)
    embedding = embedder([enriched_sentence])[0]

    # Query
    query_results = db.query_vector_table(
        table_name=table_name,
        query_vector=embedding,
        metric=metric,
        use_index=True,
        top_k=top_k
    )
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks

In [7]:
embedder = get_embedder(model_name='BAAI/bge-small-en',
                        device=device, normalize=True)
enricher = get_enricher("identity")

embeddings = embedder(enricher.enrich_batch(examples))  # Enrich and embed the entire batch

In [8]:
for embedding in tqdm(embeddings):
    start = time()
    query_result = db.query_vector_table(
        table_name='bge',
        query_vector=embedding,
        metric='vector_cosine_ops',
        use_index=True,
        top_k=1000
    )
    print(f"Query time: {time() - start:.2f} seconds")

  1%|          | 1/100 [00:00<00:48,  2.04it/s]

Query time: 0.49 seconds


  2%|▏         | 2/100 [00:00<00:43,  2.28it/s]

Query time: 0.40 seconds


  3%|▎         | 3/100 [00:01<00:34,  2.84it/s]

Query time: 0.25 seconds


  4%|▍         | 4/100 [00:01<00:37,  2.53it/s]

Query time: 0.46 seconds


  5%|▌         | 5/100 [00:01<00:35,  2.67it/s]

Query time: 0.34 seconds


  6%|▌         | 6/100 [00:02<00:30,  3.09it/s]

Query time: 0.23 seconds


  7%|▋         | 7/100 [00:02<00:28,  3.25it/s]

Query time: 0.27 seconds


  8%|▊         | 8/100 [00:02<00:27,  3.32it/s]

Query time: 0.29 seconds


  9%|▉         | 9/100 [00:03<00:28,  3.21it/s]

Query time: 0.33 seconds


 10%|█         | 10/100 [00:03<00:29,  3.07it/s]

Query time: 0.36 seconds


 11%|█         | 11/100 [00:03<00:26,  3.35it/s]

Query time: 0.24 seconds


 12%|█▏        | 12/100 [00:03<00:25,  3.44it/s]

Query time: 0.27 seconds


 13%|█▎        | 13/100 [00:04<00:27,  3.21it/s]

Query time: 0.36 seconds


 14%|█▍        | 14/100 [00:04<00:25,  3.44it/s]

Query time: 0.24 seconds


 15%|█▌        | 15/100 [00:04<00:25,  3.40it/s]

Query time: 0.30 seconds


 16%|█▌        | 16/100 [00:05<00:26,  3.21it/s]

Query time: 0.35 seconds


 17%|█▋        | 17/100 [00:05<00:23,  3.58it/s]

Query time: 0.20 seconds


 18%|█▊        | 18/100 [00:05<00:24,  3.37it/s]

Query time: 0.34 seconds


 19%|█▉        | 19/100 [00:06<00:24,  3.26it/s]

Query time: 0.33 seconds


 20%|██        | 20/100 [00:06<00:24,  3.21it/s]

Query time: 0.32 seconds


 21%|██        | 21/100 [00:06<00:22,  3.59it/s]

Query time: 0.20 seconds


 22%|██▏       | 22/100 [00:06<00:21,  3.67it/s]

Query time: 0.26 seconds


 23%|██▎       | 23/100 [00:07<00:21,  3.53it/s]

Query time: 0.31 seconds


 24%|██▍       | 24/100 [00:07<00:20,  3.66it/s]

Query time: 0.25 seconds


 26%|██▌       | 26/100 [00:07<00:17,  4.24it/s]

Query time: 0.27 seconds
Query time: 0.15 seconds


 27%|██▋       | 27/100 [00:08<00:19,  3.78it/s]

Query time: 0.33 seconds


 28%|██▊       | 28/100 [00:08<00:19,  3.78it/s]

Query time: 0.26 seconds


 30%|███       | 30/100 [00:08<00:17,  4.01it/s]

Query time: 0.35 seconds
Query time: 0.15 seconds


 31%|███       | 31/100 [00:09<00:16,  4.11it/s]

Query time: 0.23 seconds


 32%|███▏      | 32/100 [00:09<00:17,  3.80it/s]

Query time: 0.31 seconds


 33%|███▎      | 33/100 [00:09<00:20,  3.29it/s]

Query time: 0.40 seconds


 34%|███▍      | 34/100 [00:10<00:18,  3.60it/s]

Query time: 0.22 seconds


 35%|███▌      | 35/100 [00:10<00:17,  3.80it/s]

Query time: 0.23 seconds


 36%|███▌      | 36/100 [00:10<00:16,  3.95it/s]

Query time: 0.23 seconds


 37%|███▋      | 37/100 [00:10<00:18,  3.38it/s]

Query time: 0.39 seconds


 38%|███▊      | 38/100 [00:11<00:20,  3.08it/s]

Query time: 0.39 seconds


 39%|███▉      | 39/100 [00:11<00:19,  3.08it/s]

Query time: 0.32 seconds


 40%|████      | 40/100 [00:11<00:18,  3.25it/s]

Query time: 0.27 seconds


 41%|████      | 41/100 [00:12<00:16,  3.48it/s]

Query time: 0.24 seconds


 42%|████▏     | 42/100 [00:12<00:16,  3.53it/s]

Query time: 0.27 seconds


 43%|████▎     | 43/100 [00:12<00:17,  3.20it/s]

Query time: 0.38 seconds


 44%|████▍     | 44/100 [00:13<00:16,  3.50it/s]

Query time: 0.22 seconds


 45%|████▌     | 45/100 [00:13<00:15,  3.51it/s]

Query time: 0.28 seconds


 46%|████▌     | 46/100 [00:13<00:14,  3.70it/s]

Query time: 0.24 seconds


 47%|████▋     | 47/100 [00:13<00:14,  3.75it/s]

Query time: 0.26 seconds


 48%|████▊     | 48/100 [00:14<00:13,  3.92it/s]

Query time: 0.23 seconds


 49%|████▉     | 49/100 [00:14<00:12,  3.96it/s]

Query time: 0.25 seconds


 50%|█████     | 50/100 [00:14<00:12,  3.90it/s]

Query time: 0.26 seconds


 51%|█████     | 51/100 [00:14<00:12,  3.88it/s]

Query time: 0.26 seconds


 52%|█████▏    | 52/100 [00:15<00:11,  4.14it/s]

Query time: 0.20 seconds


 53%|█████▎    | 53/100 [00:15<00:12,  3.69it/s]

Query time: 0.34 seconds


 54%|█████▍    | 54/100 [00:15<00:12,  3.57it/s]

Query time: 0.30 seconds


 55%|█████▌    | 55/100 [00:15<00:12,  3.71it/s]

Query time: 0.24 seconds


 57%|█████▋    | 57/100 [00:16<00:10,  4.03it/s]

Query time: 0.29 seconds
Query time: 0.18 seconds


 58%|█████▊    | 58/100 [00:16<00:10,  4.01it/s]

Query time: 0.25 seconds


 59%|█████▉    | 59/100 [00:16<00:11,  3.52it/s]

Query time: 0.36 seconds


 60%|██████    | 60/100 [00:17<00:10,  3.69it/s]

Query time: 0.24 seconds


 62%|██████▏   | 62/100 [00:17<00:08,  4.37it/s]

Query time: 0.29 seconds
Query time: 0.11 seconds


 64%|██████▍   | 64/100 [00:18<00:07,  4.57it/s]

Query time: 0.28 seconds
Query time: 0.16 seconds


 65%|██████▌   | 65/100 [00:18<00:08,  4.36it/s]

Query time: 0.25 seconds


 67%|██████▋   | 67/100 [00:18<00:06,  4.75it/s]

Query time: 0.30 seconds
Query time: 0.11 seconds


 69%|██████▉   | 69/100 [00:19<00:06,  4.98it/s]

Query time: 0.19 seconds
Query time: 0.19 seconds


 70%|███████   | 70/100 [00:19<00:06,  4.65it/s]

Query time: 0.25 seconds


 71%|███████   | 71/100 [00:19<00:06,  4.70it/s]

Query time: 0.21 seconds


 73%|███████▎  | 73/100 [00:20<00:05,  4.77it/s]

Query time: 0.23 seconds
Query time: 0.19 seconds


 74%|███████▍  | 74/100 [00:20<00:05,  4.53it/s]

Query time: 0.25 seconds


 76%|███████▌  | 76/100 [00:20<00:04,  4.96it/s]

Query time: 0.26 seconds
Query time: 0.13 seconds


 77%|███████▋  | 77/100 [00:20<00:04,  4.70it/s]

Query time: 0.24 seconds


 78%|███████▊  | 78/100 [00:21<00:04,  4.60it/s]

Query time: 0.23 seconds


 80%|████████  | 80/100 [00:21<00:04,  4.77it/s]

Query time: 0.23 seconds
Query time: 0.18 seconds


 81%|████████  | 81/100 [00:21<00:03,  4.83it/s]

Query time: 0.20 seconds


 82%|████████▏ | 82/100 [00:21<00:03,  4.78it/s]

Query time: 0.21 seconds


 84%|████████▍ | 84/100 [00:22<00:03,  5.04it/s]

Query time: 0.22 seconds
Query time: 0.16 seconds


 86%|████████▌ | 86/100 [00:22<00:02,  5.67it/s]

Query time: 0.17 seconds
Query time: 0.14 seconds


 88%|████████▊ | 88/100 [00:22<00:01,  6.14it/s]

Query time: 0.14 seconds
Query time: 0.15 seconds


 89%|████████▉ | 89/100 [00:23<00:01,  6.13it/s]

Query time: 0.16 seconds


 91%|█████████ | 91/100 [00:23<00:01,  5.61it/s]

Query time: 0.26 seconds
Query time: 0.14 seconds


 92%|█████████▏| 92/100 [00:23<00:01,  5.41it/s]

Query time: 0.20 seconds


 94%|█████████▍| 94/100 [00:24<00:01,  5.54it/s]

Query time: 0.22 seconds
Query time: 0.14 seconds


 96%|█████████▌| 96/100 [00:24<00:00,  6.22it/s]

Query time: 0.09 seconds
Query time: 0.18 seconds


 97%|█████████▋| 97/100 [00:24<00:00,  5.16it/s]

Query time: 0.29 seconds


 98%|█████████▊| 98/100 [00:24<00:00,  5.10it/s]

Query time: 0.20 seconds


 99%|█████████▉| 99/100 [00:25<00:00,  4.96it/s]

Query time: 0.22 seconds


100%|██████████| 100/100 [00:25<00:00,  3.95it/s]

Query time: 0.28 seconds





In [None]:


results = {}
for fn in ['add_headers_and_previous_7_sentences']:
    enricher = get_enricher(fn)
    ranks = []
    for i in tqdm(range(len(examples))):
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example, 
            embedder=embedder, 
            enricher=enricher,
            table_name='bge', 
            top_k=340000)
    results[fn] = ranks
    series = pd.Series(ranks)
    print(f"Enrichment function: {fn}.")
    print(f"Stats: {series.describe()}")
# Save results
df = pd.DataFrame(results)
df.to_csv('ranks3.csv', index=False)





100%|██████████| 100/100 [06:47<00:00,  4.08s/it]

Enrichment function: add_headers_and_previous_7_sentences.
Stats: count      112.000000
mean      3136.544643
std       7961.988777
min          0.000000
25%         38.250000
50%        373.000000
75%       2512.750000
max      52857.000000
dtype: float64





In [4]:
df1 = pd.read_csv('ranks.csv')
df2 = pd.read_csv('ranks2.csv')
df3 = pd.read_csv('ranks3.csv')
df_combined = pd.concat([df1, df2, df3], axis=1)
df_combined.head()

Unnamed: 0,identity,add_abstract,add_title,add_title_and_abstract,add_previous_3_sentences,add_previous_7_sentences,add_headers_and_previous_3_sentences,add_headers_and_previous_7_sentences
0,34,1037,6,545,259,2974,1074,568
1,128,44,599,539,216,230,847,865
2,7565,179,175,18,3,57,13,20
3,1869,235,305,303,629,2489,458,375
4,23,3700,119,2313,651,7159,3887,3080


In [None]:
df = pd.read_csv('ranks.csv')
df.add_title.describe()
df.add_abstract.describe()

In [None]:
# Try again at 100,000 top k
ranks = []
for i in tqdm(range(len(examples))):
    example = examples.iloc[i]
    ranks += get_ranks(example=example, table_name='bge', top_k=250000)

print(ranks)

len([rank for rank in ranks if rank == -1])

In [None]:
print(f"Total number of ranks: {len(ranks)}")

In [None]:
import matplotlib.pyplot as plt

#
plt.hist(ranks, bins=100)
plt.show()