## MLM slots evals
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
model = model.to(device)

## Evaluate splits

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

In [9]:
batch = batched_text(data, train_idxs[:5], "vslot")
list(batch.keys())

['masked', 'text', 'mindex', 'mindex_bool']

In [10]:
batch["mindex"]

array([[135, 137,  -1,  -1,  -1,  -1],
       [ 13,  15,  -1,  -1,  -1,  -1],
       [  3,   5,  -1,  -1,  -1,  -1],
       [  0,   1,   2,   4,   5,   6],
       [  0,   2,  -1,  -1,  -1,  -1]])

In [39]:
loss_fct = CrossEntropyLoss(reduction="none")
def make_masked_ylabels(mindex: np.ndarray, ori_ids: torch.Tensor):            
    ylabels = torch.zeros_like(ori_ids)-100
    for i in range(mindex.shape[0]):
        # add 1 for the extra [CLS] token added after tokenizer
        j = mindex[i]+1        
        j = j[(j>=0) & (j<512-2)]        
        ylabels[i, j] = ori_ids[i, j]
    return ylabels

def compute_loss(logits, ylabels):
    return loss_fct(logits, ylabels)

In [12]:
def compute_mlm_loss(train_idxs, use_mask, debug=False):
    batch_size = 16
    loss_vec = []    
    
    for i in tqdm(range(0, len(train_idxs), batch_size)):
        data_idxs = train_idxs[i:i+batch_size]
        batch = batched_text(data, data_idxs, use_mask)

        # compute perplexity    
        X = tokenizer(batch["masked"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        X = X.to(device)
        Xori = tokenizer(batch["text"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        ylabels = make_masked_ylabels(batch["mindex"], Xori["input_ids"])
        ylabels = ylabels.to(device)        
        with torch.no_grad():
            out = model(**X)
            batch_loss = compute_loss(out.logits.swapdims(2, 1), ylabels)
            # average over non-zero token losses
            batch_loss = batch_loss.sum(1) / (batch_loss>0).sum(1)
            loss_vec.extend(batch_loss.cpu().tolist())
        if debug: break
        
    if debug:        
        return {"loss_vec": np.array(loss_vec), 
                "batch": batch, "ylabels": ylabels, "out": out}
    else:
        return np.array(loss_vec)

## eval VSlot

In [51]:
vslot_losses = compute_mlm_loss(test_idxs, "vslot")
np.mean(vslot_losses), np.std(vslot_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

(9.508307444639984, 2.4535534203544205)

## random vslot

In [23]:
rv_losses = compute_mlm_loss(test_idxs, "random-vslot")
print("# of nan: ", np.sum(np.isnan(rv_losses)))
np.nanmean(rv_losses), np.nanstd(rv_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

# of nan:  0


(6.354940402090685, 2.3365540636265774)

## eval cslot

In [15]:
cslot_losses = compute_mlm_loss(test_idxs, "cslot")
np.mean(cslot_losses), np.std(cslot_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

(8.085005015365043, 2.435851597715492)

## Random cslot

In [24]:
rc_losses = compute_mlm_loss(test_idxs, "random-cslot")
print("# of nan: ", np.sum(np.isnan(rc_losses)))
np.nanmean(rc_losses), np.nanstd(rc_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

# of nan:  0


(7.676952773194296, 2.6099878297632646)

## Eval Construtions

In [54]:
np.roll([1,1,1,0,0,0], 2, axis=0)

array([0, 0, 1, 1, 1, 0])

In [16]:
cx_losses = compute_mlm_loss(test_idxs, "cx")
np.mean(cx_losses), np.std(cx_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

(8.558957640602864, 1.5269846388693975)

## Random constr

In [52]:
rx_losses = compute_mlm_loss(test_idxs, "random-cx")
np.mean(rx_losses), np.std(rx_losses)

  0%|          | 0/73 [00:00<?, ?it/s]

(6.123976794320114, 2.01946610733562)

## sanity check

In [124]:
batch = {'masked': [['今', '[MASK]']],
 'text': [['今', '天']],
 'mindex': np.array([[1]]),
 'mindex_bool': np.array([[ True]])}

In [125]:
X = tokenizer(batch["masked"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
X = X.to(device)
Xori = tokenizer(batch["text"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
ylabels = make_masked_ylabels(batch["mindex"], Xori["input_ids"])
ylabels = ylabels.to(device)
with torch.no_grad():
    out = model(**X, labels=ylabels)
    loss_batch = out.loss.cpu().item()
    print(loss_batch)

5.4197540283203125


In [127]:
loss = compute_loss(out.logits.swapdims(2, 1), ylabels)
(loss.sum(1) / (loss>0).sum()).cpu().item()

5.4197540283203125

## For debug

In [48]:
aa = compute_mlm_loss(test_idxs[1:2], "random-cslot", debug=True)
aa["ylabels"]

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         1044, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]], device='cuda:0')

In [50]:
aa = compute_mlm_loss([test_idxs[0]], "random-cslot", debug=True)
aa

  0%|          | 0/1 [00:00<?, ?it/s]

{'loss_vec': array([5.19510221]),
 'batch': {'masked': [['我',
    '覺',
    '得',
    '這',
    '位',
    '護',
    '理',
    '師',
    '錯',
    '就',
    '錯',
    '[MASK]',
    '抽',
    '了',
    '一',
    '[MASK]',
    '不',
    '理',
    '性',
    '孕',
    '婦',
    '的',
    '血',
    '吧',
    '你',
    '當',
    '下',
    '可',
    '以',
    '不',
    '要',
    '哭',
    '好',
    '好',
    '跟',
    '她',
    '溝',
    '通',
    '請',
    '她',
    '換',
    '人',
    '抽',
    '吧',
    '樓',
    '上',
    '你',
    '好']],
  'text': [['我',
    '覺',
    '得',
    '這',
    '位',
    '護',
    '理',
    '師',
    '錯',
    '就',
    '錯',
    '在',
    '抽',
    '了',
    '一',
    '位',
    '不',
    '理',
    '性',
    '孕',
    '婦',
    '的',
    '血',
    '吧',
    '你',
    '當',
    '下',
    '可',
    '以',
    '不',
    '要',
    '哭',
    '好',
    '好',
    '跟',
    '她',
    '溝',
    '通',
    '請',
    '她',
    '換',
    '人',
    '抽',
    '吧',
    '樓',
    '上',
    '你',
    '好']],
  'mindex': array([[11, 15]]),
  'mindex_bool': array([[ True