# Tuning the Top-k Query parameter

First we set up the environment

In [1]:
from dotenv import load_dotenv
import os
import psycopg
import pandas as pd
import numpy as np
from time import time
import torch
from TextEnrichers import get_enricher, TextEnricher
from database.database import Database
from Embedders import Embedder, get_embedder
from tqdm import tqdm

load_dotenv('.env', override=True)
print(os.getenv('DB_PORT'))


5432


In [2]:
# Database setup
db = Database()
db.test_connection()

device = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
print(f"Using device: {device}")

Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)
Using device: mps


## Investigating precision over `k`

For our various embedding models and enrichment strategies, we want to know the smallest `top_k` value that will still retrieve the target reference for a given sentence. 

To investigate this, we'll sample 100 examples from the non-trivial training data. Each example typically has 1-2 target DOI's. For each example, we'll query the database with a large `top_k` parameter to start, so we can be sure the database returns the target references. Then we can ask at what index in the query results does a target DOI first appear. Ideally, the ranks will all be very high, indicated by having *low* indices in the query results. We also expect enriched examples to have their target doi's higher ranked (lower indices).

In [3]:
data = pd.read_json('data/dataset/split/train.jsonl', lines=True)
examples = data.sample(1, random_state=42)

def lowest_index_matching_doi(target_doi: str, query_results: list) -> int:
    """
    Returns the first index of the query results where the chunk doi matches the target doi.
    If no match is found, returns -1.
    """
    for i, result in enumerate(query_results):
        if target_doi == result.doi:
            return i
    return -1


def get_ranks(
    example: pd.Series,
    embedding,
    top_k: int,
    ef_search: int,
    table_name: str = 'library',
    target_column: str = 'bge_norm',
    metric: str = 'vector_cosine_ops'
) -> list[int]:
    target_dois = example['citation_dois']

    # Query
    start = time()
    query_results = db.query_vector_column(
        query_vector=embedding,
        target_column=target_column,
        table_name=table_name,
        top_k=top_k,
        use_index=False,
        ef_search=ef_search
    )
    print(f"Found {len(query_results)} results")
    for result in query_results:
        if result.doi == '10.1086/305746':
            print('FOUND')
    
    ranks = [lowest_index_matching_doi(
        target_doi=doi, query_results=query_results) for doi in target_dois]
    return ranks


def ranks_at_k(
        examples: pd.DataFrame,
        embedder_name: str,
        enricher_name: str,
        target_column: str,
        top_k: int,
        ef_search: 256) -> list[int]:
    """
    Calculates the 'ranks' for a given embedding model and enrichment function.
    The ranks are the indices of the first chunk with the target doi for each example, i.e.
    the lowest k that would retrieve a chunk with the target doi.
    """

    # Setup
    print(f"Embedding model: {embedder_name}, Enricher: {enricher_name}")
    embedder = get_embedder(embedder_name, device=device)
    enricher = get_enricher(
        enricher_name, path_to_data='data/preprocessed/reviews.jsonl')

    # Enrich
    texts_with_dois = list(
        examples[['sent_no_cit', 'source_doi']].itertuples(index=False, name=None))
    enriched_texts = enricher.enrich_batch(texts_with_dois)
    embeddings = embedder(enriched_texts)

    # Rank
    ranks = []
    for i in tqdm(range(len(examples))):
        embedding = embeddings[i]
        example = examples.iloc[i]
        ranks += get_ranks(
            example=example,
            embedding=embedding,
            table_name='library',
            target_column=target_column,
            top_k=top_k,
            ef_search=ef_search)
    return ranks

In [4]:
from itertools import product
                      
embedding_models = ['BAAI/bge-small-en']
enrichment_methods = ['identity', 'add_abstract',
                      'add_title', 'add_title_and_abstract']
if device == 'cuda':
    embedding_models += ['bert-base-uncased', 'adsabs/astroBERT']

