## MLM-tuned evals
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
  * MLM-tuned: `../data/models/apricot_mlm_01` (20.10)
  * MLM-tuned: `../data/models/apricot_mlm_02` (20.10)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text, batched_text_gan
from conart.gan_utils import generate_adversarials

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
model = BertForMaskedLM.from_pretrained('../data/models/apricot_mlm_03')
model = model.to(device)

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

In [9]:
batch = batched_text(data, train_idxs[:5], "vslot")
list(batch.keys())

['masked', 'text', 'mindex', 'mindex_bool']

In [10]:
batch["mindex"]

array([[135, 137,  -1,  -1,  -1,  -1],
       [ 13,  15,  -1,  -1,  -1,  -1],
       [  3,   5,  -1,  -1,  -1,  -1],
       [  0,   1,   2,   4,   5,   6],
       [  0,   2,  -1,  -1,  -1,  -1]])

## Function definitions

In [11]:
loss_fct = CrossEntropyLoss(reduction="none")
def make_masked_ylabels(mindex: np.ndarray, ori_ids: torch.Tensor):            
    ylabels = torch.zeros_like(ori_ids)-100
    for i in range(mindex.shape[0]):
        # add 1 for the extra [CLS] token added after tokenizer
        j = mindex[i]+1        
        j = j[(j>=0) & (j<512-2)]        
        ylabels[i, j] = ori_ids[i, j]
    return ylabels

def compute_loss(logits, ylabels):
    return loss_fct(logits, ylabels)

In [12]:
def compute_mlm_loss(model, train_idxs, use_mask, debug=False):
    batch_size = 16
    loss_vec = []    
    
    for i in range(0, len(train_idxs), batch_size):
        data_idxs = train_idxs[i:i+batch_size]
        batch = batched_text(data, data_idxs, use_mask)

        # compute perplexity    
        X = tokenizer(batch["masked"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        X = X.to(device)
        Xori = tokenizer(batch["text"], return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
        ylabels = make_masked_ylabels(batch["mindex"], Xori["input_ids"])
        ylabels = ylabels.to(device)        
        with torch.no_grad():
            out = model(**X)
            batch_loss = compute_loss(out.logits.swapdims(2, 1), ylabels)
            # average over non-zero token losses
            batch_loss = batch_loss.sum(1) / (batch_loss>0).sum(1)
            loss_vec.extend(batch_loss.cpu().tolist())
        if debug: break
        
    if debug:        
        return {"loss_vec": np.array(loss_vec), 
                "batch": batch, "ylabels": ylabels, "out": out}
    else:
        return np.array(loss_vec)

In [13]:
from itertools import product
Mtest = len(test_idxs)
n_split = len(cv_splits)
def make_design_iter():
    return product(["cnstr", "cslot", "vslot"], # cnstr elements 
                   ["raw", "shifted", "random"]) # masked type

def get_mask_condition(cx_elem, mask_type):
    cx_code = "cx" if cx_elem == "cnstr" else cx_elem
    mask_code = "" if mask_type == "raw" else mask_type+"-"
    return mask_code+cx_code

def print_stat(x):
    print("M={:.4f}, SD={:.4f} ({} NaNs)"
          .format(np.nanmean(x), np.nanstd(x), np.sum(np.isnan(x))))

def make_buffer(n_split=10):
    ret = {}
    for cx_elem, mtype in make_design_iter():
        ret[f"{cx_elem}_{mtype}"] = \
            np.zeros((Mtest, n_split))
    return ret

## Computing MLM Losses

In [14]:
## masked
n_condition = len(list(make_design_iter()))
losses = make_buffer(1)
split_idx = 0
train_idxs, test_idxs = cv_splits[split_idx].values()
pbar = tqdm(total=n_condition)
for cx_elem, mtype in make_design_iter():        
    cond_text = f"{cx_elem}_{mtype}"
    pbar.set_description(f"{split_idx+1: 2d}. {cond_text}")
    m_cond = get_mask_condition(cx_elem, mtype)        
    loss_x = compute_mlm_loss(model, test_idxs, m_cond)        
    losses[cond_text][:, split_idx] = loss_x            
    pbar.update(1)
pbar.close()

  0%|          | 0/9 [00:00<?, ?it/s]

In [15]:
[(k,np.mean(x)) for k, x in losses.items()]

[('cnstr_raw', 4.223324037336069),
 ('cnstr_shifted', 4.566821129403833),
 ('cnstr_random', 3.3012096010116547),
 ('cslot_raw', 2.711533956908393),
 ('cslot_shifted', 2.326555755823354),
 ('cslot_random', 2.1128639750665528),
 ('vslot_raw', 2.682894103944538),
 ('vslot_shifted', 3.078818427914438),
 ('vslot_random', 2.7265981398804775)]

## Generate samples

In [16]:
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]

def sample_varslots(model, train_idxs):
    batch_size = 16
    max_len = 256
    samples = []  
    
    for i in range(0, len(train_idxs), batch_size):
        data_idxs = train_idxs[i:i+batch_size]
        batch = batched_text_gan(data, data_idxs)
        
        real_tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          return_token_type_ids=False, return_attention_mask=False, 
                          max_length=max_len)    
        masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          max_length=max_len)    
        
        slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                     for x in batch["slot_tags"]]        
        slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
        slot_tags = slot_tags[:, :real_tokens.input_ids.size(1)]                   
        
        BV_id = slot_lenc.transform(["BV"])[0]
        IV_id = slot_lenc.transform(["IV"])[0]
        adv_labels = (slot_tags == BV_id).clone()        
        adv_labels = torch.logical_or(adv_labels, slot_tags==IV_id, out=adv_labels)
                
        batch["slot_tags"] = slot_tags.to(device)
        batch["real_text"] = real_tokens.to(device)
        batch["masked_text"] = masked_tokens.to(device)
        
        # generate GAN real/fake labels
        gen_labels = torch.full_like(slot_tags, 0)
        gen_labels.masked_fill_(adv_labels, 1)
        batch["gen_labels"] = gen_labels        
        
        with torch.no_grad():            
            out = model(**masked_tokens)
            adv_out = generate_adversarials(batch, out.logits.softmax(axis=2))            
            samples.append(adv_out)
        break
            
    return samples


