In [None]:
import os
os.environ["TRANSFORMERS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/network/scratch/m/mirceara/.cache/huggingface/datasets"
os.environ["BALAUR_CACHE"] = "/network/scratch/m/mirceara/.cache/balaur"

In [None]:
import sys
sys.path.append("/home/mila/m/mirceara/balaur/experiments/pretrain/")
from run_bort import MlmModel, MlmModelConfig, MlmWnreDataModule, MlmWnreDataModuleConfig, WNRE
from balaur.modeling.backbone.bort import BortForSequenceClassification
from balaur import ROOT_DIR

In [None]:
from collections import defaultdict
from functools import lru_cache
import math
from pathlib import Path
import random
import json

from itertools import chain
import copy


from tqdm import tqdm
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.optim import AdamW
import pandas as pd
import datasets as ds
import transformers as tr
import pytorch_lightning as pl

from typing import List

## Setup

### Model

In [None]:
def load_bort_model(ckpt_dir: Path, step: int, num_labels: int = 2, **config_kwargs):
    
    p = list(ckpt_dir.glob(f"*step={step}.ckpt"))
    assert len(p) == 1, f"Unresolvable paths: {p}"
    p = p[0]
    
    ckpt = torch.load(p)
    config = MlmModelConfig()
    config.__dict__.update(ckpt['hyper_parameters'])
    config.__dict__.update(config_kwargs)
    pl_model = MlmModel(config=config)
    pl_model.load_state_dict(ckpt['state_dict'])
    
    model = BortForSequenceClassification(pl_model.bort_config, num_labels=num_labels)
    model.bort = pl_model.model
    return model

### Dataset

In [None]:
tknzr = tr.AutoTokenizer.from_pretrained('roberta-base')

snli = ds.load_dataset('snli', split='train').filter(lambda x: x['label'] != -1)
snli = snli.rename_column('premise', 'sentence1').rename_column('hypothesis', 'sentence2')
snli_class_label = snli.features['label']


d = ds.DatasetDict(dict(
            nmonli = ds.load_dataset("json", data_files="monli/nmonli_train.jsonl", split='train'),
            pmonli = ds.load_dataset("json", data_files="monli/pmonli.jsonl", split='train'),
        ))
d = d.map(lambda x: {'label': [snli_class_label.str2int(l) for l in x['gold_label']]}, batched=True)
d = d.cast_column('label', snli_class_label)
# d['snli'] = snli
d = d.rename_column('label', 'labels')

d = d.map(lambda x: tknzr(x['sentence1'], x['sentence2'], return_token_type_ids=False), batched=True)

d['nmonli'] = d['nmonli'].add_column('source', ['nmonli' for _ in range(len(d['nmonli']))])
d['pmonli'] = d['pmonli'].add_column('source', ['pmonli' for _ in range(len(d['pmonli']))])
# d['snli'] = d['snli'].add_column('source', ['snli' for _ in range(len(d['snli']))])
# d['snli'] = d['snli'].add_column('depth', [None for _ in range(len(d['snli']))])


max_lengths = [max([len(x) for x in _d['input_ids']]) for _d in d.values()]
max_length = max(max_lengths)
d = d.map(lambda e: tknzr.pad(e, padding='max_length', max_length=max_length), batched=True)

## WordNet annotation

In [None]:
wnre = WNRE('roberta-base', rel_depth=3)

In [None]:
def is_in_wordnet(tok: str):
    tok = tknzr.vocab.get("Ġ" +tok, None)
    if tok:
        return bool(wnre.tok2rel2syn[tok])
    else:
        return False
    
def annotate_single_token(example):
        if example['source'] == 'snli':
            return dict(
                hypernym_in_lexrel=None,
                hyponym_in_lexrel=None,
            )
        if example['depth'] > 0:
            hypernym =  example['sentence2_lex']
            hyponym =  example['sentence1_lex']
        else:
            hypernym =  example['sentence1_lex']  
            hyponym =  example['sentence2_lex']
            
        return dict(
            hypernym_in_lexrel=is_in_wordnet(hypernym),
            hyponym_in_lexrel=is_in_wordnet(hyponym),
        )

d = d.map(annotate_single_token)

### Train-eval split

In [None]:
def systematic_train_eval_split(monli_ds: ds.Dataset, seed: int, frac=0.8):
    monli_dict = monli_ds.to_dict()
    hypernym_dict = defaultdict(int)
    for i in range(len(monli_ds)):
        if monli_dict['depth'][i] > 0:
            hypernym =  monli_dict['sentence2_lex'][i]
        else:
            hypernym =  monli_dict['sentence1_lex'][i]             
        hypernym_dict[hypernym] += 1
    
    hypernyms = list(hypernym_dict.keys())
    random.Random(seed).shuffle(hypernyms)
    num_train = int(len(hypernyms) * frac)
    train_hypernyms = set(hypernyms[:num_train])
    eval_hypernyms = set(hypernyms[num_train:])
    
    train_dict = defaultdict(list)
    eval_dict = defaultdict(list)
    for i in range(len(monli_ds)):
        if monli_dict['depth'][i] > 0:
            hypernym =  monli_dict['sentence2_lex'][i]
        else:
            hypernym =  monli_dict['sentence1_lex'][i]   
            
        if hypernym in train_hypernyms:
            for k in monli_dict.keys():
                train_dict[k].append(monli_dict[k][i])
        else:
            for k in monli_dict.keys():
                eval_dict[k].append(monli_dict[k][i])
    
    train_ds = ds.Dataset.from_dict(train_dict)
    eval_ds = ds.Dataset.from_dict(eval_dict)
    return train_ds, eval_ds

@lru_cache
def get_train_eval_sets(seed, frac=0.8):
    keep_cols = ['input_ids', 'labels','attention_mask',
                'source', 'hypernym_in_lexrel', 'hyponym_in_lexrel',
                'depth']
    train_nmonli, eval_nmonli = systematic_train_eval_split(d['nmonli'], seed=seed, frac=frac)
    train_pmonli, eval_pmonli = systematic_train_eval_split(d['pmonli'], seed=seed, frac=frac)

    train_nmonli = train_nmonli.remove_columns([c for c in train_nmonli.column_names if c not in keep_cols])
    train_pmonli = train_pmonli.remove_columns([c for c in train_pmonli.column_names if c not in keep_cols])

    train_dataset = ds.concatenate_datasets([train_nmonli, train_pmonli])
    train_dataset = train_dataset.shuffle(seed=seed)
    eval_dataset = ds.concatenate_datasets([eval_nmonli, eval_pmonli])
    
    train_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'attention_mask'], output_all_columns=True)
    eval_dataset.set_format(type='torch', columns=['input_ids', 'labels', 'attention_mask'], output_all_columns=True)
    
    return train_dataset, eval_dataset

