## MLM-tuned Accuracies
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
  * MLM-tuned: `../data/models/apricot_mlm_03` (20.10)  
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from itertools import chain
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text, batched_text_gan, get_equality_constraints
from conart.sample import sample_site

In [4]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [5]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [6]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [7]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
# model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
model = BertForMaskedLM.from_pretrained('../data/models/apricot_mlm_03')
model = model.to(device)
ckip_model = BertForMaskedLM.from_pretrained('ckiplab/bert-base-chinese')
ckip_model = ckip_model.to(device)

## Checking input data

In [8]:
train_idxs, test_idxs = cv_splits[0].values()

In [9]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

In [10]:
batch = batched_text(data, train_idxs[:5], "vslot")
def get_cnstr_eqs(cxinst):
    cnstr_eqs = {
        "text": "".join(chain.from_iterable(cxinst["text"])),
        "form": cxinst["cnstr_form"],
        "example": cxinst["cnstr_example"],
        "eqs": get_equality_constraints(cxinst)
    }
    return cnstr_eqs
list(batch.keys())

['masked', 'text', 'mindex', 'mindex_bool']

## Generage samples

In [13]:
def generate_guesses(test_idx):
    batch = batched_text(data, test_idxs[test_idx:test_idx+1], 'vslot')
    print("Masked", "".join(batch["masked"][0]))
    print("Origin", "".join(batch["text"][0]))
    samples = sample_site(batch, model, tokenizer)[0]
    prob_sort = samples["probs"].sum(0).argsort()
    print("Model (separated): ", tokenizer.batch_decode(samples["ids"]))
    samples = sample_site(batch, model, tokenizer, merge_pair2=True)[0]
    prob_sort = samples["probs"].sum(0).argsort()
    print("Model (merged): ", tokenizer.batch_decode(samples["ids"]))

In [14]:
generate_guesses(100)

Masked 裡面的紅蘿蔔和馬鈴薯也很入味泡麵這樣[MASK]一[MASK]味道很棒耶
Origin 裡面的紅蘿蔔和馬鈴薯也很入味泡麵這樣拌一拌味道很棒耶
Model (separated):  ['拌 煮 吃 炒 炸 涮 夾 泡 蒸 烤', '拌 吃 煮 炸 涮 炒 泡 夾 烤 做']
Model (merged):  ['拌', '吃', '煮', '炸', '炒', '涮', '夾', '泡', '烤', '蒸']


In [23]:
generate_guesses(40)

Masked 什麼叫不[MASK]白不[MASK]當那兩千多塊不用付嗎
Origin 什麼叫不刷白不刷當那兩千多塊不用付嗎
Model (separated):  ['吃 賺 拿 給 買 付 算 領 做 收', '上 敢 用 能 拿 要 難 好 擔 吃']
Model (merged):  ['拿', '吃', '敢', '用', '給', '上', '能', '要', '買', '做']


In [24]:
generate_guesses(3)

Masked 運動強度沒有太高圖個[MASK]一[MASK]
Origin 運動強度沒有太高圖個動一動
Model (separated):  ['升 緩 加 動 忍 玩 瘦 練 晃 撐', '緩 忍 升 加 動 醒 笑 想 練 晃']
Model (merged):  ['升', '緩', '忍', '加', '動', '醒', '練', '晃', '玩', '想']
