## MLM slots evals
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * losses data (with CV): `../data/ckip_bert_cnstr_losses.pkl`

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
model = model.to(device)

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

In [9]:
batch = batched_text(data, train_idxs[:5], "vslot")
list(batch.keys())

['masked', 'text', 'mindex', 'mindex_bool']

In [10]:
batch["mindex"]

array([[135, 137,  -1,  -1,  -1,  -1],
       [ 13,  15,  -1,  -1,  -1,  -1],
       [  3,   5,  -1,  -1,  -1,  -1],
       [  0,   1,   2,   4,   5,   6],
       [  0,   2,  -1,  -1,  -1,  -1]])

## Function definitions

In [11]:
loss_fct = CrossEntropyLoss(reduction="none")
def make_masked_ylabels(mindex: np.ndarray, ori_ids: torch.Tensor):            
    ylabels = torch.zeros_like(ori_ids)-100
    for i in range(mindex.shape[0]):
        # add 1 for the extra [CLS] token added after tokenizer
        j = mindex[i]+1        
        j = j[(j>=0) & (j<512-2)]        
        ylabels[i, j] = ori_ids[i, j]
    return ylabels

def compute_loss(logits, ylabels):
    return loss_fct(logits, ylabels)

In [29]:
def compute_mlm_loss(model, train_idxs, use_mask, debug=False):
    batch_size = 16
    loss_vec = []    
    
    for i in range(0, len(train_idxs), batch_size):
        data_idxs = train_idxs[i:i+batch_size]
        batch = batched_text(data, data_idxs, use_mask)

        # compute perplexity    
        X = tokenizer(batch["masked"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        X = X.to(device)
        Xori = tokenizer(batch["text"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        ylabels = make_masked_ylabels(batch["mindex"], Xori["input_ids"])
        ylabels = ylabels.to(device)        
        with torch.no_grad():
            out = model(**X)
            batch_loss = compute_loss(out.logits.swapdims(2, 1), ylabels)
            # average over non-zero token losses
            batch_loss = batch_loss.sum(1) / (batch_loss>0).sum(1)
            loss_vec.extend(batch_loss.cpu().tolist())
        if debug: break
        
    if debug:        
        return {"loss_vec": np.array(loss_vec), 
                "batch": batch, "ylabels": ylabels, "out": out}
    else:
        return np.array(loss_vec)

In [18]:
from itertools import product
Mtest = len(test_idxs)
n_split = len(cv_splits)
def make_design_iter():
    return product(["cnstr", "cslot", "vslot"], # cnstr elements 
                   ["raw", "shifted", "random"]) # masked type

def get_mask_condition(cx_elem, mask_type):
    cx_code = "cx" if cx_elem == "cnstr" else cx_elem
    mask_code = "" if mask_type == "raw" else mask_type+"-"
    return mask_code+cx_code

def print_stat(x):
    print("M={:.4f}, SD={:.4f} ({} NaNs)"
          .format(np.nanmean(x), np.nanstd(x), np.sum(np.isnan(x))))

def make_buffer():
    ret = {}
    for cx_elem, mtype in make_design_iter():
        ret[f"{cx_elem}_{mtype}"] = \
            np.zeros((Mtest, n_split))
    return ret

## Computing MLM Losses

In [22]:
## masked
losses = make_buffer()
n_condition = len(list(make_design_iter()))
for split_idx in range(n_split):
    train_idxs, test_idxs = cv_splits[split_idx].values()
    pbar = tqdm(total=n_condition)
    for cx_elem, mtype in make_design_iter():        
        cond_text = f"{cx_elem}_{mtype}"
        pbar.set_description(f"{split_idx+1: 2d}. {cond_text}")
        m_cond = get_mask_condition(cx_elem, mtype)        
        loss_x = compute_mlm_loss(model, test_idxs, m_cond)        
        losses[cond_text][:, split_idx] = loss_x            
        pbar.update(1)
    pbar.close()

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

## Output pickle

In [23]:
with open("../data/ckip_bert_cnstr_losses.pkl", "wb") as fout:
    pickle.dump(losses, fout)

## Losses with bert-base-chinese

In [28]:
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = model.to(device)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [30]:
## masked
losses = make_buffer()
n_condition = len(list(make_design_iter()))
for split_idx in range(n_split):
    train_idxs, test_idxs = cv_splits[split_idx].values()
    pbar = tqdm(total=n_condition)
    for cx_elem, mtype in make_design_iter():        
        cond_text = f"{cx_elem}_{mtype}"
        pbar.set_description(f"{split_idx+1: 2d}. {cond_text}")
        m_cond = get_mask_condition(cx_elem, mtype)        
        loss_x = compute_mlm_loss(model, test_idxs, m_cond)        
        losses[cond_text][:, split_idx] = loss_x            
        pbar.update(1)
    pbar.close()

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

In [31]:
# output pickle
with open("../data/bert_base_cnstr_losses.pkl", "wb") as fout:
    pickle.dump(losses, fout)

## Workarea

### the equality between the built-in loss and custom-made compute-loss

In [13]:
batch = {'masked': [['今', '[MASK]']],
 'text': [['今', '天']],
 'mindex': np.array([[1]]),
 'mindex_bool': np.array([[ True]])}

In [14]:
X = tokenizer(batch["masked"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
X = X.to(device)
Xori = tokenizer(batch["text"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
ylabels = make_masked_ylabels(batch["mindex"], Xori["input_ids"])
ylabels = ylabels.to(device)
with torch.no_grad():
    out = model(**X, labels=ylabels)
    loss_batch = out.loss.cpu().item()
    print(loss_batch)

5.4197540283203125


In [15]:
loss = compute_loss(out.logits.swapdims(2, 1), ylabels)
(loss.sum(1) / (loss>0).sum()).cpu().item()

5.4197540283203125

In [16]:
ylabels

tensor([[-100, -100, 1921, -100]], device='cuda:0')

## For debug

In [17]:
aa = compute_mlm_loss(test_idxs[1:2], "shifted-vslot", debug=True)
aa

{'loss_vec': array([3.59199357]),
 'batch': {'masked': [['1',
    '[MASK]',
    'm',
    '補',
    '給',
    '品',
    '買',
    '一',
    '買',
    '窩',
    '著',
    '先',
    '看',
    '這',
    '幾',
    '天',
    '發',
    '展',
    '比',
    '較',
    '穩',
    '第',
    '一',
    '次',
    '過',
    '年',
    '留',
    '在',
    '高',
    '[MASK]']],
  'text': [['1',
    'k',
    'm',
    '補',
    '給',
    '品',
    '買',
    '一',
    '買',
    '窩',
    '著',
    '先',
    '看',
    '這',
    '幾',
    '天',
    '發',
    '展',
    '比',
    '較',
    '穩',
    '第',
    '一',
    '次',
    '過',
    '年',
    '留',
    '在',
    '高',
    '雄']],
  'mindex': array([[ 1, 29]]),
  'mindex_bool': array([[ True,  True]])},
 'ylabels': tensor([[-100, -100,  153, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, 7413, -100]], device='cuda:0'),
 'out': MaskedLMOutput(loss=None, logits=tensor([[[ -7.845

In [50]:
aa = compute_mlm_loss([test_idxs[0]], "random-cslot", debug=True)
aa

  0%|          | 0/1 [00:00<?, ?it/s]

{'loss_vec': array([5.19510221]),
 'batch': {'masked': [['我',
    '覺',
    '得',
    '這',
    '位',
    '護',
    '理',
    '師',
    '錯',
    '就',
    '錯',
    '[MASK]',
    '抽',
    '了',
    '一',
    '[MASK]',
    '不',
    '理',
    '性',
    '孕',
    '婦',
    '的',
    '血',
    '吧',
    '你',
    '當',
    '下',
    '可',
    '以',
    '不',
    '要',
    '哭',
    '好',
    '好',
    '跟',
    '她',
    '溝',
    '通',
    '請',
    '她',
    '換',
    '人',
    '抽',
    '吧',
    '樓',
    '上',
    '你',
    '好']],
  'text': [['我',
    '覺',
    '得',
    '這',
    '位',
    '護',
    '理',
    '師',
    '錯',
    '就',
    '錯',
    '在',
    '抽',
    '了',
    '一',
    '位',
    '不',
    '理',
    '性',
    '孕',
    '婦',
    '的',
    '血',
    '吧',
    '你',
    '當',
    '下',
    '可',
    '以',
    '不',
    '要',
    '哭',
    '好',
    '好',
    '跟',
    '她',
    '溝',
    '通',
    '請',
    '她',
    '換',
    '人',
    '抽',
    '吧',
    '樓',
    '上',
    '你',
    '好']],
  'mindex': array([[11, 15]]),
  'mindex_bool': array([[ True