In [None]:
train_dataset, eval_dataset = get_train_eval_sets(42)

### Metrics

In [None]:
class MoNLIMetrics:
    def __init__(self):
        self.acc = defaultdict(list)
        self.nll = defaultdict(list)
        self.labels = defaultdict(list)
        self.losses = []

    def update(self,
               logits: torch.Tensor,    # (bsz,vsz)
               labels: torch.Tensor, # (bsz, nlabels), nlabels varies within a batch
               loss: float,
               sources: List[str],
               depths: List[int],
               hypernym_in_lexrel: List[bool],
               hyponym_in_lexrel: List[bool],
               ):
        
        nll_tensor = -F.log_softmax(logits, dim=-1)
        for bidx, vidx in enumerate(labels):
            nlls = nll_tensor[bidx][vidx].item()
            accs = (nlls <= nll_tensor[bidx]).all()

            source = sources[bidx]
            depth = 0 if depths[bidx] is None else abs(depths[bidx]) 
            if source == 'snli':
                k = source
            elif hypernym_in_lexrel[bidx] and hyponym_in_lexrel[bidx]:
                k = f"{source}_both_{depth}"
            elif hypernym_in_lexrel[bidx]:
                k = f"{source}_hyper_{depth}"
            elif hyponym_in_lexrel[bidx]:
                k = f"{source}_hypo_{depth}"
            else:
                k = f"{source}_none_{depth}"
                
            self.nll[k].append(nlls)
            self.nll["all"].append(nlls)
            self.acc[k].append(int(accs))
            self.acc["all"].append(int(accs))
            self.labels[k].append(labels[bidx])
            self.labels["all"].append(labels[bidx])
                
        self.losses += [loss]