In [17]:
test_idxs[1:3]

[6598, 4660]

In [18]:
batch["mindex"]

array([[135, 137,  -1,  -1,  -1,  -1],
       [ 13,  15,  -1,  -1,  -1,  -1],
       [  3,   5,  -1,  -1,  -1,  -1],
       [  0,   1,   2,   4,   5,   6],
       [  0,   2,  -1,  -1,  -1,  -1]])

In [19]:
probs[2, [6,8], :].argsort()[-10:]

NameError: name 'probs' is not defined

## Sample variable site

In [20]:
from conart import gan_utils as gu
import torch.nn.functional as F
from typing import List, Dict

def sample_site(batch, merge_pair2=False):
    masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          max_length=200)  
    with torch.no_grad():
        masked_tokens = masked_tokens.to(device)
        out = model(**masked_tokens)
        probs = F.log_softmax(out.logits, dim=2).cpu().numpy()
    
    mindex = batch["mindex"]
    
    samples = []
    for i in range(mindex.shape[0]):
        mindex_x = mindex[i,:]+1
        mindex_x = mindex_x[mindex_x>0]
        probs_x = probs[i, mindex_x, :]
        if len(mindex_x) == 2 and merge_pair2:            
            probs_sum = probs[i, mindex_x, :].sum(0)
            arg_x = probs_sum.argsort()[::-1][:10]            
            samples.append({"ids": arg_x, 
                            "probs": probs_sum[arg_x]})
        else:
            arg_x = probs[i, mindex_x, :].argsort(axis=1)[:, ::-1][:, :10]
            prob_arg = probs[i, mindex_x, :].argsort(axis=1)[:, ::-1][:, :10]
            samples.append({"ids": arg_x, 
                            "probs": np.take_along_axis(probs_x, arg_x, 1)})
    return samples

In [21]:
def generate_guesses(test_idx):
    batch = batched_text(data, test_idxs[test_idx:test_idx+1], 'vslot')
    print("Masked", "".join(batch["masked"][0]))
    print("Origin", "".join(batch["text"][0]))
    samples = sample_site(batch)[0]
    prob_sort = samples["probs"].sum(0).argsort()
    print("Model (separated): ", tokenizer.batch_decode(samples["ids"]))
    samples = sample_site(batch, merge_pair2=True)[0]
    prob_sort = samples["probs"].sum(0).argsort()
    print("Model (merged): ", tokenizer.batch_decode(samples["ids"]))

In [22]:
generate_guesses(15)

Masked 這麼熟門熟路吳先生是誰吳先生很懂喔我連[MASK]都沒[MASK]過他都知道
Origin 這麼熟門熟路吳先生是誰吳先生很懂喔我連聽都沒聽過他都知道
Model (separated):  ['聽 看 查 想 見 去 問 找 提 講', '聽 見 看 想 問 查 找 去 提 講']
Model (merged):  ['聽', '看', '見', '想', '查', '問', '去', '找', '提', '講']


In [23]:
generate_guesses(7)

Masked 原本買小的但[MASK]一[MASK]有時候看鏡子整體的手感覺手錶好小
Origin 原本買小的但戴一戴有時候看鏡子整體的手感覺手錶好小
Model (separated):  ['洗 想 算 買 修 看 用 逛 摸 動', '想 洗 算 買 修 動 摸 看 用 玩']
Model (merged):  ['想', '洗', '算', '買', '修', '看', '動', '用', '摸', '玩']


In [24]:
generate_guesses(3)

Masked 運動強度沒有太高圖個[MASK]一[MASK]
Origin 運動強度沒有太高圖個動一動
Model (separated):  ['升 緩 加 動 忍 玩 瘦 練 晃 撐', '緩 忍 升 加 動 醒 笑 想 練 晃']
Model (merged):  ['升', '緩', '忍', '加', '動', '醒', '練', '晃', '玩', '想']


## Output pickle