In [None]:
import os
os.environ["TRANSFORMERS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/datasets"
os.environ["BALAUR_CACHE"] = "/network/scratch/m/mirceara/.cache/balaur"

In [None]:
import sys
sys.path.append("/home/mila/m/mirceara/balaur/experiments/pretrain/")
from run_bort import MlmModel, MlmModelConfig, MlmWnreDataModule, MlmWnreDataModuleConfig, WNRE

In [None]:
import json
from pathlib import Path
from tqdm import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import transformers as tr
from pytorch_lightning import seed_everything

## Setup

### Model Eval

In [None]:
class EvalModel:
    @classmethod
    def load_bort(cls, ckpt_path: Path, device: str = 'cuda', **kwargs):
        em = cls()
        ckpt_path = str(ckpt_path.absolute())
        ckpt = torch.load(ckpt_path)
        config = MlmModelConfig(**kwargs)
        config.__dict__.update(ckpt['hyper_parameters'])
        em.device = torch.device(device)
        em.model_name = ckpt_path
        em.model = MlmModel(config).to(em.device)
        em.model.load_state_dict(ckpt['state_dict'])
        em.step = ckpt['global_step']
        em.mask_id = em.model.bort_config.mask_token_id
        return em
        
    @classmethod
    def load_hf(cls, model_name: str = 'bert-base-uncased', device: str = 'cuda'):
        em = cls()
        em.model_name = model_name
        em.device = torch.device(device)
        em.model = tr.AutoModelForMaskedLM.from_pretrained(model_name).to(em.device)
        em.step = 0
        em.mask_id = tr.AutoTokenizer.from_pretrained(model_name).mask_token_id
        return em
    
    @torch.inference_mode()
    def run_eval_batch(self, batch):
        batch = dict(
            input_ids=batch['input_ids'].to(self.device),
            labels=batch['labels'].to(self.device),
        )
        if type(self.model) == MlmModel:
            src = self.model(batch)
            mask_unmasked = batch['input_ids'].view(-1) == self.mask_id
            mlm_src = src.view(-1, src.shape[-1])[mask_unmasked]
            labels = batch['labels'].view(-1)[mask_unmasked]
            logits = self.model.head(mlm_src)
            loss = F.cross_entropy(logits, labels)
            mlm_loss = loss.item()
            mrr = self.compute_mrr(logits, labels)
        
        else:
            out = self.model(**batch)
            mlm_loss = out.loss.item()
            mask_unmasked = batch['input_ids'].view(-1) == self.mask_id
            logits = out.logits
            logits = logits.view(-1, logits.shape[-1])[mask_unmasked]
            labels = batch['labels'].view(-1)[mask_unmasked]
            mrr = self.compute_mrr(logits, labels)
            
        return mlm_loss, mrr
        
    @torch.inference_mode()
    def compute_mrr(self, logits, labels):
        assert len(labels.shape) == 1 and len(logits.shape) == 2, "logits and labels must be flattened"
        assert labels.shape[0] == logits.shape[0], "logits must be subsampled to match number of labels"
        target_logits = torch.gather(logits, -1, labels.unsqueeze(-1))
        ranks = (logits >= target_logits).sum(dim=-1)
        mrr = (1 / ranks).mean().item()
        return mrr
        

### Dataset

In [None]:
BSZ = 512
dc = MlmWnreDataModuleConfig(
    dataset_name='wikibooks',
    tokenizer='roberta-base',
    wnre=True,
    per_device_bsz=BSZ,
    num_dataloader_workers=1,
    complete_docs=True,
)
dm = MlmWnreDataModule(dc)
dm.prepare_data()
dm.setup()
vdl = dm.val_dataloader()

## MLM Only performance

In [None]:
SEED = 42
MAX_STEP = 25_000


MODEL_NAME = "mlm_only"
ckpt_dir = f"/home/mila/m/mirceara/scratch/.cache/balaur/runs/{MODEL_NAME}/balaur/{MODEL_NAME}/checkpoints/"
dm.config.wnre_only_mask = False
seed_everything(seed=SEED)
vdl = dm.val_dataloader()
mlm_dir = Path(ckpt_dir).parent / "mlm_eval"
mlm_dir.mkdir(exist_ok=True)
for ckpt in sorted(Path(ckpt_dir).glob("*step=*.ckpt"), key=os.path.getmtime):
    step = int(ckpt.name.split("step=")[1].split(".")[0])
    print(ckpt)
    model = EvalModel.load_bort(ckpt)
    step = model.step
    losses = []
    mrrs = []
    for batch in tqdm(vdl):
        with torch.amp.autocast("cuda"):
            mlm_loss, mrr = model.run_eval_batch(batch)
        losses.append(mlm_loss)
        mrrs.append(mrr)
    print(np.mean(losses))
    (mlm_dir / f"mlm_{step}").write_text(json.dumps(losses))
    (mlm_dir / f"mrr_{step}").write_text(json.dumps(mrrs))
    del model

## MLM+WNRE performance

In [None]:
SEED = 42
MAX_STEP = 25_000


MODEL_NAME = "mlm_wnre"
ckpt_dir = f"/home/mila/m/mirceara/scratch/.cache/balaur/runs/{MODEL_NAME}/balaur/{MODEL_NAME}/checkpoints/"
dm.config.wnre_only_mask = False
seed_everything(seed=SEED)
vdl = dm.val_dataloader()
mlm_dir = Path(ckpt_dir).parent / "mlm_eval"
mlm_dir.mkdir(exist_ok=True)
for ckpt in sorted(Path(ckpt_dir).glob("*step=*.ckpt"), key=os.path.getmtime):
    step = int(ckpt.name.split("step=")[1].split(".")[0])
    print(ckpt)
    model = EvalModel.load_bort(ckpt, wnre_factor=0.75)
    step = model.step
    losses = []
    mrrs = []
    for batch in tqdm(vdl):
        with torch.amp.autocast("cuda"):
            mlm_loss, mrr = model.run_eval_batch(batch)
        losses.append(mlm_loss)
        mrrs.append(mrr)
    print(np.mean(losses))
    (mlm_dir / f"mlm_{step}").write_text(json.dumps(losses))
    (mlm_dir / f"mrr_{step}").write_text(json.dumps(mrrs))
    del model

## Plotting

In [None]:
import wandb
import plotly.graph_objects as go
import plotly.io as pio
from IPython.display import display, HTML
import numpy as np

In [None]:
def errorband_trace(x, y, yhi, ylo, legend: str, rgb='0,100,80', add_errorband: bool = True):
    trace = [
        go.Scatter(
            name=legend,
            x=x,
            y=y,
            line=dict(color=f'rgb({rgb})'),
            mode='lines'
        ),
    ]
    if add_errorband:
        trace.append(go.Scatter(
            x=x+x[::-1], # x, then x reversed
            y=yhi+ylo[::-1], # upper, then lower reversed
            fill='toself',
            fillcolor=f'rgba({rgb},0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))
    return trace

In [None]:
def get_step(p: Path):
    return int(p.name.split("_")[-1])

def get_mrr_mean_std(d):
    steps = []
    mrr_avg = []
    mrr_hi = []
    mrr_lo = []
    for f in sorted(Path(d).glob("mrr_*"), key=get_step):
        mrrs = json.loads(f.read_text())
        steps.append(get_step(f))
        avg = np.mean(mrrs)
        std = np.std(mrrs)
        mrr_avg.append(avg)
        mrr_hi.append(avg + std)
        mrr_lo.append(avg - std)
    return steps, mrr_avg, mrr_hi, mrr_lo 

def get_mlm_mean_std(d):
    steps = []
    mlm_avg = []
    mlm_hi = []
    mlm_lo = []
    for f in sorted(Path(d).glob("mlm_*"), key=get_step):
        mlms = json.loads(f.read_text())
        steps.append(get_step(f))
        avg = np.mean(mlms)
        std = np.std(mlms)
        mlm_avg.append(avg)
        mlm_hi.append(avg + std)
        mlm_lo.append(avg - std)
    return steps, mlm_avg, mlm_hi, mlm_lo

In [None]:
MODEL1 = "mlm_wnre"
DIR1 = f"/home/mila/m/mirceara/scratch/.cache/balaur/runs/{MODEL1}/balaur/{MODEL1}/mlm_eval"
MODEL2 = "mlm_only"
DIR2 = f"/home/mila/m/mirceara/scratch/.cache/balaur/runs/{MODEL2}/balaur/{MODEL2}/mlm_eval"


x1, y1, y1_hi, y1_lo = get_mlm_mean_std(DIR1)
x2, y2, y2_hi, y2_lo = get_mlm_mean_std(DIR2)

traces = []
traces.extend(
    errorband_trace(x1, y1, y1_hi, y1_lo,"BERT+BALAUR", '65,105,225', add_errorband=False)
)
traces.extend(
    errorband_trace(x2, y2, y2_hi, y2_lo,"BERT (OURS)", '255,127,80', add_errorband=False)
)
fig = go.Figure(traces)

# fig.update_xaxes(type="log")
fig.update_yaxes(type="log")

fig.update_xaxes(title="Training steps")
fig.update_yaxes(title="MLM Loss")
fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='top', xanchor='right', x=0.99,y=0.99))


fig.update_layout(width=500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_traces(line=dict(width=1.5))
display(HTML(fig.to_html()))
pio.write_image(fig, "eval_mlm_loss.pdf", width=1.5*300, height=0.75*300)