In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd

paths = [
    "data/training.jsonl",
    "data/TAR_data.jsonl",
    "data/sysrev_conv.jsonl",
]
dataset = []
for path in paths:
    df = pd.read_json(path, lines=True)
    dataset.append(df)

dataset = pd.concat(dataset)
# dataset = dataset[dataset["nl_query"] != ""]
dataset

In [None]:
has_mh = dataset[dataset["mission_hash"].isna() | ~dataset["mission_hash"].duplicated(keep='last')]
has_mh = dataset[~dataset["mission_hash"].duplicated(keep='last')]
has_mh

In [None]:
dataset[dataset["nl_query"] == ""]

In [None]:
import torch
from utils.boolrank import DualEncoderModel

model = DualEncoderModel('BAAI/bge-small-en-v1.5')
model.load(r"models\clip\bge-small-en-v1.5\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-11288\model.safetensors")
# model = DualEncoderModel('dmis-lab/biobert-v1.1')
# model.load(r"models\clip\biobert-v1.1\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-14110\model.safetensors")

embeddings = model.encode_bool(df["bool_query"].tolist(), batch_size=200).detach().cpu().numpy()
# embeddings = model.encode_text(words, batch_size=200).detach().cpu().numpy()
torch.cuda.empty_cache()

In [None]:
import wandb
from pathlib import Path

Path("./plots").mkdir(parents=True, exist_ok=True)

biobert = "clip/biobert-v1.1/b4_lr8E-06_(pubmed-sea)"

api = wandb.Api()

def find_runs(names, runs):
    run_map = {run.name: run for run in runs}
    return [run_map[name] for name in names if name in run_map]

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    "siglip/siglip2-base-patch16-224/lr1E-06_(pubmed-sea)^0",
    "siglip/siglip2-base-patch16-224/b4_lr1E-06_(pubmed-sea)^0",
    "clip/siglip2-base-patch16-224/lr5E-07_(pubmed-sea)",
    "clip/siglip2-base-patch16-224/b4_lr2E-07_(pubmed-sea)",
]
labels = [
    "siglip + bs2",
    "siglip + bs4",
    "clip + bs2",
    "clip + bs4",
]

runs = find_runs(names, runs)
for run, label in zip(runs, labels):
    hist: pd.DataFrame = run.history()
    col_name = "eval/mean_recall@1"
    data = hist[col_name].dropna()
    plt.plot(data.index, data, label=label,
             color='blue' if 'sig' in label else 'red',
             linestyle='--' if 'bs4' in label else '-')

plt.title("Recall@1 SigLIP vs CLIP loss")
plt.xlabel("Step")
plt.ylabel("Mean Recall@1")
plt.legend()
plt.savefig("plots/loss_comp.pdf", format="pdf")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    "clip/siglip2-base-patch16-224/lr5E-07_(pubmed-sea)",
    biobert,
    "clip/bert-small/b4_lr5E-06_(pubmed-sea)",
    "clip/bert-mini/b4_lr5E-06_(pubmed-sea)"
]
labels = [
    "siglip",
    "biobert",
    "bert-small",
    "bert-mini",
]
plots = [
    "eval/mean_recall@1",
    "eval/pubmed-searchrefiner_recall@1",
    "eval/TAR_recall@1",
]
titles = [
    "Mean Recall@1",
    "SearchRefiner Recall@1",
    "TAR Recall@1"
]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

runs = find_runs(names, runs)
for pl, ax, title in zip(plots, axs, titles):
    for run, label in zip(runs, labels):
        hist: pd.DataFrame = run.history()
        data = hist[pl].dropna()
        smoothed = data.rolling(window=10, min_periods=1).mean()
        ax.plot(smoothed.index, smoothed, label=label)

    ax.set_title(title)
    ax.set_xlabel("Step")
    if pl == plots[0]:
        ax.set_ylabel("Recall@1")
        ax.legend()

fig.tight_layout()
fig.savefig("plots/modelcomp1.pdf", format="pdf")
fig.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    biobert,
    "clip/llm-embedder/b4_lr5E-06_(pubmed-sea)",
    "clip/bge-small-en-v1.5/b4_lr5E-06_(pubmed-sea)",
]
labels = [
    "biobert",
    "llm-embedder",
    "bge-small-en",
]
plots = [
    "eval/mean_recall@1",
    "eval/pubmed-searchrefiner_recall@1",
    "eval/TAR_recall@1",
]
titles = [
    "Mean Recall@1",
    "SearchRefiner Recall@1",
    "TAR Recall@1"
]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

runs = find_runs(names, runs)
for pl, ax, title in zip(plots, axs, titles):
    for run, label in zip(runs, labels):
        hist: pd.DataFrame = run.history()
        data = hist[pl].dropna()
        smoothed = data.rolling(window=10, min_periods=1).mean()
        ax.plot(smoothed.index, smoothed, label=label)

    ax.set_title(title)
    ax.set_xlabel("Step")
    if pl == plots[0]:
        ax.set_ylabel("Recall@1")
        ax.legend()