### Eval loop

In [None]:
def eval_labels_collate(batch):
    tensor_batch = [{k: x[k] for k in ['input_ids', 'attention_mask', 'labels']} for x in batch]
    out = torch.utils.data.dataloader.default_collate(tensor_batch)
    out['source'] = [x['source'] for x in batch]
    out['hypernym_in_lexrel'] = [x['hypernym_in_lexrel'] for x in batch]
    out['hyponym_in_lexrel'] = [x['hyponym_in_lexrel'] for x in batch]
    out['depth'] = [x['depth'] for x in batch]
    return out


def eval_loop(*, model: BortForSequenceClassification, 
              tokenizer: tr.AutoTokenizer, 
              dataset: ds.Dataset, 
              bsz: int,
              device: torch.device = None,
             ):
    # setup
    device = device or torch.device('cpu')
    model.to(device)
    dl = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, drop_last=False, 
                                     collate_fn=eval_labels_collate)
    metrics = MoNLIMetrics()
    
    with torch.inference_mode():
        for batch in dl:
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                out = model(
                    batch['input_ids'].to(device),
                    labels=batch['labels'].to(device),
                    attention_mask=batch['attention_mask'].to(device)
                )
            metrics.update(
                out['logits'],
                batch['labels'],
                out['loss'].item(),
                sources=batch['source'],
                depths=batch['depth'],
                hypernym_in_lexrel=batch['hypernym_in_lexrel'],
                hyponym_in_lexrel=batch['hyponym_in_lexrel'],
            )
    return metrics 

### Finetune loop

In [None]:
def ft_labels_collate(batch):
    tensor_batch = [{k: x[k] for k in ['input_ids', 'attention_mask', 'labels']} for x in batch]
    out = torch.utils.data.dataloader.default_collate(tensor_batch)
    out['source'] = [x['source'] for x in batch]
    out['hypernym_in_lexrel'] = [x['hypernym_in_lexrel'] for x in batch]
    out['hyponym_in_lexrel'] = [x['hyponym_in_lexrel'] for x in batch]
    out['depth'] = [x['depth'] for x in batch]
    return out


