In [1]:
%load_ext autoreload
%autoreload 2

from boolrank import *
from my_processing import paths_to_dataset
import numpy as np

loss = "siglip"
# loss = "clip"

# model = DualSiglip2Model('BAAI/llm-embedder', loss)
model = DualSiglip2Model('BAAI/bge-small-en-v1.5', loss)
# model = DualSiglip2Model('dmis-lab/biobert-v1.1', loss)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
batch_size = 2
# epochs = 5 * batch_size
epochs = 10
lr = 1e-7
eval_batch = 30
power = 1

bool_key = "bool_query"
nl_key = "nl_query"
qual_key = "quality"

data_path = "training"
path = "data/{}.jsonl"
pubmed = path.format(data_path)
TAR = path.format("TAR_data")
sysrev = path.format("sysrev_conv")
train_sources = ['pubmed-searchrefiner']
# train_sources += ['pubmed-query', 'raw-jsonl']
dataset = paths_to_dataset([pubmed, TAR, sysrev],
                           test_only_sources=['TAR', 'sysrev'],
                           train_sources=train_sources)

print(dataset)
weights = np.array(dataset["train"][qual_key])**power
lr_n = "" if lr == 1e-7 else f"lr{lr:.0E}_"
b_n = "" if batch_size == 2 else f"b{batch_size}_"
pow_n = "" if power == 1 else f"^{power}"
data_n = '_'.join(k[:10] for k in np.unique(dataset['train']['source']))

model_name = model.model_name.split("/")[-1]
model_path = f"{loss}/{model_name}/{b_n + lr_n}({data_n}){pow_n}"
print(model_path)

In [None]:
import os
import re
from transformers import Trainer, TrainingArguments
from transformers.utils.notebook import NotebookProgressCallback
from custom_trainer import NotebookProgressCallbackNoTable, WandbCallbackAveraged
from evaluation import compute_metrics
from torch.utils.data import WeightedRandomSampler

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

class WeightedTrainer(Trainer):
    def get_train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
        )

os.environ["WANDB_PROJECT"] = "Boolean-Ranking"
os.environ["WANDB_LOG_MODEL"] = "false"

# epochs = 10
training_args = TrainingArguments(
    output_dir="models/" + model_path,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    learning_rate=lr,
    save_steps=1000,
    save_total_limit=1,
    remove_unused_columns=False,
    bf16=True,
    optim="adamw_bnb_8bit",
    logging_steps=100,
    eval_steps=200,
    eval_strategy="steps",
    eval_on_start=True,
    per_device_eval_batch_size=eval_batch,
    run_name=model_path,
    # max_steps=1000,
)

def collate_fn(batch):
    return {
        "in_bool": [ex[bool_key] for ex in batch],
        # "in_text": [re.sub("\[.*?\]", "", ex[nl_key]) for ex in batch],
        "in_text": [ex[nl_key] for ex in batch],
    }

