## MLM-tuned Accuracies
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
  * MLM-tuned: `../data/models/apricot_mlm_03` (20.10)  
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from itertools import chain
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text, batched_text_gan, get_equality_constraints
from conart.sample import sample_site

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
model = BertForMaskedLM.from_pretrained('../data/models/apricot_mlm_03')
model = model.to(device)
ckip_model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
ckip_model = ckip_model.to(device)

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

In [9]:
batch = batched_text(data, train_idxs[:5], "vslot")
def get_cnstr_eqs(cxinst):
    cnstr_eqs = {
        "text": "".join(chain.from_iterable(cxinst["text"])),
        "form": cxinst["cnstr_form"],
        "example": cxinst["cnstr_example"],
        "eqs": get_equality_constraints(cxinst)
    }
    return cnstr_eqs
list(batch.keys())

['masked', 'text', 'mindex', 'mindex_bool']

## Calculate Unigram baseline

In [10]:
import re
v_instances = []
for idx in train_idxs:
    try:
        eqs = get_cnstr_eqs(data[idx])    
    except ValueError:
        continue
        
    for v, inst in zip(eqs["form"], eqs["example"]):
        if re.match("[a-z0-9]+", v):            
            v_instances.append(inst)

In [11]:
len(v_instances)

19778

In [12]:
from collections import Counter
uni_freq = Counter(chain.from_iterable(v_instances))
def query_uni_freq(char):
    return uni_freq.get(char, 0)

In [19]:
uni_base_chars = [x[0] for x in uni_freq.most_common()]
uni_top10_chars = [x[0] for x in uni_freq.most_common(10)]
uni_base_ids = tokenizer.convert_tokens_to_ids(uni_base_chars)
uni_top10_ids = tokenizer.convert_tokens_to_ids(uni_top10_chars)
uni_top10_ids

[3647, 2682, 4692, 1391, 5050, 5481, 5464, 6341, 3121, 3119]

In [20]:
len(uni_base_chars)

891

## Compute Topk accuracies

In [27]:
# acc_tables: models: conart, ckip, unigram, random (4) x (Top1, 5, 10)(3)
acc_table = np.zeros((4, 3))
N = 0

for data_idx in tqdm(test_idxs):    
    bb = batched_text(data, [data_idx], 'vslot')

    text = bb["text"][0]
    mask_locs = bb["mindex"][0]
    mask_locs = mask_locs[mask_locs>0]  
    try:
        conart_samples = sample_site(bb, model, tokenizer, n_sample=10, max_len=500)[0]
        ckip_samples = sample_site(bb, ckip_model, tokenizer, n_sample=10, max_len=500)[0]
    except:
        print("Sentence too long, mindex exceeds max_len", data_idx)
        continue
    for i, mask_idx in enumerate(mask_locs):
        tgt_char = text[mask_idx]
        conart_preds = conart_samples["ids"][i]    
        ckip_preds = ckip_samples["ids"][i]
        tgt_idx = tokenizer.convert_tokens_to_ids(tgt_char)
        # print("mask_idx, char, idx: ", mask_idx, tgt_char, tgt_idx)
        random_choices = np.random.choice(uni_base_ids, 10)
        acc_table[0] += np.array([
              tgt_idx in conart_preds[:1],
              tgt_idx in conart_preds[:5],
              tgt_idx in conart_preds[:10]])
        acc_table[1] += np.array([
              tgt_idx in ckip_preds[:1],
              tgt_idx in ckip_preds[:5],
              tgt_idx in ckip_preds[:10]])
        acc_table[2] += np.array([
              tgt_idx in uni_top10_ids[:1],
              tgt_idx in uni_top10_ids[:5],
              tgt_idx in uni_top10_ids[:10],])
        
        acc_table[3] += np.array([
              tgt_idx in random_choices[:1],
              tgt_idx in random_choices[:5],
              tgt_idx in random_choices[:10],])
        N += 1

  0%|          | 0/1165 [00:00<?, ?it/s]

Sentence too long, mindex exceeds max_len 7417


In [34]:
import pandas as pd
acc_dfr = pd.DataFrame(acc_table/N, columns=["Top1", "Top5", "Top10"], index=["CxLM", "Bert-Base", "Unigram", "Random"])
acc_dfr

Unnamed: 0,Top1,Top5,Top10
CxLM,0.300522,0.507031,0.599839
Bert-Base,0.061069,0.132583,0.182804
Unigram,0.047409,0.156689,0.216553
Random,0.000402,0.004821,0.011249


## Output Table

In [36]:
print(acc_dfr.to_latex())

\begin{tabular}{lrrr}
\toprule
{} &      Top1 &      Top5 &     Top10 \\
\midrule
CxLM      &  0.300522 &  0.507031 &  0.599839 \\
Bert-Base &  0.061069 &  0.132583 &  0.182804 \\
Unigram   &  0.047409 &  0.156689 &  0.216553 \\
Random    &  0.000402 &  0.004821 &  0.011249 \\
\bottomrule
\end{tabular}