def finetune_loop(*, model: BortForSequenceClassification, tokenizer: tr.AutoTokenizer,
                  bsz: int = 16, lr: float = 1e-5, wd: float = 0.1, max_grad_norm=1.0,
                  warmup: float = 0.06,
                  device: torch.device = None,
                  num_seeds: int = 5, start_seed: int = 42,
                  max_train_examples: int = 4096,
                  eval_every_n_examples: int = 512,
                  betas: list[float] = [0.9, 0.98],
                 ):
    
    device = device or torch.device('cpu')
    train_metrics = {}
    eval_metrics = {}
    losses = {}
    seed = start_seed
    
    training_steps = math.ceil(max_train_examples / bsz)
    warmup_steps = int(training_steps * warmup)
    for seed_incr in range(num_seeds):
        seed_model = copy.deepcopy(model)
        seed_model.train()
        seed_model = seed_model.to(device)
        no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in seed_model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": wd,
            },
            {
                "params": [p for n, p in seed_model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=lr, weight_decay=wd, betas=betas)
        lr_scheduler = tr.get_scheduler('linear', optimizer, warmup_steps, training_steps)
        
        seed = seed + seed_incr
        train_dataset, eval_dataset = get_train_eval_sets(seed)
        g = torch.Generator()
        g.manual_seed(seed)
        dl = torch.utils.data.DataLoader(
            train_dataset, batch_size=bsz, shuffle=True, drop_last=False, 
            generator=g, collate_fn=ft_labels_collate,
        )
        torch.manual_seed(seed)
        
        train_metrics[seed] = {}
        eval_metrics[seed] = {}
        seen = 0
        next_eval_seen = 0
        while seen < max_train_examples:
            pbar = tqdm(dl, total=len(dl), desc=f"Training (loss=0)")
            for batch in pbar:
                with torch.autocast(device_type='cuda', dtype=torch.float32):
                    optimizer.zero_grad()
                    out = seed_model(
                        batch['input_ids'].to(device),
                        labels=batch['labels'].to(device),
                        attention_mask=batch['attention_mask'].to(device)
                    )
                if seen >= next_eval_seen:
                    next_eval_seen += eval_every_n_examples
                    train_metric = MoNLIMetrics()
                    train_metric.update(
                        out['logits'],
                        batch['labels'],
                        out['loss'].item(),
                        sources=batch['source'],
                        depths=batch['depth'],
                        hypernym_in_lexrel=batch['hypernym_in_lexrel'],
                        hyponym_in_lexrel=batch['hyponym_in_lexrel'],
                    )
                    train_metrics[seed][seen] = train_metric
                    eval_metrics[seed][seen] = eval_loop(
                        model=seed_model,
                        tokenizer=tokenizer, 
                        dataset=eval_dataset,
                        bsz=32, 
                        device=device,
                    )
                loss = out['loss']
                pbar.set_description(f"Training (loss={loss})")
                loss.backward()
                torch.nn.utils.clip_grad_norm_(seed_model.parameters(), max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                seen += bsz
                if seen >= max_train_examples:
                    break 
    return train_metrics, eval_metrics

## Evaluation

In [None]:
MODEL_NAMES = ["mlm_only", "mlm_wnre"]
STEP = 25_000
exp_dir = Path("/network/scratch/m/mirceara/.cache/balaur/runs")

models = {
    m: load_bort_model(Path(f"{exp_dir}/{m}/balaur/{m}/checkpoints/"), STEP, num_labels=2)
    for m in MODEL_NAMES
}
tknzr = tr.AutoTokenizer.from_pretrained('roberta-base')

In [None]:
train_metrics = {}
eval_metrics = {}
num_epochs = 10
for m in  ["mlm_wnre", "mlm_only"]:
    model = models[m]
    train_metrics[m], eval_metrics[m] = finetune_loop(
        model=model, 
        tokenizer=tknzr, 
        bsz=32,
        lr=8e-5,
        device=torch.device('cuda'),
        num_seeds=5,
        max_train_examples=len(train_dataset)*num_epochs,
        eval_every_n_examples=int(len(train_dataset)/2),
    )

In [None]:
import pickle

with open("train_metrics.pkl", "wb") as f:
    pickle.dump(train_metrics, f)
    
with open("eval_metrics.pkl", "wb") as f:
    pickle.dump(eval_metrics, f)

## Plots

In [None]:
from collections import defaultdict
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.core.display import display, HTML
from statistics import quantiles
import plotly.io as pio


from typing import *

In [None]:
def errorband_trace(x, y, ylo, yhi, legend: str, rgb='0,100,80', showlegend: bool = True):
    trace = [
        go.Scatter(
            name=legend,
            x=x,
            y=y,
            line=dict(color=f'rgb({rgb})'),
            mode='lines',
            showlegend=showlegend,
        ),
        go.Scatter(
            x=x+x[::-1], # x, then x reversed
            y=yhi+ylo[::-1], # upper, then lower reversed
            fill='toself',
            fillcolor=f'rgba({rgb},0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        )
    ]
    return trace

In [None]:
def get_nll_traces(metrics: MoNLIMetrics, model1: str, model2: str, stratification_key: str, rgb1: str='45,114,178', rgb2: str='194,24,7', showlegend: bool = True):
    traces = []
    for obj, rgb in zip([model1, model2], [rgb1,rgb2]):
        x = []
        y = []
        ylo = []
        yhi = []
        y_mean = []
        x_mean = []
        nlls = defaultdict(list)
        for seed in metrics[obj]:
            for step, metric in metrics[obj][seed].items():
                nlls[step].extend(metric.nll[stratification_key])
        for step in sorted(metrics[obj][seed]):
            model = f"{obj}_{step}"
            nll = [np.mean(xx) for xx in nlls[step]]
            if len(nll) >= 2:
                quartiles = quantiles(nll, n=4)
                x.append(step)
                y.append(quartiles[1])
                ylo.append(quartiles[0])
                yhi.append(quartiles[2])
            if len(nll):
                y_mean.append(np.mean(nll))
                x_mean.append(step)
        traces.extend(errorband_trace(x,y,ylo,yhi,obj,rgb, showlegend))
        traces.append(
            go.Scatter(
                x=x,
                y=y_mean,
                line=dict(color=f'rgb({rgb})', dash='dash'),
                name=f"{obj} (mean)",
                showlegend=showlegend
            )
        )
    return traces


def get_fixed_nll_traces(metrics: MoNLIMetrics, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7', stratification_key: str = 'all', showlegend: bool = True):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        ylo = []
        yhi = []
        y_mean = []
        x_mean = []
        nlls = defaultdict(list)
        for seed in metrics[obj]:
            for step, metric in metrics[obj][seed].items():
                for key in metric.nll.keys():
                    if stratification_key in key:
                        nlls[step].extend(metric.nll[key]) 
        for step in sorted(metrics[obj][seed]):
            if len(nlls[step]):
                nll = [np.mean(x) for x in nlls[step]]
                y_mean_step = np.mean(nll)
                y_std_step = np.std(nll)
                x.append(step)
                y.append(y_mean_step)
#                 ylo.append(y_mean_step - y_std_step)
#                 yhi.append(y_mean_step + y_std_step)
                ylo.append(y_mean_step)
                yhi.append(y_mean_step)
        traces.extend(errorband_trace(x,y,ylo,yhi,leg,rgb, showlegend))
    return traces


def get_acc_traces(metrics: MoNLIMetrics, model1: str, model2: str, leg1: str, leg2: str, rgb1: str='45,114,178', rgb2: str='194,24,7', stratification_key: str = 'all', showlegend: bool = True):
    traces = []
    for obj, leg, rgb in zip([model1, model2], [leg1, leg2], [rgb1,rgb2]):
        x = []
        y = []
        ylo = []
        yhi = []
        y_mean = []
        x_mean = []
        accs = defaultdict(list)
        for seed in metrics[obj]:
            for step, metric in metrics[obj][seed].items():
                for key in metric.acc.keys():
                    if stratification_key in key:
                        accs[step].extend(metric.acc[key]) 
        for step in sorted(metrics[obj][seed]):
            if len(accs[step]):
                y_mean_step = np.mean(accs[step])
                y_std_step = np.std(accs[step])
                x.append(step)
                y.append(y_mean_step)
#                 ylo.append(y_mean_step - y_std_step)
#                 yhi.append(y_mean_step + y_std_step)
                ylo.append(y_mean_step)
                yhi.append(y_mean_step)
        traces.extend(errorband_trace(x,y,ylo,yhi,leg,rgb, showlegend))
    return traces




In [None]:
MODEL_LEGENDS = {
    "mlm_only": "BERT (OURS)",
    "mlm_wnre": "BERT+BALAUR"
}
MODEL_RGBS = {
    "mlm_only": "255,127,80",
    "mlm_wnre": "65,105,225"
}

acc_traces = get_acc_traces(
    eval_metrics, 
    MODEL_NAMES[1], MODEL_NAMES[0],
    MODEL_LEGENDS[MODEL_NAMES[1]], MODEL_LEGENDS[MODEL_NAMES[0]],
    MODEL_RGBS[MODEL_NAMES[1]], MODEL_RGBS[MODEL_NAMES[0]],
    stratification_key='all'
)

fig = go.Figure(acc_traces)


fig.update_yaxes(type="linear")
fig.update_xaxes(type="linear")

fig.update_layout(showlegend=True, template='simple_white')
# fig.update_layout(legend=dict(yanchor='top', xanchor='left', x=0.01,y=0.99))
fig.update_layout(legend=dict(yanchor='bottom', xanchor='right', x=0.99,y=0.01))

fig.update_layout(width =500, height=250, 
                  font_family="Serif", 
                  font_size=12, 
                  margin_l=5, margin_t=5, margin_b=5, margin_r=5)

fig.update_layout(
    xaxis_title="Number of training examples",
    yaxis_title="Accuracy (MoNLI)",
)
display(HTML(fig.to_html()))
pio.write_image(fig, "ft_monli.pdf", width=1.5*300, height=0.75*300)

In [None]:

def get_stratified_accs(metrics: MoNLIMetrics, model1: str, model2: str, stratification_key: str):
    d = {}
    for obj in [model1, model2]:
        accs = []
        for seed in metrics[obj]:
            step = max(metrics[obj][seed].keys())
            metric = metrics[obj][seed][step]
            accs.extend(metric.acc[stratification_key])
            
        model = f"{obj}_{step}"
        acc = 100 * np.mean([np.mean(xx) for xx in accs])
        
        d[obj] = acc
    return d
        
    
        

# eval NLL
DEPTHS = list(range(1,7)) # technically includes up to depth 10, but plots are very sparse and uninterpretable beyond depth 6
SOURCES = ['nmonli', 'pmonli']
SINGLE_TOKENS = ['both', 'hyper', 'none']


stratified_counts = {}
for source in SOURCES:
    for depth in DEPTHS:
        for single_token in SINGLE_TOKENS:
            stratification_key = f"{source}_{single_token}_{depth}"
            num_examples = 0
            for seed in eval_metrics['mlm_wnre']:
                num_examples += len(eval_metrics['mlm_wnre'][seed][0].labels[stratification_key])
            num_examples /= len(eval_metrics['mlm_wnre'])
            num_examples = math.ceil(num_examples)
            stratified_counts[stratification_key] = num_examples

stratified_accs = {m: {} for m in ['mlm_wnre', 'mlm_only']}
for i1, source in enumerate(SOURCES):
    for i2, depth in enumerate(DEPTHS):
        for i3, single_token in enumerate(SINGLE_TOKENS):
            stratification_key = f"{source}_{single_token}_{depth}"
            accs = get_stratified_accs(eval_metrics, 'mlm_only', 'mlm_wnre', stratification_key)
            for m, acc in accs.items():
                stratified_accs[m][stratification_key] = acc


print(json.dumps(stratified_counts, indent=2))
print(json.dumps(stratified_accs, indent=2))
            

In [None]:
# NMoNLI performance
agg_fns = dict(
    nmonli_overall=lambda k: "nmonli" in k,
    nmonli_both=lambda k: "nmonli" in k and "both" in k,
    nmonli_hyper=lambda k: "nmonli" in k and "hyper" in k,
    nmonli_none=lambda k: "nmonli" in k and "none" in k,
    pmonli_overall=lambda k: "pmonli" in k,
    pmonli_both=lambda k: "pmonli" in k and "both" in k,
    pmonli_hyper=lambda k: "pmonli" in k and "hyper" in k,
    pmonli_none=lambda k: "pmonli" in k and "none" in k,

)
for agg, agg_fn in agg_fns.items():
    print()
    print(agg)
    for m, saccs in stratified_accs.items():
        acc_sum = 0
        acc_div = 0
        for k, sacc in saccs.items():
            if agg_fn(k):
                c = stratified_counts[k]
                if math.isnan(sacc):
                    sacc = 0
                    assert c == 0

                acc_sum += sacc*c
                acc_div += c   
        acc_avg = acc_sum / acc_div if acc_div else None
        print(m, acc_avg)

In [None]:
# NMoNLI performance
agg_fns = dict(
    nmonli_1=lambda k: "nmonli" in k and "1" in k,
    nmonli_2=lambda k: "nmonli" in k and "2" in k,
    nmonli_3=lambda k: "nmonli" in k and "3" in k,
    nmonli_4=lambda k: "nmonli" in k and "4" in k,
    nmonli_5=lambda k: "nmonli" in k and "5" in k,
    nmonli_6=lambda k: "nmonli" in k and "6" in k,
    pmonli_1=lambda k: "pmonli" in k and "1" in k,
    pmonli_2=lambda k: "pmonli" in k and "2" in k,
    pmonli_3=lambda k: "pmonli" in k and "3" in k,
    pmonli_4=lambda k: "pmonli" in k and "4" in k,
    pmonli_5=lambda k: "pmonli" in k and "5" in k,
    pmonli_6=lambda k: "pmonli" in k and "6" in k,

)
for agg, agg_fn in agg_fns.items():
    print()
    print(agg)
    for m, saccs in stratified_accs.items():
        acc_sum = 0
        acc_div = 0
        for k, sacc in saccs.items():
            if agg_fn(k):
                c = stratified_counts[k]
                if math.isnan(sacc):
                    sacc = 0
                    assert c == 0

                acc_sum += sacc*c
                acc_div += c   
        acc_avg = acc_sum / acc_div if acc_div else None
        print(m, acc_avg)