In [None]:
import os
os.environ["TRANSFORMERS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/datasets"
os.environ["BALAUR_CACHE"] = "/network/scratch/m/mirceara/.cache/balaur"

In [None]:
import sys
sys.path.append("/home/mila/m/mirceara/balaur/experiments/pretrain/")
from run_bort import MlmModel, MlmModelConfig, MlmWnreDataModule, MlmWnreDataModuleConfig, WNRE

In [None]:
from itertools import chain
import pandas as pd
from pathlib import Path
from tqdm import tqdm, trange
import json


import numpy as np
import torch
import torch.nn.functional as F
import datasets as ds
import transformers as tr

from typing import List, Union

pd.set_option('display.max_columns', None)

## Load data

In [None]:
TEXT_COL = 'masked_text'
LABEL_COL = 'masked_tokens'
REL_COL = "context_token"

tokenized = {}
tokenizers = ['roberta-base', 'bert-base-uncased']
setting = 'ctx'
tokenized_categories = {}
for t in tokenizers:
    tokenized[t] = {}
    tokenized_categories[t] = {}
    for task in ['hypernym', 'hyponym']:
        d = ds.load_from_disk(f"preprocessed/hypcc_{t}_{task}_singular_singular_{setting}")
        tokenized_categories[t][task] = list(sorted(set(chain.from_iterable(d['labels']))))
        tokenized[t][task] = d

## Setup

### Model loading

In [None]:
def load_model(ckpt_dir: Path, step: int, **config_kwargs):
    # get checkpoint
    p = list(ckpt_dir.glob(f"*step={step}.ckpt"))
    assert len(p) == 1, f"Unresolvable paths: {p}"
    p = p[0]

    # load model and config
    # HACK: load model by reconstructing config from ckpt
    #       for quick and dirty evaluation of trained models.
    ckpt = torch.load(p)
    config = MlmModelConfig()
    config.__dict__.update(ckpt['hyper_parameters'])
    config.__dict__.update(config_kwargs)
    model = MlmModel(config=config)
    model.load_state_dict(ckpt['state_dict'])
    return model

### Metrics tracking

In [None]:
class HypccMetrics:
    def __init__(self):
        self.open_acc1 = []
        self.open_acc5 = []
        self.open_rank = []
        self.closed_acc1 = []
        self.closed_acc5 = []
        self.closed_rank = []
        self.top10_labels = []
        self.top10_probs = []
        self.nll = []
        self.labels = []
        self.prompts = []
        self.rep_prob = []

    def update(self,
               logits: torch.Tensor,    # (bsz,vsz)
               related: List[int],      # (bsz,)
               labels: List[List[int]], # (bsz, nlabels), nlabels varies within a batch
               closed_vocab: List[int],
               ):
        
        nll_tensor = -F.log_softmax(logits, dim=-1)
        nlls = []
        closed_ranks = []
        open_ranks = []
        for bidx, vidxs in enumerate(labels):
            closed_ranks.append([])
            open_ranks.append([])
            nlls.append([])
            rank_scores = logits[bidx].clone()
            rank_scores[vidxs] = float('-inf')
            for vidx in vidxs:
                nlls[-1].append(nll_tensor[bidx][vidx].item())
                label_gt = logits[bidx][vidx] < rank_scores
                open_ranks[-1].append(int(label_gt.sum() + 1))
                closed_ranks[-1].append(int(label_gt[closed_vocab].sum() + 1))
        
        probs = torch.exp(-nll_tensor)
        top10 = probs.topk(10, dim=1)
        self.top10_probs += top10[0].tolist()
        self.top10_labels += top10[1].tolist()
        for bidx, vidx in enumerate(related):
            self.rep_prob.append(probs[bidx][vidx].item())
        
        
        self.nll += nlls
        self.closed_rank += closed_ranks
        self.open_rank += open_ranks
        self.closed_acc1 += [[int(r == 1) for r in rs] for rs in closed_ranks]
        self.closed_acc5 += [[int(r <= 5) for r in rs] for rs in closed_ranks]
        self.open_acc1 += [[int(r == 1) for r in rs] for rs in open_ranks]
        self.open_acc5 += [[int(r <= 5) for r in rs] for rs in open_ranks]
        self.labels += labels
        self.prompts += related

### Eval loop

In [None]:
def labels_collate(batch):
    tensor_batch = [{k: x[k] for k in ['input_ids', 'attention_mask']} for x in batch]
    out = torch.utils.data.dataloader.default_collate(tensor_batch)
    out['labels'] = [x['labels'] for x in batch]
    out[REL_COL] = [x[REL_COL] for x in batch]
    return out

def eval_loop(*, 
              model_name: str,
              model: Union[MlmModel, tr.AutoModel],
              tokenizer: tr.AutoTokenizer, 
              dataset: ds.Dataset, 
              categories: List[int], 
              bsz: int,
              device: torch.device = None):
    # setup
    device = device or torch.device('cpu')
    model.to(device)
    dl = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, drop_last=False, collate_fn=labels_collate)
    metrics = HypccMetrics()
    
    with torch.inference_mode():
        for batch in tqdm(dl, total=len(dl)):
            labels = batch.pop('labels')
            related = batch.pop(REL_COL)
            batch = {k:v.to(device) for k,v in batch.items()}
            mask_idx = batch['input_ids'] == tokenizer.mask_token_id
            hidden_states = model(batch)
            embeds = hidden_states[mask_idx]
            logits = model.head(embeds)
            metrics.update(
                logits,
                related,
                labels,
                categories
            )
    model.to(torch.device('cpu'))
    return metrics 