fig.tight_layout()
fig.savefig("plots/modelcomp2.pdf", format="pdf")
fig.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    "clip/siglip2-base-patch16-224/lr5E-07_(pubmed-sea)",
    biobert,
    "clip/llm-embedder/b4_lr5E-06_(pubmed-sea)",
    "clip/bge-small-en-v1.5/b4_lr5E-06_(pubmed-sea)",
    "clip/bert-small/b4_lr5E-06_(pubmed-sea)",
    "clip/bert-mini/b4_lr5E-06_(pubmed-sea)",
]
labels = [
    "siglip",
    "biobert",
    "llm-embedder",
    "bge-small-en",
    "small",
    "mini"
]
plots = [
    "eval/mean_recall@1",
    "eval/mean_steps_per_second",
    "eval/TAR_recall@1",
]
titles = [
    "Mean Recall@1",
    "SearchRefiner Recall@1",
    "TAR Recall@1"
]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

runs = find_runs(names, runs)
for pl, ax, title in zip(plots, axs, titles):
    for run, label in zip(runs, labels):
        hist: pd.DataFrame = run.history()
        data = hist[pl].dropna()
        smoothed = data.rolling(window=10, min_periods=1).mean()
        ax.plot(smoothed.index, smoothed, label=label)

    ax.set_title(title)
    ax.set_xlabel("Step")
    if pl == plots[0]:
        ax.set_ylabel("Recall@1")
        ax.legend()

fig.tight_layout()
fig.savefig("plots/modelcomp3.pdf", format="pdf")
fig.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    biobert,
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-sea)no[]",
    "clip/bge-small-en-v1.5/b4_lr5E-06_(pubmed-sea)",
    "clip/bge-small-en-v1.5/b4_lr5E-06_(pubmed-sea)^0no[]"
]
labels = [
    "biobert",
    "biobert + no []-terms",
    "bge-small-en",
    "bge-small-en + no []-terms",
]
plots = [
    "eval/mean_recall@1",
    "eval/pubmed-query_recall@1",
    "eval/TAR_recall@1",
]
titles = [
    "Mean Recall@1",
    "PubMed-Query Recall@1",
    "TAR Recall@1"
]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

runs = find_runs(names, runs)
for pl, ax, title in zip(plots, axs, titles):
    for run, label in zip(runs, labels):
        hist: pd.DataFrame = run.history()
        data = hist[pl].dropna()
        smoothed = data.rolling(window=10, min_periods=1).mean()
        ax.plot(smoothed.index, smoothed, label=label,
                color='blue' if 'bge' in label else 'red',
                linestyle='--' if '[]' in label else '-')

    ax.set_title(title)
    ax.set_xlabel("Step")
    if pl == plots[0]:
        ax.set_ylabel("Recall@1")
        ax.legend()

fig.tight_layout()
fig.savefig("plots/modelcomp4.pdf", format="pdf")
fig.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

runs = api.runs("simon-doehl-ai/Boolean-Ranking")

names = [
    "clip/bge-small-en-v1.5/b4_lr5E-06_(pubmed-sea)^0no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-sea)no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^0no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^2no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^3no[]",
    "clip/biobert-v1.1/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^4no[]",
    "clip/bge-small-en-v1.5/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^0no[]",
    "clip/bge-small-en-v1.5/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)no[]",
    "clip/bge-small-en-v1.5/b4_lr8E-06_(pubmed-que_pubmed-sea_raw-jsonl)^2no[]",
]
labels = [
    "bge-small-en base",
    "biobert base",
    "bio ^0",
    "bio ^1",
    "bio ^2",
    "bio ^3",
    "bio ^4",
    "bge ^0",
    "bge ^1",
    "bge ^2",
]
plots = [
    "eval/mean_recall@1",
    "eval/pubmed-query_recall@1",
    "eval/TAR_recall@1",
]
titles = [
    "Mean Recall@1",
    "PubMed-Query Recall@1",
    "TAR Recall@1"
]
colors = []
colors.append("blue")
colors.append("red")
colors.extend(plt.cm.viridis(np.linspace(0, 1, 5)))
colors.extend(plt.cm.viridis(np.linspace(0, 1, 5)))

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

runs = find_runs(names, runs)
for pl, ax, title in zip(plots, axs, titles):
    for run, label, color in zip(runs, labels, colors):
        hist: pd.DataFrame = run.history()
        data = hist[pl].dropna()
        smoothed = data.rolling(window=10, min_periods=1).mean()
        ax.plot(smoothed.index, smoothed, label=label, color=color,
                linestyle='--' if 'base' in label else '-')

    ax.set_title(title)
    ax.set_xlabel("Step")
    if pl == plots[0]:
        ax.set_ylabel("Recall@1")
        ax.legend()

fig.tight_layout()
fig.savefig("plots/traindatacomp.pdf", format="pdf")
fig.show()