## GAN for MLM
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text_gan

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

## Dataset and Dataloader

In [19]:
train_idxs_ds = TensorDataset(torch.tensor(train_idxs))
test_idxs_ds = TensorDataset(torch.tensor(test_idxs))
cx_lenc = LabelEncoder()
cx_lenc.classes_ = ["[PAD]", "BX", "IX", "O"]
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]
adv_lenc = LabelEncoder()
adv_lenc.classes_ = ["fake", "real"]
BV_id = slot_lenc.transform(["BV"])[0]
IV_id = slot_lenc.transform(["IV"])[0]

In [20]:
def make_gendcr_labels(batch):
    # create adv_labels, where cell value is a 1 if it is a variable (BV/IV)
    # otherwise, it's a 0.
    slot_tags = batch["slot_tags"]
    slot_mask = slot_tags != 0
    adv_labels = (slot_tags == BV_id).clone()    
    adv_labels = torch.logical_or(adv_labels, slot_tags==IV_id, out=adv_labels)
    
    # generate GAN real/fake labels
    gen_labels = torch.full_like(slot_tags, -100)
    gen_labels.masked_fill_(adv_labels, 1)
    dcr_real_mask = torch.logical_and(slot_mask, adv_labels.logical_not())
    dcr_fake_mask = torch.logical_and(slot_mask, adv_labels)
    dcr_labels = torch.full_like(slot_tags, -100)
    dcr_labels.masked_fill_(dcr_real_mask, 1)
    dcr_labels.masked_fill_(dcr_fake_mask, 0)
    return {"gen_labels": gen_labels, "dcr_labels": dcr_labels}

In [48]:
tokenizer(bb["masked"], is_split_into_words=True, return_token_type_ids=False, return_attention_mask=False)

{'input_ids': [[101, 738, 3221, 4684, 2970, 103, 671, 103, 3862, 7295, 1921, 4958, 749, 102], [101, 872, 738, 2523, 4735, 3146, 1921, 3819, 5582, 5632, 2346, 872, 4511, 1351, 1168, 2419, 3221, 1914, 2483, 3255, 6865, 103, 103, 6963, 103, 103, 679, 103, 102]]}

In [55]:
def gan_collate_fn(X, data, cx_lenc, slot_lenc, device="cpu"):  
    idxs = [x[0].item() for x in X]
    batch = batched_text_gan(data, idxs)
    
    real_tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          return_token_type_ids=False, return_attention_mask=False)    
    masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True)    
    
    cx_tags = [torch.tensor(cx_lenc.transform(["[PAD]"] + x + ["[PAD]"]))
               for x in batch["cx_tags"]]
    cx_tags = pad_sequence(cx_tags, batch_first=True, padding_value=0)
    cx_tags = cx_tags[:, :real_tokens.input_ids.size(1)]
    
    slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                 for x in batch["slot_tags"]]        
    slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
    slot_tags = slot_tags[:, :real_tokens.input_ids.size(1)]
    batch["cx_tags"] = cx_tags.to(device)
    batch["slot_tags"] = slot_tags.to(device)
    batch["real_text"] = real_tokens.to(device)
    batch["masked_text"] = masked_tokens.to(device)
    batch.update(make_gendcr_labels(batch))
    return batch

In [56]:
bb = gan_collate_fn([test_idxs_ds[202], test_idxs_ds[203]], data, cx_lenc, slot_lenc)

In [57]:
list(bb.keys())

['text',
 'masked',
 'cx_tags',
 'slot_tags',
 'real_text',
 'masked_text',
 'gen_labels',
 'dcr_labels']

In [59]:
## visual check
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["slot_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["cx_tags"][0])])

['[CLS],0', '也,5', '是,5', '直,5', '接,5', '退,3', '一,1', '退,3', '海,5', '闊,5', '天,5', '空,5', '了,5', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],0', '也,3', '是,3', '直,3', '接,3', '退,1', '一,2', '退,2', '海,3', '闊,3', '天,3', '空,3', '了,3', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']


## Eye-balling fake/real labels

In [62]:
## should only show variable sites
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["gen_labels"]!=1, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] [PAD] 退 [PAD] 退 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [63]:
## should only show non-fake sites (discriminator should say 'real')
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["dcr_labels"]!=1, 0)[0, :40])

'[PAD] 也 是 直 接 [PAD] 一 [PAD] 海 闊 天 空 了 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [64]:
## should only show fake sites (discriminator should say 'fake')
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["dcr_labels"]!=0, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] [PAD] 退 [PAD] 退 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

## Model definition

In [65]:
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
import torch.nn as nn
import torch.optim as optim

