In [None]:
import os
os.environ["TRANSFORMERS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/datasets"
os.environ["BALAUR_CACHE"] = "/network/scratch/m/mirceara/.cache/balaur"

In [None]:
import sys
sys.path.append("/home/mila/m/mirceara/balaur/experiments/pretrain/")
from run_bort import MlmModel, MlmModelConfig, MlmWnreDataModule, MlmWnreDataModuleConfig, WNRE

In [None]:
from pathlib import Path
import yaml
from itertools import chain
import copy
from functools import lru_cache
import pickle
from collections import defaultdict, Counter
import json

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
import transformers as tr
import datasets as ds
import pandas as pd
from tqdm import tqdm

from balaur import ROOT_DIR

from typing import List

exp_dir = Path("/network/scratch/m/mirceara/.cache/balaur/runs")

In [None]:
%config Completer.use_jedi = False

## Load data

In [None]:
TEXT_COL = 'masked_text'
LABEL_COL = 'masked_tokens'
REL_COL = "context_token"

tokenizer = 'roberta-base'
setting = 'ctx'
tokenized = {}
tokenized_categories = {}
for task in ['hypernym', 'hyponym']:
    d = ds.load_from_disk(f"../5-2_hypcc/preprocessed/hypcc_{tokenizer}_{task}_singular_singular_{setting}")
    tokenized_categories[task] = list(sorted(set(chain.from_iterable(d['labels']))))
    tokenized[task] = d

## Setup

### Model loading

In [None]:
def load_model(ckpt_dir: Path, step: int, **config_kwargs):
    # get checkpoint
    p = list(ckpt_dir.glob(f"*step={step}.ckpt"))
    assert len(p) == 1, f"Unresolvable paths: {p}"
    p = p[0]

    # load model and config
    # HACK: load model by reconstructing config from ckpt
    #       for quick and dirty evaluation of trained models.
    ckpt = torch.load(p)
    config = MlmModelConfig()
    config.__dict__.update(ckpt['hyper_parameters'])
    config.__dict__.update(config_kwargs)
    model = MlmModel(config=config)
    model.load_state_dict(ckpt['state_dict'])
    return model


In [None]:
def freeze_model(model, unfreeze_str: str = "head"):
    for n,p in model.named_parameters():
        if n.split(".")[0] == 'head':
            p.requires_grad = True
            print(f"Unfreezing: {n}")
        else:
            p.requires_grad = False

### Metrics tracking

In [None]:
class HypccMetrics:
    def __init__(self):
        self.open_acc1 = []
        self.open_acc5 = []
        self.open_rank = []
        self.closed_acc1 = []
        self.closed_acc5 = []
        self.closed_rank = []
        self.top10_labels = []
        self.top10_probs = []
        self.nll = []
        self.labels = []
        self.prompts = []
        self.rep_prob = []

    def update(self,
               logits: torch.Tensor,    # (bsz,vsz)
               related: List[int],      # (bsz,)
               labels: List[List[int]], # (bsz, nlabels), nlabels varies within a batch
               closed_vocab: List[int],
               ):
        
        nll_tensor = -F.log_softmax(logits, dim=-1)
        nlls = []
        closed_ranks = []
        open_ranks = []
        for bidx, vidxs in enumerate(labels):
            closed_ranks.append([])
            open_ranks.append([])
            nlls.append([])
            rank_scores = logits[bidx].clone()
            rank_scores[vidxs] = float('-inf')
            for vidx in vidxs:
                nlls[-1].append(nll_tensor[bidx][vidx].item())
                label_gt = logits[bidx][vidx] < rank_scores
                open_ranks[-1].append(int(label_gt.sum() + 1))
                closed_ranks[-1].append(int(label_gt[closed_vocab].sum() + 1))
        
        probs = torch.exp(-nll_tensor)
        top10 = probs.topk(10, dim=1)
        self.top10_probs += top10[0].tolist()
        self.top10_labels += top10[1].tolist()
        for bidx, vidx in enumerate(related):
            self.rep_prob.append(probs[bidx][vidx].item())
        
        
        self.nll += nlls
        self.closed_rank += closed_ranks
        self.open_rank += open_ranks
        self.closed_acc1 += [[int(r == 1) for r in rs] for rs in closed_ranks]
        self.closed_acc5 += [[int(r <= 5) for r in rs] for rs in closed_ranks]
        self.open_acc1 += [[int(r == 1) for r in rs] for rs in open_ranks]
        self.open_acc5 += [[int(r <= 5) for r in rs] for rs in open_ranks]
        self.labels += labels
        self.prompts += related

### Eval loop

In [None]:
def labels_collate(batch):
    tensor_batch = [{k: x[k] for k in ['input_ids', 'attention_mask']} for x in batch]
    out = torch.utils.data.dataloader.default_collate(tensor_batch)
    out[REL_COL] = [x[REL_COL] for x in batch]
    out['labels'] = [x['labels'] for x in batch]
    return out

def eval_loop(*, 
              model: MlmModel,
              tokenizer: tr.AutoTokenizer, 
              dataset: ds.Dataset, 
              categories: List[int], 
              bsz: int,
              device: torch.device = None,
             ):
    # setup
    device = device or torch.device('cpu')
    model.to(device)
    dl = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, drop_last=False, collate_fn=labels_collate)
    metrics = HypccMetrics()
    
    with torch.inference_mode():
        for batch in tqdm(dl, total=len(dl)):
            labels = batch.pop('labels')
            related = batch.pop(REL_COL)
            batch = {k:v.to(device) for k,v in batch.items()}
            mask_idx = batch['input_ids'] == tokenizer.mask_token_id
            hidden_states = model(batch)
            embeds = hidden_states[mask_idx]
            logits = model.head(embeds)

            metrics.update(
                logits.detach().cpu().float(),
                related,
                labels,
                categories
            )
    return metrics 