# trainer = Trainer(
trainer = WeightedTrainer(
    model,
    training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

trainer.remove_callback(NotebookProgressCallback)
trainer.add_callback(NotebookProgressCallbackNoTable)
trainer.add_callback(WandbCallbackAveraged)

trainer.evaluate(dataset["test"]["TAR"])
# trainer.train()
# try: trainer.train(resume_from_checkpoint=True)
# except: trainer.train(resume_from_checkpoint=False)

In [None]:
tar = dataset["test"]["TAR"]
model(tar[bool_key], tar[nl_key], False)

In [8]:
from pathlib import Path
from evaluation import evaluate_on_generated

DIR = Path("data") / "combined_outputs"
evaluate_on_generated(model)
evaluate_on_generated(model, ["prompt_id", "id"])

Unnamed: 0,model,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,HuggingfaceH4,0.195,0.533,4.462,5.0
1,gpt-3.5-turbo-0125,0.292,0.485,4.9,5.0
2,gpt-3.5-turbo-1106,0.197,0.554,5.625,6.0
3,gpt-4-1106-preview,0.597,0.331,3.324,3.0
4,gpt-4o-mini,0.435,0.439,5.05,5.0
5,meta-llama,0.232,0.556,5.513,6.0
6,mistralai,0.255,0.483,4.205,4.0
7,o1-2024-12-17,-0.069,0.652,4.436,5.0
8,open-mistral-7b,0.285,0.534,4.333,5.0
9,open-mixtral-8x7b,0.173,0.578,4.692,5.0




Unnamed: 0,id,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,1,0.259,0.519,5.714,5.0
1,2,0.092,0.635,4.0,4.0
2,3,-0.036,0.666,6.714,7.0
3,4,0.12,0.643,5.0,5.0
4,6,0.543,0.343,6.143,7.0
5,7,-0.131,0.687,6.286,7.0
6,8,-0.164,0.732,5.143,6.0
7,10,-0.286,0.791,7.0,7.0
8,11,-0.066,0.677,5.571,6.0
9,12,-0.144,0.745,4.833,5.5


In [10]:
model.load(r"models\clip\bge-small-en-v1.5\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-11288\model.safetensors")
# model.load(r"models\clip\biobert-v1.1\b16_lr1E-05_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-14110\model.safetensors")
evaluate_on_generated(model)
evaluate_on_generated(model, ["prompt_id", "id"])

Unnamed: 0,model,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,HuggingfaceH4,0.008,0.636,4.462,5.0
1,gpt-3.5-turbo-0125,-0.001,0.631,4.9,5.0
2,gpt-3.5-turbo-1106,0.076,0.606,5.625,6.0
3,gpt-4-1106-preview,0.145,0.553,3.324,3.0
4,gpt-4o-mini,0.005,0.64,5.05,5.0
5,meta-llama,0.01,0.653,5.513,6.0
6,mistralai,-0.023,0.648,4.205,4.0
7,o1-2024-12-17,-0.074,0.688,4.436,5.0
8,open-mistral-7b,0.027,0.625,4.333,5.0
9,open-mixtral-8x7b,-0.008,0.647,4.692,5.0




Unnamed: 0,id,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,1,0.039,0.619,5.714,5.0
1,2,0.241,0.52,4.0,4.0
2,3,0.024,0.658,6.714,7.0
3,4,-0.189,0.734,5.0,5.0
4,6,0.288,0.449,6.143,7.0
5,7,-0.257,0.742,6.286,7.0
6,8,0.27,0.56,5.143,6.0
7,10,-0.19,0.717,7.0,7.0
8,11,-0.117,0.718,5.571,6.0
9,12,-0.386,0.81,4.833,5.5


In [13]:
model.load(r"models\clip\bge-small-en-v1.5\b16_(pubmed-que_pubmed-sea_raw-jsonl)^4\checkpoint-11288\model.safetensors")
evaluate_on_generated(model)
evaluate_on_generated(model, ["prompt_id", "id"])

Unnamed: 0,model,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,HuggingfaceH4,0.153,0.56,4.462,5.0
1,gpt-3.5-turbo-0125,0.261,0.497,4.9,5.0
2,gpt-3.5-turbo-1106,0.183,0.58,5.625,6.0
3,gpt-4-1106-preview,0.561,0.33,3.324,3.0
4,gpt-4o-mini,0.371,0.467,5.05,5.0
5,meta-llama,0.134,0.581,5.513,6.0
6,mistralai,0.101,0.581,4.205,4.0
7,o1-2024-12-17,-0.049,0.645,4.436,5.0
8,open-mistral-7b,0.181,0.593,4.333,5.0
9,open-mixtral-8x7b,0.128,0.594,4.692,5.0




Unnamed: 0,id,spearman,norm_offset_sum,avg_queries_per_prompt,med_queries_per_prompt
0,1,0.3,0.459,5.714,5.0
1,2,0.12,0.595,4.0,4.0
2,3,-0.02,0.668,6.714,7.0
3,4,0.057,0.671,5.0,5.0
4,6,0.49,0.396,6.143,7.0
5,7,-0.22,0.735,6.286,7.0
6,8,-0.211,0.74,5.143,6.0
7,10,-0.159,0.729,7.0,7.0
8,11,-0.178,0.708,5.571,6.0
9,12,-0.048,0.694,4.833,5.5


In [None]:
from evaluation import evaluate
# paths = [r"models/siglip2/old/b2-bf-8b/e4", r"models/siglip2/old/b3-bf-8b"]
paths = [None]
amt = eval_batch
format = "pdf"
for path in paths:
    if path is not None:
        model.load(path + "/model.safetensors")
        print(path)
    else: path = model_path

    for key, data in dataset["test"].items():
        res = evaluate(model, data[bool_key][:amt], data[nl_key][:amt], plot=True, title=key)
        res["plot"].savefig(f"models/{path}/test_stats.{format}", format=format)