In [70]:
class ConartModelApricot(BertForMaskedLM):
    def __init__(self, config):
        super(ConartModelApricot, self).__init__(config)
        # inherit self.bert, self.cls (lm head) from super()
        self.lm_cls = self.cls
        self.tok_cls = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
    
    def G_params(self):
        return [self.bert.parameters(), self.cls.parameters()]
    
    def D_params(self):
        return [self.bert.parameters(), self.tok_cls.parameters()]
    
    def forward_G(self, X):
        out = self.forward(X)
        return self.lm_cls(out.last_hidden_state)
    
    def forward_D(self, X):
        out = self.forward(X)
        return self.tok_cls(out.last_hidden_state)
    
    def forward(self, X):
        tokens = X["masked_text"]
        cx_tags = X["cx_tags"]
        slot_tags = X["slot_tags"]
        bert_out = self.bert(**tokens, return_dict=True)
        return bert_out

## Model training

In [67]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
batch_size = 16
train_loader = DataLoader(train_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
test_loader = DataLoader(test_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)

In [68]:
print("batch size: ", batch_size)
print("Training dataset:", len(train_idxs_ds))
print("Testing dataset:", len(test_idxs_ds))

batch size:  16
Training dataset: 10477
Testing dataset: 1165


In [71]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = ConartModelApricot.from_pretrained("ckiplab/bert-base-chinese", num_labels=len(cx_lenc.classes_))
model = model.to(device)

Some weights of ConartModelApricot were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['lm_cls.predictions.transform.dense.weight', 'tok_cls.weight', 'tok_cls.bias', 'lm_cls.predictions.transform.dense.bias', 'lm_cls.predictions.decoder.weight', 'lm_cls.predictions.transform.LayerNorm.weight', 'lm_cls.predictions.transform.LayerNorm.bias', 'lm_cls.predictions.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [72]:
bb = next(iter(test_loader))
model.forward(bb).last_hidden_state.shape

torch.Size([16, 46, 768])

In [73]:
lm_probs = model.forward_G(bb)[0]

In [74]:
lm_probs.shape

torch.Size([46, 21128])

## Generate Adversarial sample

In [102]:
def visualize_adv(adv_data, maxlen=20):
    adv_ids = adv_data["adv_ids"]
    real_ids = adv_data["real_ids"]
    dcr_masks = adv_data["dcr_labels"]
    for i in range(adv_ids.size(0)):
        adv_seq = adv_ids[i].tolist()[:maxlen]
        real_seq = real_ids[i].tolist()[:maxlen]
        dcr_seq = dcr_masks[i].tolist()[:maxlen]
        tr = tokenizer.decode
        for a, r, d in zip(adv_seq, real_seq, dcr_seq):
            ach = tr(a); rch = tr(r)
            ch = rch if d else \
                 f"\x1b[31m{ach}(\x1b[4m{rch}\x1b[0;31m)\x1b[0m"
            print(ch, end='')
        print("")

In [103]:
def generate_adversarials(batch, lm_prob):
    adv_ids = batch["masked_text"].input_ids.clone()    
    gen_labels = batch["gen_labels"]
    adv_ids[gen_labels==1] = torch.multinomial(lm_probs[gen_labels==1], 1).squeeze()
    real_ids = bb["real_text"].input_ids
    dcr_labels = adv_ids == real_ids
    return {"adv_ids": adv_ids, "real_ids": real_ids, 
            "dcr_labels": dcr_labels}
lm_probs = model.forward_G(bb).softmax(dim=2)
adv_out = generate_adversarials(bb, lm_probs)
visualize_adv(adv_out)

[CLS]不是有醫生說成人都不見得能整天了何況是
[CLS]我怎麼[31m想([4m找[0;31m)[0m都[31m辦([4m找[0;31m)[0m不到[UNK][UNK]追殺列車攤大的[UNK]h
[CLS]這種店[31m鎮([4m收[0;31m)[0m一[31m,([4m收[0;31m)[0m反正也沒人想去吃走一下進度
[CLS]說[31m完([4m黑[0;31m)[0m就[31m乾([4m黑[0;31m)[0m吧[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]沒作業就讓我們[31m幹([4m動[0;31m)[0m一[31m下([4m動[0;31m)[0m嘛[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]書名說媽媽有病控制狂有毒的書都去看一[31m覽([4m看[0;31m)[0m
[CLS]能[31m是([4m醒[0;31m)[0m一[31m定([4m個[0;31m)[0m是一[31m樣([4m個[0;31m)[0m吧我認為[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]希望可以把樂高雄這個名字也[31m來([4m改[0;31m)[0m一[31m起([4m改[0;31m)[0m[SEP][PAD][PAD]
[CLS]敢[31m不([4m玩[0;31m)[0m敢[31m。([4m冒[0;31m)[0m[31m##llow([4m險[0;31m)[0m[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]但想先[31m來([4m問[0;31m)[0m[31m聽([4m問[0;31m)[0m看大家的意見[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]頭也是頭髮在肩上[31m,([4m甩[0;31m)[0m來[31m滾([4m甩[0;31m)[0m去[SEP][PAD][PA

## Discriminator

In [None]:
model.forward_D(bb).argmax(2)[0]