In [1]:
import pandas as pd
from tqdm import tqdm
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from metrics import Metric
from rank_fuser import RankFuser
from statistics_computation import compute_statistics

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

bge_reranker = Metric.get_metric("bge_reranker")
# roberta_nli = Metric.get_metric("roberta_nli")

In [2]:
rf_bge = RankFuser(config={"bge_reranker": 1.0})
# rf_roberta = RankFuser(config={"roberta_nli": 1.0})
import random

FILE = "experiments/results/bge_title_search_results.jsonl"
SAMPLE_SIZE = 100

# First, count the total number of lines in the file
with open(FILE, "r") as f:
    total_lines = sum(1 for _ in f)

# Randomly select 100 line indices
random.seed(42)
random_indices = set(random.sample(range(total_lines), SAMPLE_SIZE))

# Read only the selected lines
data = []
with open(FILE, "r") as f:
    for i, line in enumerate(f):
        if i in random_indices:
            row_data = json.loads(line)
            row_data["results"] = row_data["results"][:200]
            data.append(row_data)
        if len(data) == SAMPLE_SIZE:
            break

print(f"Read {len(data)} random lines from the file.")
print(f"Results length: {len(data[0]['results']) if data else 0}")

Read 100 random lines from the file.
Results length: 200


In [3]:
query = pd.Series(data[0]['record'])
results = pd.DataFrame(data[0]['results'])
print(query)
print(results.columns)

source_doi                         10.1146/annurev-astro-081811-125615
sent_original        2. At lower redshifts, the faint-end slope was...
sent_no_cit          2. At lower redshifts, the faint-end slope was...
sent_idx                                                           600
citation_dois                                       ['10.1086/376841']
pubdate                                                       20140801
resolved_bibcodes                              ['2003AJ....126.1607S']
sent_cit_masked      2. At lower redshifts, the faint-end slope was...
expanded_query                                                    None
dtype: object
Index(['text', 'doi', 'pubdate', 'citation_count', 'metric'], dtype='object')


In [4]:
from metrics import Metric
bge_metric = Metric.get_metric("bge_reranker")
scores = bge_metric(query, results)
print(type(scores))
print(len(scores))
print(scores[0])

<class 'pandas.core.series.Series'>
200
-1.468541145324707


In [7]:
reranked = rf_bge.rerank(data)


Reranking results: 100%|██████████| 100/100 [05:51<00:00,  3.52s/it]


In [None]:
stats = compute_statistics(reranked)
print(stats.keys())

In [None]:
stats['hitrate'][99]

In [None]:
# plot hitrate, iou, and recall on the same plot
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(stats['hitrate'], label='Hitrate')
plt.plot(stats['iou'], label='IoU')
plt.plot(stats['recall'], label='Recall')
plt.xlabel('Query')
plt.ylabel('Score')
plt.title('Hitrate, IoU, and Recall')
plt.legend()
plt.show()


In [None]:
original_stats = compute_statistics(data)

In [None]:
# plot original_stats['hitrate'] against stats['hitrate']
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(original_stats['hitrate'], label='Original Hitrate')
plt.plot(stats['hitrate'], label='Reranked Hitrate')
plt.xlabel('Query')
plt.ylabel('Score')
plt.title('Original vs Reranked Hitrate')
plt.legend()
plt.show()
    