In [None]:
%load_ext autoreload
%autoreload 2

from utils.boolrank import DualSiglip2Model

# loss = "siglip"
loss = "clip"

# model = DualSiglip2Model('BAAI/llm-embedder', loss)
model = DualSiglip2Model('BAAI/bge-small-en-v1.5', loss)
# model = DualSiglip2Model('dmis-lab/biobert-v1.1', loss)
# model = DualSiglip2Model('prajjwal1/bert-mini', loss)
# model = DualSiglip2Model('prajjwal1/bert-small', loss)

In [None]:
from utils.my_processing import paths_to_dataset
import numpy as np

batch_size = 128
# epochs = 5 * batch_size
epochs = 30
lr = 1e-6
eval_batch = 30
power = 4

bool_key = "bool_query"
nl_key = "nl_query"
qual_key = "quality"

data_path = "training"
path = "data/{}.jsonl"
pubmed = path.format(data_path)
TAR = path.format("TAR_data")
sysrev = path.format("sysrev_conv")
train_sources = ['pubmed-searchrefiner']
train_sources += ['pubmed-query', 'raw-jsonl']
dataset = paths_to_dataset([pubmed, TAR, sysrev],
                           test_only_sources=['TAR', 'sysrev'],
                           train_sources=train_sources)

print(dataset)
weights = np.array(dataset["train"][qual_key])**power
lr_n = "" if lr == 1e-7 else f"lr{lr:.0E}_"
b_n = "" if batch_size == 2 else f"b{batch_size}_"
pow_n = "" if power == 1 else f"^{power}"
data_n = '_'.join(k[:10] for k in np.unique(dataset['train']['source']))

model_name = model.model_name.split("/")[-1]
model_path = f"{loss}/{model_name}/{b_n + lr_n}({data_n}){pow_n}"
print(model_path)

In [None]:
import os
import torch
from transformers import Trainer, TrainingArguments
from transformers.utils.notebook import NotebookProgressCallback
from utils.custom_trainer import NotebookProgressCallbackNoTable, WandbCallbackAveraged
from utils.evaluation import compute_metrics
from torch.utils.data import WeightedRandomSampler

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

class WeightedTrainer(Trainer):
    def get_train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
        )

os.environ["WANDB_PROJECT"] = "Boolean-Ranking"
os.environ["WANDB_LOG_MODEL"] = "false"

# epochs = 10
training_args = TrainingArguments(
    output_dir="models/" + model_path,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    learning_rate=lr,
    save_steps=1000,
    save_total_limit=1,
    remove_unused_columns=False,
    bf16=True,
    optim="adamw_bnb_8bit",
    logging_steps=100,
    eval_steps=200,
    eval_strategy="steps",
    eval_on_start=True,
    per_device_eval_batch_size=eval_batch,
    run_name=model_path,
    # max_steps=1000,
)

def collate_fn(batch):
    return {
        "in_bool": [ex[bool_key] for ex in batch],
        # "in_text": [re.sub("\[.*?\]", "", ex[nl_key]) for ex in batch],
        "in_text": [ex[nl_key] for ex in batch],
    }

# trainer = Trainer(
trainer = WeightedTrainer(
    model,
    training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

trainer.remove_callback(NotebookProgressCallback)
trainer.add_callback(NotebookProgressCallbackNoTable)
trainer.add_callback(WandbCallbackAveraged)

trainer.train()
# try: trainer.train(resume_from_checkpoint=True)
# except: trainer.train(resume_from_checkpoint=False)

In [None]:
from utils.evaluation import evaluate
# paths = [r"models/siglip2/old/b2-bf-8b/e4", r"models/siglip2/old/b3-bf-8b"]
paths = [None]
amt = eval_batch
format = "pdf"
for path in paths:
    if path is not None:
        model.load(path + "/model.safetensors")
        print(path)
    else: path = model_path

    for key, data in dataset["test"].items():
        res = evaluate(model, data[bool_key][:amt], data[nl_key][:amt], plot=True, title=key)
        res["plot"].savefig(f"models/{path}/test_stats.{format}", format=format)