def hf_bert_eval_loop(*, 
              model_name: str,
              model: tr.AutoModel,
              tokenizer: tr.AutoTokenizer, 
              dataset: ds.Dataset, 
              categories: List[int], 
              bsz: int,
              device: torch.device = None):
    # setup
    device = device or torch.device('cpu')
    model.to(device)
    dl = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, drop_last=False, collate_fn=labels_collate)
    metrics = HypccMetrics()
    
    with torch.inference_mode():
        for batch in tqdm(dl, total=len(dl)):
            labels = batch.pop('labels')
            related = batch.pop(REL_COL)
            batch = {k:v.to(device) for k,v in batch.items()}
            mask_idx = batch['input_ids'] == tokenizer.mask_token_id
            out = model(**batch, output_hidden_states=False)        
            metrics.update(
                out.logits[mask_idx],
                related,
                labels,
                categories
            )
    model.to(torch.device('cpu'))
    return metrics 

## Evaluation

In [None]:
bsz = 512
step = 25000
device = torch.device('cuda')

get_ckpt_dir = lambda m: Path(f"/home/mila/m/mirceara/scratch/.cache/balaur/runs/{m}/balaur/{m}/checkpoints/")

models = [
    ("mlm_wnre", "roberta-base", eval_loop),
    ("mlm_only", "roberta-base", eval_loop),
#     ("roberta-base", "roberta-base", hf_bert_eval_loop),
#     ("roberta-large", "roberta-base", hf_bert_eval_loop),
#     ("bert-base-uncased", "bert-base-uncased", hf_bert_eval_loop),
    ("bert-large-uncased", "bert-base-uncased", hf_bert_eval_loop),

]
TOKENIZERS = ['bert-base-uncased', 'roberta-base']
tknzrs = {t: tr.AutoTokenizer.from_pretrained(t) for t in TOKENIZERS}


metrics = {}
for m, t, eval_fn in models:
    metrics[m] = {}
    if eval_fn == hf_bert_eval_loop:
        model = tr.AutoModelForMaskedLM.from_pretrained(m)
    else:
        model = load_model(get_ckpt_dir(m), step)
    print(m)
    for task, dataset in tokenized[t].items():
        with torch.amp.autocast("cuda"):
            metrics[m][task] = eval_fn(
                model_name=m,
                model=model,
                tokenizer=tknzrs[t],
                dataset=dataset,
                categories=tokenized_categories[t][task],
                bsz=bsz,
                device=device,
            )        