### oLMpics Finetune Loop

In [None]:
def unpack_multilabel(examples):
    unpacked = {k: [] for k in examples}
    for i in range(len(examples['input_ids'])):
        example = {k: v[i] for k,v in examples.items()}
        labels = example.pop('labels')
        masked_tokens = example.pop('masked_tokens')
        for l,m in zip(labels, masked_tokens):
            for k,v in example.items():
                unpacked[k].append(v)
            unpacked['labels'].append(l)
            unpacked['masked_tokens'].append(m)
            
    return unpacked

def get_train_eval_sets(task: str, seed: int, split=0.8):
    ds_dict = tokenized[task].train_test_split(seed=seed, train_size=split)
    
    train_ds = ds_dict['train']    
    train_ds = train_ds.map(unpack_multilabel, batched=True)
    train_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'], output_all_columns=True)
    
    eval_ds = ds_dict['test']
    eval_ds.set_format(type='torch', columns=['input_ids', 'attention_mask'], output_all_columns=True)

    return train_ds, eval_ds

def ft_labels_collate(batch):
    tensor_batch = [{k: x[k] for k in ['input_ids', 'attention_mask', 'labels']} for x in batch]
    out = torch.utils.data.dataloader.default_collate(tensor_batch)
    return out
    
def finetune_loop(*, 
                  model: MlmModel, 
                  tokenizer: tr.AutoTokenizer,
                  task: str = 'hypernym',
                  bsz: int = 16, 
                  lr: float = 4e-4, 
                  wd: float = 0.1,
                  device: torch.device = None,
                  num_seeds: int = 5, start_seed: int = 1337,
                  warmup_ratio: float = 0.06,
                  betas: list[float] = [0.9, 0.98],
                  max_steps: int = 4096,
                  split: float = 0.8,
                  freeze: bool = True,
                  schedule: str = 'linear',
                 ):
    
    device = device or torch.device('cpu')
    metrics = {}
    losses = {}
    seed = start_seed
    for seed_incr in range(num_seeds):
        seed_model = copy.deepcopy(model)
        seed_model.train()
        if freeze:
            freeze_model(seed_model)
        seed_model = seed_model.to(device)
        no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in seed_model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": wd,
            },
            {
                "params": [p for n, p in seed_model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=lr, weight_decay=wd, betas=betas)
        lr_scheduler = tr.get_scheduler(schedule, optimizer, int(max_steps * warmup_ratio), max_steps)
        
        seed = seed + seed_incr
        train_ds, eval_ds = get_train_eval_sets(task, seed, split)
        categories = tokenized_categories[task]
        g = torch.Generator()
        g.manual_seed(seed)
        dl = torch.utils.data.DataLoader(
            train_ds, batch_size=bsz, shuffle=True, drop_last=False, 
            generator=g, collate_fn=ft_labels_collate,
        )
        torch.manual_seed(seed)
        
        metrics[seed] = {}
        losses[seed] = {}
        steps = 0
        next_eval_step = 0
        while steps < max_steps:
            print("new epoch")
            for batch in dl:
                batch = {k:v.to(device) for k,v in batch.items()}
                optimizer.zero_grad()
                # begin mlm loss computation in original model
                src = seed_model(batch)
                src = src.view(-1, src.shape[-1])
                mask_unmasked = batch['input_ids'].view(-1) == seed_model.bort_config.mask_token_id
                mlm_src = src[mask_unmasked]
                logits = seed_model.head(mlm_src)
                loss = F.cross_entropy(logits, batch['labels'])
                # end mlm loss computation
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                loss = loss.item()
                if steps == next_eval_step:
                    if steps == 0:
                        next_eval_step = bsz
                    else:
                        next_eval_step *= 2
                    
                    metrics[seed][steps] = eval_loop(
                        model=seed_model,
                        tokenizer=tokenizer, 
                        dataset=eval_ds,
                        categories=categories,
                        bsz=1024, 
                        device=device,
                    )
                    losses[seed][steps] = loss

                if steps >= max_steps:
                    break 
                    
                steps += bsz
                
    return metrics,losses

## Run Finetuning

In [None]:
MODEL_NAMES = ["mlm_only", "mlm_wnre"]
STEP = 25_000

models = {
    m: load_model(Path(f"{exp_dir}/{m}/balaur/{m}/checkpoints/"), STEP)
    for m in MODEL_NAMES
}
tknzr = tr.AutoTokenizer.from_pretrained('roberta-base')

In [None]:
SKIP = True # this takes quite some time, set to False to run.

if not SKIP:
    eval_metrics = defaultdict(dict)
    train_losses = defaultdict(dict)
    for m, model in models.items():
        for task in ['hypernym']:
            with torch.amp.autocast("cuda"):
                eval_metrics[task][m], train_losses[task][m] = finetune_loop(
                    model=model, 
                    tokenizer=tknzr, 
                    task=task,
                    bsz=16, 
                    lr=4e-4,
                    device=torch.device('cuda'),
                    num_seeds=5,
                    max_steps=2048*4*8,
                )

In [None]:
if not SKIP:
    with open("ft_results.pkl", "wb") as f:
        pickle.dump(eval_metrics, f)

In [None]:
if SKIP:
    with open("ft_results.pkl", "rb") as f:
        eval_metrics = pickle.load(f)

## Plots

In [None]:
from collections import defaultdict
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

from IPython.core.display import display, HTML
from statistics import quantiles

from typing import *

In [None]:
def plot_trace(x, y, legend: str, rgb='0,100,80', dash=None, mode='lines', z=None):
    trace = [
        go.Scatter(
            name=legend,
            x=x,
            y=y,
            hovertext=z,
            line=dict(color=f'rgb({rgb})', dash=dash),
            mode=mode
        ),
    ]
    return trace

def plot_mrr(task: str, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7'):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        mrrs = defaultdict(list)
        for seed in eval_metrics[task][obj]:
            for step, metric in eval_metrics[task][obj][seed].items():
                mrr = [np.mean([1/x for x in xx]) for xx in metric.open_rank]
                mrrs[step].extend(mrr)
        for step in sorted(eval_metrics[task][obj][seed]):
            mrr = np.mean([np.mean(xx) for xx in mrrs[step]])
            y.append(mrr)          
            x.append(step)
        traces.extend(plot_trace(x,y,leg,rgb))
    fig = go.Figure(traces)
    return fig

def plot_rep(task: str, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7'):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        reps = defaultdict(list)
        for seed in eval_metrics[task][obj]:
            for step, metric in eval_metrics[task][obj][seed].items():
                rep = [int(pred[0] == ctxt) for pred, ctxt in zip(metric.top10_labels, metric.prompts)]
                reps[step].extend(rep)
        for step in sorted(eval_metrics[task][obj][seed]):
            rep = np.mean([np.mean(xx) for xx in reps[step]])
            y.append(rep)          
            x.append(step)
        traces.extend(plot_trace(x,y,leg,rgb))
    fig = go.Figure(traces)
    return fig

### Open MRR

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_mrr(
    'hypernym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.01))


fig.update_layout(width=500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="MRR (Hypernym)",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_openmrr_hypernym.pdf", width=1.5*300, height=0.75*300)


In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_mrr(
    'hyponym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.01))


fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="MRR (Hyponym)",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_openmrr_hyponym.pdf", width=1.5*300, height=0.75*300)


### Repetitions

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_rep(
    'hypernym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.80))


fig.update_layout(width=500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="Hypernym Repetition",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_repetition_hypernym.pdf", width=1.5*300, height=0.75*300)


In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_rep(
    'hyponym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.80))


fig.update_layout(width=500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="Hyponym Repetition",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_repetition_hyponym.pdf", width=1.5*300, height=0.75*300)


## Error analysis

In [None]:
def plot_intruding_preds(task: str, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7'):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        intruder_rates = defaultdict(list)
        for seed in eval_metrics[task][obj]:
            for step, metric in eval_metrics[task][obj][seed].items():
                num_preds = 0
                num_intrd = 0
                for i, preds in enumerate(metric.top10_labels):
                    for rank, pred in enumerate(preds):
                        if pred not in metric.labels[i] and pred in tokenized_categories[task]:
                            num_intrd += 1
                        num_preds += 1
                intruder_rates[step].append(num_intrd / num_preds)
        for step in sorted(eval_metrics[task][obj][seed]):
            intruder_rate = np.mean(intruder_rates[step])
            y.append(intruder_rate)          
            x.append(step)
        traces.extend(plot_trace(x,y,leg,rgb))
    fig = go.Figure(traces)
    return fig

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_intruding_preds(
    'hypernym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.01))


fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="Class Intrusion Rate",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_intrusion_hypernym.pdf", width=1.5*300, height=0.75*300)

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = plot_intruding_preds(
    'hyponym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)
fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.01))


fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="Class Intrusion Rate",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_intrusion_hyponym.pdf", width=1.5*300, height=0.75*300)

In [None]:
tknzr = tr.AutoTokenizer.from_pretrained('roberta-base')
class_counters = {}
for k, d in tokenized.items():
    c = Counter(chain.from_iterable(d['labels']))
    class_counters[k] = c

def convert_token(token_id: int) -> str:
    token = tknzr.convert_ids_to_tokens(token_id)
    if token[0] == 'Ġ':
        token = token[1:]
    return token

def error_analysis(task: str, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7'):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        
        n = 0
        intrusion_rate = defaultdict(int)        
        for seed in eval_metrics[task][obj]:
            # do this only for last step
            step = max(eval_metrics[task][obj][seed].keys())
            metric = eval_metrics[task][obj][seed][step]
            for i, preds in enumerate(metric.top10_labels):
                for rank, pred in enumerate(preds):
                    if pred not in metric.labels[i] and pred in tokenized_categories[task]:
                        intrusion_rate[pred] += 1
                n += 1
        intrusion_rate = {k:v/n for k,v in intrusion_rate.items()}    
                
        idx = [i for i,_ in class_counters[task].most_common() if i in intrusion_rate]
        x = [class_counters[task][i] for i in idx]
        y = [intrusion_rate[i] for i in idx]
        z = [convert_token(i) for i in idx]
        traces.extend(plot_trace(x,y,leg,rgb, mode='markers', z=z))
    fig = go.Figure(traces)
    return fig

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = error_analysis(
    'hypernym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)

fig.update_yaxes(type="linear")
fig.update_xaxes(type="log")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.4,y=0.7))


fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Class Frequency in HypCC",
    yaxis_title="Intrusion Rate",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_freq-intrusion_hypernym.pdf", width=1.5*300, height=0.75*300)

In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

m2 = "mlm_only"
m1 = "mlm_wnre"
fig = error_analysis(
    'hyponym',
    m1,m2,
    MODEL_LEGENDS[m1], MODEL_LEGENDS[m2],
    MODEL_RGBS[m1], MODEL_RGBS[m2]
)

fig.update_yaxes(type="linear")
fig.update_xaxes(type="log")

fig.update_layout(showlegend=True, template='simple_white')
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.4,y=0.7))


fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Class Frequency in HypCC",
    yaxis_title="Intrusion Rate",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_hypcc_freq-intrusion_hyponym.pdf", width=1.5*300, height=0.75*300)