## MLM losses analysis
* Input
  * MLM losses - CKIP bert: `../data/ckip_bert_cnstr_losses.pkl` (10.10)
  * MLM losses - Bert base: `../data/bert_base_cnstr_losses.pkl` (10.10)
* Note
  * raw: the mask is applied at the locations of constructions
  * shifted: this is a translated version of construction mask. The mask is the same as the construction mask, but is shifted with a random offset. If the shifted mask exceeds the sequence length, it is wrapped around to the head of the sequence.
  * random: this serves as a control condition. The same number of tokens are masked randomly chose in the sequence.

In [1]:
import pickle
import numpy as np
import pandas as pd

## CKIP bert

In [2]:
with open("../data/ckip_bert_cnstr_losses.pkl", "rb") as fin:
    ckip_losses = pickle.load(fin)

In [3]:
def compute_cv_stat(losses_dict):
    stats = {}
    for cond, losses in losses_dict.items():
        cv_means = np.nanmean(losses, 0)
        cv_std = np.nanstd(cv_means)
        n_nan = np.sum(np.isnan(losses))
        stats[cond] = {
            "cv_mean": np.nanmean(cv_means), "cv_std": cv_std, "n_nan": n_nan
        }    
    return stats

def compute_contrasts(losses_dict):
    contrasts = {        
        "raw(cslot-cnstr)"    : losses_dict["cslot_raw"]-losses_dict["cnstr_raw"],
        "shifted(cslot-cnstr)": losses_dict["cslot_shifted"]-losses_dict["cnstr_shifted"],
        "random(cslot-cnstr)" : losses_dict["cslot_random"]-losses_dict["cnstr_random"],
        "raw(vslot-cnstr)"    : losses_dict["vslot_raw"]-losses_dict["cnstr_raw"],
        "shifted(vslot-cnstr)": losses_dict["vslot_shifted"]-losses_dict["cnstr_shifted"],
        "random(vslot-cnstr)" : losses_dict["vslot_random"]-losses_dict["cnstr_random"],        
    }            
    return contrasts


In [4]:
ckip_stats = compute_cv_stat(ckip_losses)
ckip_contrasts = compute_cv_stat(compute_contrasts(ckip_losses))
ckip_stat_dfr = pd.DataFrame.from_dict(ckip_stats, orient="index")
ckip_stat_dfr

Unnamed: 0,cv_mean,cv_std,n_nan
cnstr_raw,8.511247,0.039296,1
cnstr_shifted,7.350479,0.0583,1
cnstr_random,5.961418,0.051417,0
cslot_raw,8.066977,0.048429,1
cslot_shifted,7.828488,0.051511,1
cslot_random,7.658355,0.061763,1
vslot_raw,9.500527,0.040104,1
vslot_shifted,6.630606,0.081591,3
vslot_random,6.327774,0.066782,0


In [5]:
# note that in shifted(cslot_cnstr) and shifted(vslot-cnstr), cslot and vslot are 
# not aligned with cnstr; they are shifted independently.
ckip_contrast_dfr = pd.DataFrame.from_dict(ckip_contrasts, orient="index")
ckip_contrast_dfr

Unnamed: 0,cv_mean,cv_std,n_nan
raw(cslot-cnstr),-0.444271,0.055797,1
shifted(cslot-cnstr),0.477202,0.052125,2
random(cslot-cnstr),1.696492,0.068649,1
raw(vslot-cnstr),0.989372,0.043886,2
shifted(vslot-cnstr),-0.720103,0.088001,4
random(vslot-cnstr),0.366356,0.093851,0


## Bert-base

In [6]:
with open("../data/bert_base_cnstr_losses.pkl", "rb") as fin:
    base_losses = pickle.load(fin)
base_stats = compute_cv_stat(base_losses)
base_contrasts = compute_cv_stat(compute_contrasts(base_losses))
base_stat_dfr = pd.DataFrame.from_dict(base_stats, orient="index")
base_stat_dfr

Unnamed: 0,cv_mean,cv_std,n_nan
cnstr_raw,8.097117,0.036832,1
cnstr_shifted,6.947395,0.047075,3
cnstr_random,5.575383,0.072184,0
cslot_raw,7.364918,0.064958,1
cslot_shifted,7.158829,0.0531,1
cslot_random,6.997634,0.071926,0
vslot_raw,8.900065,0.050399,1
vslot_shifted,6.151401,0.090277,2
vslot_random,5.883541,0.08183,1


In [7]:
base_contrast_dfr = pd.DataFrame.from_dict(base_contrasts, orient="index")
base_contrast_dfr

Unnamed: 0,cv_mean,cv_std,n_nan
raw(cslot-cnstr),-0.732199,0.067939,1
shifted(cslot-cnstr),0.210853,0.059142,4
random(cslot-cnstr),1.422251,0.083828,0
raw(vslot-cnstr),0.803012,0.033653,2
shifted(vslot-cnstr),-0.794815,0.101628,4
random(vslot-cnstr),0.307919,0.074492,1


In [8]:
from nbconvert import HTMLExporter
import nbformat
this_nb = nbformat.read("10.20-mlm-data-alys.ipynb", as_version=4)
html_export = HTMLExporter()
(body, res) = html_export.from_notebook_node(this_nb)

In [9]:
with open("../data/output/10.20-mlm-data-alys.html", "w") as fout:
    fout.write(body)