In [None]:
results_df = {}
for task in ['hypernym', 'hyponym']:
    df_rows = {}
    for model in metrics.keys():
        _metrics = metrics[model][task]
        closed_mrr = [np.mean([1/x for x in xx]) for xx in _metrics.closed_rank]
        open_mrr = [np.mean([1/x for x in xx]) for xx in _metrics.open_rank]
        nll = [np.mean(xx) for xx in _metrics.nll]
        prob = [np.sum(np.exp(np.negative(xx))) for xx in _metrics.nll]
        col = dict(
            closed_mrr = np.mean(closed_mrr),
            open_mrr = np.mean(open_mrr),
            closed_acc1 = np.mean([np.mean(xx) for xx in _metrics.closed_acc1]),
            closed_acc5 = np.mean([np.mean(xx) for xx in _metrics.closed_acc5]),
            open_acc1 = np.mean([np.mean(xx) for xx in _metrics.open_acc1]),
            open_acc5 = np.mean([np.mean(xx) for xx in _metrics.open_acc5]),
            nll = np.mean(nll),
            prob = np.mean(prob),
        )
        # formatting
        for k in col.keys():
            if any([x in k for x in ['acc', 'prob', 'mrr']]):
                col[k] *= 100
            col[k] = f"{col[k]:.3f}"
        df_rows[model] = col
    results_df[task] = pd.DataFrame.from_dict(df_rows)
    
for k, df in results_df.items():
    display(k)
    display(df.transpose())

## Analysis: Repetition Rate

In [None]:
def get_repetitions(_metrics: HypccMetrics, m: str, t: str, task: str = 'hypernym'):
    preds = [x[0] for x in _metrics.top10_labels]
    ctxts = _metrics.prompts
    is_repeat = [ctxt==pred for ctxt, pred in zip(ctxts, preds)]
    return is_repeat

def get_repetition_rate(*args, **kwargs):
    is_repeat = get_repetitions(*args, **kwargs)
    return sum(is_repeat) / len(is_repeat)

for task in ['hypernym', 'hyponym']:
    repetition_df = pd.DataFrame()
    for m,t,_ in models:
        repetition_df[m] = [f"{(100*get_repetition_rate(metrics[m][task], m, t, task)):.2f}%"]
    print(task)
    display(repetition_df.transpose())

## Analysis: Predictions

In [None]:
def convert_token(token_id: int, m: str):
    if m.startswith('roberta') or m.startswith('mlm'):
        token = tknzrs['roberta-base'].convert_ids_to_tokens(token_id)
        if token[0] == 'Ġ':
            return token[1:]
        else:
            return 'Ġ'+token
    elif m.startswith('bert'):
        token = tknzrs['bert-base-uncased'].convert_ids_to_tokens(token_id)
        return token

In [None]:
task = 'hypernym'
context_word = 'church'

m = 'mlm_wnre'
t = 'roberta-base'
print(m)
for i,preds in enumerate(metrics[m][task].top10_labels):
    if convert_token(metrics[m][task].prompts[i], t) != context_word:
        continue
    x = tokenized[t][task]['masked_text'][i]
    probs = metrics[m][task].top10_probs[i]
    toks = [convert_token(p, m) for p in preds]
    tok_probs = [f"{t} ({100*p:.2f})" for  t,p in zip(toks, probs)]
    print(x)
    print("\\textsc{Bert}\t\t&\t" + " &\t".join(toks[:5]) + " \\\\")
    print("\\textsc{+Balaur}\t&\t" + " &\t".join([f"${100*p:.2f}$" for  p in probs][:5]) + " \\\\")
    
    
m = 'mlm_only'
t = 'roberta-base'
print("\n")
print(m)
for i,preds in enumerate(metrics[m][task].top10_labels):
    if convert_token(metrics[m][task].prompts[i], t) != context_word:
        continue
    x = tokenized[t][task]['masked_text'][i]
    probs = metrics[m][task].top10_probs[i]
    toks = [convert_token(p, m) for p in preds]
    tok_probs = [f"{t} ({100*p:.2f})" for  t,p in zip(toks, probs)]
    print(x)
    print("\\textsc{Bert}\t\t&\t" + " &\t".join(toks[:5]) + " \\\\")
    print("\t\t\t&\t" + " &\t".join([f"${100*p:.2f}$" for  p in probs][:5]) + " \\\\")
    
    labels =  metrics[m][task].labels[i]
    labels = [convert_token(l, m) for l in labels]

print()
print("labels")
print(", ".join(labels))