In [None]:
import pandas as pd
from tqdm import tqdm
import json
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from core.embedders import Embedder

from core.statistics_computation import compute_statistics

query_embedder = Embedder.create("Qwen/Qwen-Embedder-0.6B", device='mps', normalize=True, for_queries=True)
doc_embedder = Embedder.create("Qwen/Qwen-Embedder-0.6B", device='mps', normalize=True, for_queries=False)

In [2]:
rank_fuser = RankFuser(config={"log_citations": 1.0, "negative_log_years_old": 4.0})

In [5]:
FILE = "experiments/results/bge_title_search_results.jsonl"

# truncate the data to top k=200
data = []
with open(FILE, "r") as f:
    for line in f:
        row_data = json.loads(line)
        row_data["results"] = row_data["results"][:200]
        data.append(row_data)
    # data = [json.loads(line) for line in f]

print(len(data))

14735


In [6]:
reranked = rank_fuser.rerank(data)
print(len(reranked))

Reranking results:  12%|█▏        | 1836/14735 [00:13<01:36, 133.94it/s]


KeyboardInterrupt: 

In [None]:
stats = compute_statistics(reranked)
print(stats.keys())

In [None]:
np.arange(3, -3.5, step=-0.5)

In [7]:
import numpy as np

diffs = np.arange(3.0, -3.1, step=-0.5)
hitrate_matrix = np.zeros((len(diffs), 200))
recall_matrix = np.zeros((len(diffs), 200))
iou_matrix = np.zeros((len(diffs), 200))

for row, diff in enumerate(diffs):
    config = {"log_citations": 4.0 + diff.item(), "negative_log_years_old": 4.0}
    print(f"Preparing rank fusion with: {config}")
    rank_fuser = RankFuser(config=config)
    stats = compute_statistics(rank_fuser.rerank(data))
    hitrate_matrix[row] = stats["hitrate"]
    recall_matrix[row] = stats["recall"]
    iou_matrix[row] = stats["iou"]

Preparing rank fusion with: {'log_citations': 7.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:49<00:00, 135.15it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2540.50it/s]


Preparing rank fusion with: {'log_citations': 6.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:44<00:00, 141.51it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2693.75it/s]


Preparing rank fusion with: {'log_citations': 6.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:43<00:00, 142.91it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2714.01it/s]


Preparing rank fusion with: {'log_citations': 5.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:45<00:00, 139.89it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2688.22it/s]


Preparing rank fusion with: {'log_citations': 5.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:43<00:00, 141.95it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2705.97it/s]


Preparing rank fusion with: {'log_citations': 4.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:44<00:00, 140.86it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2709.12it/s]


Preparing rank fusion with: {'log_citations': 4.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:44<00:00, 140.71it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2641.10it/s]


Preparing rank fusion with: {'log_citations': 3.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:44<00:00, 141.46it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2691.67it/s]


Preparing rank fusion with: {'log_citations': 3.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:43<00:00, 142.16it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2936.84it/s]


Preparing rank fusion with: {'log_citations': 2.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:34<00:00, 156.23it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:04<00:00, 2987.37it/s]


Preparing rank fusion with: {'log_citations': 2.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:33<00:00, 156.96it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:05<00:00, 2906.47it/s]


Preparing rank fusion with: {'log_citations': 1.5, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:34<00:00, 156.09it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:04<00:00, 2955.86it/s]


Preparing rank fusion with: {'log_citations': 1.0, 'negative_log_years_old': 4.0}


Reranking results: 100%|██████████| 14735/14735 [01:33<00:00, 157.20it/s]
Computing statistics: 100%|██████████| 14735/14735 [00:04<00:00, 3003.42it/s]


In [None]:
original_stats = compute_statistics(data)
print(original_stats.keys())

In [None]:
# build DataFrame: rows=diffs, cols=1..200
def plot_stat_comparison(original, reranks, title=""):
    # Original stats is one array so its length is on shape index 0; reranks' length is on shape index 1
    assert original.shape[0] == reranks.shape[1]
    df = pd.DataFrame(reranks, index=labels, columns=np.arange(1, reranks.shape[1] + 1))

    df_long = df.reset_index().melt(id_vars="index", var_name="k", value_name="hitrate").rename(columns={"index": "diff"})
    plt.figure(figsize=(10, 10))

    # Plot original and reranked data
    ax = sns.lineplot(data=df_long, x="k", y="hitrate", hue="diff", legend="full", palette="mako")
    ax.plot(np.arange(1, 201), original, label="original", linewidth=1)

    handles, labels_legend = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels_legend, title="diff", bbox_to_anchor=(1.05, 1), loc=2)

    plt.xlabel("Top-k")
    plt.ylabel("HitRate")
    plt.minorticks_on()
    plt.grid(which="major", linestyle="-", linewidth="0.5", color="black")
    plt.grid(which="minor", linestyle=":", linewidth="0.5", color="gray")
    plt.title(title)
    plt.tight_layout()
    plt.show()

plot_stat_comparison(original_stats['hitrate'], hitrate_matrix, title="HitRate Comparison")
plot_stat_comparison(original_stats['recall'], recall_matrix, title="Recall Comparison")
plot_stat_comparison(original_stats['iou'], iou_matrix, title="IoU Comparison")

### Reranking by age

As seen above, stats decrease with recency (negative log years old). Will reranking by age (older is better) improve stats?

In [None]:
recency_reranker = RankFuser(config={"log_citations": 1.0})
recency_stats = compute_statistics(recency_reranker.rerank(data))

In [None]:
x = np.arange(1, 201)  # Top-k 1..200
plt.figure(figsize=(10, 10))
plt.plot(x, original_stats['iou'], label="Original", linewidth=2)
plt.plot(x, recency_stats['iou'], label="Age reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("IoU")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("IoU Comparison")
plt.show()

plt.figure(figsize=(10, 10))
plt.plot(x, original_stats["recall"], label="Original", linewidth=2)
plt.plot(x, recency_stats["recall"], label="Negative Citation Count reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("Recall")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("Recall Comparison")
plt.show()

plt.figure(figsize=(10, 10))
plt.plot(x, original_stats["hitrate"], label="Original", linewidth=2)
plt.plot(x, recency_stats["hitrate"], label="Citations reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("HitRate")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("HitRate Comparison")
plt.show()

In [None]:
config = {"negative_log_years_old": 1.0}
reality_ranker = RankFuser(config=config)
print(reality_ranker)

In [None]:
reranked = reality_ranker.rerank(data)
statistics = compute_statistics(reranked)

In [None]:
x = np.arange(1, 201)  # Top-k 1..200
plt.figure(figsize=(10, 10))
plt.plot(x, original_stats["iou"], label="Original", linewidth=2)
plt.plot(x, statistics["iou"], label="reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("IoU")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("IoU Comparison")
plt.show()

plt.figure(figsize=(10, 10))
plt.plot(x, original_stats["recall"], label="Original", linewidth=2)
plt.plot(x, statistics["recall"], label="reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("Recall")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("Recall Comparison")
plt.show()

plt.figure(figsize=(10, 10))
plt.plot(x, original_stats["hitrate"], label="Original", linewidth=2)
plt.plot(x, statistics["hitrate"], label="reranked", linewidth=2)
plt.xlabel("Top-k")
plt.ylabel("HitRate")
plt.legend()
plt.grid(True, which="major", linestyle="--", linewidth=0.5)
plt.title("HitRate Comparison")
plt.show()