In [9]:
import pandas as pd
from Rerankers import entailment_ranker
from Embedders import get_embedder
from database.database import Database

In [10]:
examples = pd.read_json('data/dataset/100/nontrivial.jsonl', lines=True)
db = Database()
db.test_connection()

bge_embedder = get_embedder('BAAI/bge-small-en', device='mps', normalize=True)
roberta_reranker = entailment_ranker(model_name="cross-encoder/nli-roberta-base", device='mps')



Database         User             Host                             Port            
citeline_db      bbasseri         localhost                        5432            
Database version: ('PostgreSQL 17.3 (Homebrew) on x86_64-apple-darwin23.6.0, compiled by Apple clang version 16.0.0 (clang-1600.0.26.6), 64-bit',)


In [11]:
def get_index_of_target(query_results, target_doi: str) -> int:
    """
    Get the index of the target DOI in the query results.
    :param query_results: The query results from the database.
    :param target_doi: The DOI of the target paper.
    :return: The index of the target DOI in the query results.
    """
    for i, result in enumerate(query_results):
        if result.doi == target_doi:
            return i
    return -1

In [29]:
initial_results = []
for index, example in examples.iterrows():
    print(index)
    if index == 15:
        break

    input_sentence = example['sent_no_cit']
    target_doi = example['citation_dois'][0]
    pubdate = example['pubdate']

    embedding = bge_embedder(input_sentence)
    query_results = db.query_vector_column(
        query_vector=embedding,
        table_name='lib',
        target_column='bge_norm',
        pubdate=pubdate,
        top_k=10_000,
        probes=30,
        explain=False,
    )

    target_index = get_index_of_target(query_results, target_doi)
    initial_results.append({'example_index': index, 'target_rank': target_index, "query_results": query_results})


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15


In [32]:
from time import time


# Iterate over all the initial results
for i, result in enumerate(initial_results):
    print(i)
    if i == 15:
        break

    # Make sure you have a valid target to rerank
    if result['target_rank'] < 0:
        print(f"Index {i} didn't retrieve its target DOI at this top k/probes level")
        continue

    target_doi = examples.iloc[i].citation_dois[0]
    input_sentence = examples.iloc[i].sent_no_cit
    query_results = result['query_results']
    reranked_query_results = query_results.copy()
    start = time()
    reranked_query_results.sort(key=lambda q: roberta_reranker(q.chunk, input_sentence), reverse=True)
    end = time()
    print(f"Reranking took {end - start} seconds")
    reranked_rank = get_index_of_target(reranked_query_results, target_doi)

    result['rerank'] = reranked_rank

0
Index 0 didn't retrieve its target DOI at this top k/probes level
1
Reranking took 332.82916021347046 seconds
2
Reranking took 322.01701307296753 seconds
3
Reranking took 403.6238911151886 seconds
4
Reranking took 489.87950110435486 seconds
5
Reranking took 428.43002486228943 seconds
6
Reranking took 390.4660520553589 seconds
7
Reranking took 443.9020359516144 seconds
8
Index 8 didn't retrieve its target DOI at this top k/probes level
9
Reranking took 473.3110861778259 seconds
10
Index 10 didn't retrieve its target DOI at this top k/probes level
11
Index 11 didn't retrieve its target DOI at this top k/probes level
12
Reranking took 510.01005816459656 seconds
13
Reranking took 385.608925819397 seconds
14
Index 14 didn't retrieve its target DOI at this top k/probes level


In [33]:
diffs = []
for result in initial_results:
    if 'rerank' in result:
        print(f"Original rank: {result['target_rank']}, Reranked rank: {result['rerank']}")
        diffs.append(result['target_rank'] - result['rerank'])
print(f"Average rank improvement: {sum(diffs) / len(diffs)}")

Original rank: 510, Reranked rank: 340
Original rank: 6319, Reranked rank: 220
Original rank: 1893, Reranked rank: 3304
Original rank: 0, Reranked rank: 27
Original rank: 24, Reranked rank: 236
Original rank: 0, Reranked rank: 120
Original rank: 480, Reranked rank: 80
Original rank: 459, Reranked rank: 185
Original rank: 5, Reranked rank: 3
Original rank: 1, Reranked rank: 40
Average rank improvement: 513.6