combos = list(product(embedding_models, enrichment_methods))
print(f"Combos: {combos}")

Combos: [('BAAI/bge-small-en', 'identity'), ('BAAI/bge-small-en', 'add_abstract'), ('BAAI/bge-small-en', 'add_title'), ('BAAI/bge-small-en', 'add_title_and_abstract')]


In [5]:
rank_data = {}
for enricher_name in ['identity', 'add_abstract', 'add_title', 'add_title_and_abstract']:
    rank_data[enricher_name] = ranks_at_k(
        examples=examples,
        embedder_name='BAAI/bge-small-en',
        enricher_name=enricher_name,
        target_column='bge_norm',
        top_k=2_261_334,
        ef_search=20
    )

Embedding model: BAAI/bge-small-en, Enricher: identity


  0%|          | 0/1 [00:00<?, ?it/s]

  Query execution time: 37.84 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 1/1 [00:48<00:00, 48.30s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_abstract


  0%|          | 0/1 [00:00<?, ?it/s]

  Query execution time: 35.50 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 1/1 [00:45<00:00, 45.27s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title


  0%|          | 0/1 [00:00<?, ?it/s]

  Query execution time: 35.98 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 1/1 [00:46<00:00, 46.11s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title_and_abstract


  0%|          | 0/1 [00:00<?, ?it/s]

  Query execution time: 34.77 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 1/1 [00:44<00:00, 44.96s/it]


In [6]:
print(examples.iloc[0].sent_original)
for enricher_name, ranks in rank_data.items():
    print(f"Enricher: {enricher_name}. Rank: {ranks[0]}")

CGRO observations of X-ray binaries detected non-thermal power-law tails extending well beyond 100 keV with a photon index Γ LE ≈2.5–3 (Grove et al. 1998 ). 
Enricher: identity. Rank: 21
Enricher: add_abstract. Rank: 348
Enricher: add_title. Rank: 1
Enricher: add_title_and_abstract. Rank: 285


In [None]:
new_examples = data.sample(2, random_state=29)

new_rank_data = {}
for enricher_name in ['identity', 'add_abstract', 'add_title', 'add_title_and_abstract']:
    new_rank_data[enricher_name] = ranks_at_k(
        examples=new_examples,
        embedder_name='BAAI/bge-small-en',
        enricher_name=enricher_name,
        target_column='bge_norm',
        top_k=2_261_334,
        ef_search=20
    )

Embedding model: BAAI/bge-small-en, Enricher: identity


  0%|          | 0/2 [00:00<?, ?it/s]

  Query execution time: 34.42 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 1/2 [00:44<00:44, 44.23s/it]

  Query execution time: 34.71 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 2/2 [01:28<00:00, 44.34s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_abstract


  0%|          | 0/2 [00:00<?, ?it/s]

  Query execution time: 34.46 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 1/2 [00:44<00:44, 44.64s/it]

  Query execution time: 34.72 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 2/2 [01:29<00:00, 44.53s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title


  0%|          | 0/2 [00:00<?, ?it/s]

  Query execution time: 34.76 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 1/2 [00:45<00:45, 45.00s/it]

  Query execution time: 34.70 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 2/2 [01:29<00:00, 44.90s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title_and_abstract


  0%|          | 0/2 [00:00<?, ?it/s]

  Query execution time: 34.30 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 1/2 [00:43<00:43, 43.56s/it]

  Query execution time: 35.42 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 2/2 [01:28<00:00, 44.34s/it]

Enricher: identity. Rank: 1159
Enricher: add_abstract. Rank: 164
Enricher: add_title. Rank: 138
Enricher: add_title_and_abstract. Rank: 229





In [8]:
for enricher_name, ranks in new_rank_data.items():
    print(f"Enricher: {enricher_name}. Rank: {ranks}")

Enricher: identity. Rank: [1159, 288, 288]
Enricher: add_abstract. Rank: [164, 2903, 2903]
Enricher: add_title. Rank: [138, 13693, 13693]
Enricher: add_title_and_abstract. Rank: [229, 4420, 4420]


In [9]:
print(new_examples.iloc[1].sent_original)
print(new_examples.iloc[1].citation_dois)

Both Hood et al. ( 2009 ) and MacTaggart and Hood ( 2009 ) reported that a second MFR formed via magnetic reconnection (see below for a detailed description of this process) underneath the bodily emerged MFR. 
['10.1051/0004-6361/200912189', '10.1051/0004-6361/200912189']


In [10]:
more_examples = data.sample(4, random_state=1)
more_rank_data = {}
for enricher_name in ['identity', 'add_abstract', 'add_title', 'add_title_and_abstract']:
    more_rank_data[enricher_name] = ranks_at_k(
        examples=more_examples,
        embedder_name='BAAI/bge-small-en',
        enricher_name=enricher_name,
        target_column='bge_norm',
        top_k=2_261_334,
        ef_search=20
    )

Embedding model: BAAI/bge-small-en, Enricher: identity


  0%|          | 0/4 [00:00<?, ?it/s]

  Query execution time: 34.51 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 25%|██▌       | 1/4 [00:44<02:14, 44.79s/it]

  Query execution time: 35.54 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 2/4 [01:30<01:30, 45.33s/it]

  Query execution time: 36.06 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 75%|███████▌  | 3/4 [02:16<00:45, 45.69s/it]

  Query execution time: 35.24 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 4/4 [03:01<00:00, 45.49s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_abstract


  0%|          | 0/4 [00:00<?, ?it/s]

  Query execution time: 35.42 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 25%|██▌       | 1/4 [00:46<02:18, 46.04s/it]

  Query execution time: 35.21 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 2/4 [01:31<01:31, 45.57s/it]

  Query execution time: 35.37 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 75%|███████▌  | 3/4 [02:17<00:45, 45.66s/it]

  Query execution time: 35.10 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 4/4 [03:02<00:00, 45.60s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title


  0%|          | 0/4 [00:00<?, ?it/s]

  Query execution time: 35.64 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 25%|██▌       | 1/4 [00:45<02:16, 45.36s/it]

  Query execution time: 35.47 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 2/4 [01:30<01:31, 45.51s/it]

  Query execution time: 35.53 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 75%|███████▌  | 3/4 [02:16<00:45, 45.57s/it]

  Query execution time: 35.83 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 4/4 [03:02<00:00, 45.70s/it]


Embedding model: BAAI/bge-small-en, Enricher: add_title_and_abstract


  0%|          | 0/4 [00:00<?, ?it/s]

  Query execution time: 36.36 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 25%|██▌       | 1/4 [00:46<02:19, 46.39s/it]

  Query execution time: 35.10 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 50%|█████     | 2/4 [01:31<01:31, 45.72s/it]

  Query execution time: 35.81 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


 75%|███████▌  | 3/4 [02:17<00:45, 45.68s/it]

  Query execution time: 35.65 seconds
Found 2261334 results
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND
FOUND


100%|██████████| 4/4 [03:02<00:00, 45.58s/it]


In [None]:
for enricher_name, ranks in new_rank_data.items():
    print(f"Enricher: {enricher_name}. Rank: {ranks}")

In [None]:
# write rank_data dict to file
import json
with open('data/rank_data.json', 'w') as f:
    json.dump(rank_data, f)
print('Rank data saved to file')

In [None]:
max(rank_data['identity']), max(rank_data['add_abstract']), max(rank_data['add_title'])

In [None]:
research = pd.read_json('data/preprocessed/research.jsonl', lines=True)
research[research.doi == '10.1093/mnras/stab3351']

In [None]:
model_names_to_filenames = {
    'BAAI/bge-small-en': 'bge',
    'bert-base-uncased': 'bert',
    'adsabs/astroBERT': 'astrobert'
}