## GAN for MLM
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text_gan

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')


## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

## Dataset and Dataloader

In [9]:
train_idxs_ds = TensorDataset(torch.tensor(train_idxs))
test_idxs_ds = TensorDataset(torch.tensor(test_idxs))
cx_lenc = LabelEncoder()
cx_lenc.classes_ = ["[PAD]", "BX", "IX", "O"]
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]

In [10]:
def gan_collate_fn(X, data, cx_lenc, slot_lenc, device="cpu"):  
    idxs = [x[0].item() for x in X]
    batch = batched_text_gan(data, idxs)
    
    tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True)    
    
    cx_tags = [torch.tensor(cx_lenc.transform(["[PAD]"] + x + ["[PAD]"]))
               for x in batch["cx_tags"]]
    cx_tags = pad_sequence(cx_tags, batch_first=True, padding_value=0)
    cx_tags = cx_tags[:, :tokens.input_ids.size(1)]
    
    slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                 for x in batch["slot_tags"]]        
    slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
    slot_tags = slot_tags[:, :tokens.input_ids.size(1)]
    batch["cx_tags"] = cx_tags.to(device)
    batch["slot_tags"] = slot_tags.to(device)
    batch["text"] = tokens.to(device)
    
    return batch

In [11]:
bb = gan_collate_fn([test_idxs_ds[202], test_idxs_ds[203]], data, cx_lenc, slot_lenc)

In [12]:
bb

{'text': {'input_ids': tensor([[ 101,  738, 3221, 4684, 2970, 6842,  671, 6842, 3862, 7295, 1921, 4958,
           749,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0],
         [ 101,  872,  738, 2523, 4735, 3146, 1921, 3819, 5582, 5632, 2346,  872,
          4511, 1351, 1168, 2419, 3221, 1914, 2483, 3255, 6865, 6134, 6888, 6963,
          6134, 6888,  679, 1962,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1]])},
 'cx_tags': tensor([[0, 3, 3, 3, 3, 1, 2, 2, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  

In [13]:
## visual check
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["text"].input_ids[0]), bb["slot_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["text"].input_ids[0]), bb["cx_tags"][0])])

['[CLS],0', '也,5', '是,5', '直,5', '接,5', '退,3', '一,1', '退,3', '海,5', '闊,5', '天,5', '空,5', '了,5', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],0', '也,3', '是,3', '直,3', '接,3', '退,1', '一,2', '退,2', '海,3', '闊,3', '天,3', '空,3', '了,3', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']


In [14]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
batch_size = 16
train_loader = DataLoader(train_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
test_loader = DataLoader(test_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)

In [15]:
print("batch size: ", batch_size)
print("Training dataset:", len(train_idxs_ds))
print("Testing dataset:", len(test_idxs_ds))

batch size:  16
Training dataset: 10477
Testing dataset: 1165


## Model definition

In [16]:
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
import torch.nn as nn
import torch.optim as optim

In [17]:
class ConartModelApricot(BertForMaskedLM):
    def __init__(self, config):
        super(ConartModelApricot, self).__init__(config)
        # inherit self.bert, self.cls (lm head) from super()
        self.lm_cls = self.cls
        self.tok_cls = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
    
    def G_params(self):
        return [self.bert.parameters(), self.cls.parameters()]
    
    def D_params(self):
        return [self.bert.parameters(), self.tok_cls.parameters()]
    
    def forward_G(self, X):
        out = self.forward(X)
        return self.lm_cls(out.last_hidden_state)
    
    def forward_D(self, X):
        out = self.forward(X)
        return self.tok_cls(out.last_hidden_state)
    
    def forward(self, X):
        tokens = X["text"]
        cx_tags = X["cx_tags"]
        slot_tags = X["slot_tags"]
        bert_out = self.bert(**tokens, return_dict=True)
        return bert_out

In [18]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = ConartModelApricot.from_pretrained("ckiplab/bert-base-chinese", num_labels=len(cx_lenc.classes_))
model = model.to(device)

Some weights of ConartModelApricot were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['lm_cls.predictions.transform.LayerNorm.weight', 'lm_cls.predictions.transform.dense.weight', 'lm_cls.predictions.bias', 'lm_cls.predictions.decoder.weight', 'tok_cls.weight', 'tok_cls.bias', 'lm_cls.predictions.transform.dense.bias', 'lm_cls.predictions.transform.LayerNorm.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
bb = next(iter(test_loader))
model.forward(bb).last_hidden_state.shape

torch.Size([16, 73, 768])

In [23]:
lm_probs = model.forward_G(bb)[0]

In [24]:
lm_probs.shape

torch.Size([73, 21128])

In [29]:
slot_lenc.classes_

['[PAD]', 'BC', 'IC', 'BV', 'IV', 'O']

## Prepare fake/real mask

In [97]:
BV_id = slot_lenc.transform(["BV"])[0]
IV_id = slot_lenc.transform(["IV"])[0]
slot_mask = bb["slot_tags"] != 0
adv_labels = (bb["slot_tags"] == BV_id).clone()
print(adv_labels.data_ptr())
adv_labels = torch.logical_or(adv_labels, bb["slot_tags"] == IV_id, out=adv_labels)
print(adv_labels.data_ptr())

140646856712704
140646856712704


In [98]:
# generate GAN real/fake labels
gen_labels = torch.full_like(bb["slot_tags"], -100)
gen_labels.masked_fill_(adv_labels, 1)
dcr_real_mask = torch.logical_and(slot_mask, adv_labels.logical_not())
dcr_fake_mask = torch.logical_and(slot_mask, adv_labels)
dcr_labels = torch.full_like(bb["slot_tags"], -100)
dcr_labels.masked_fill_(dcr_real_mask, 1)
dcr_labels.masked_fill_(dcr_fake_mask, 0)

tensor([[-100,    1,    1,  ..., -100, -100, -100],
        [-100,    1,    0,  ..., -100, -100, -100],
        [-100,    1,    1,  ..., -100, -100, -100],
        ...,
        [-100,    1,    0,  ..., -100, -100, -100],
        [-100,    1,    1,  ..., -100, -100, -100],
        [-100,    1,    1,  ..., -100, -100, -100]], device='cuda:0')

## Eye-balling fake/real labels

In [100]:
## should only show variable sites
tokenizer.decode(bb["text"].input_ids.masked_fill(gen_labels!=1, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] 看 [PAD] [PAD] 看 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [102]:
## should only show non-fake sites (discriminator should say 'real')
tokenizer.decode(bb["text"].input_ids.masked_fill(dcr_labels!=1, 0)[0, :40])

'[PAD] 內 文 連 [PAD] 都 沒 [PAD] 的 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [103]:
## should only show fake sites (discriminator should say 'fake')
tokenizer.decode(bb["text"].input_ids.masked_fill(dcr_labels!=0, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] 看 [PAD] [PAD] 看 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

## Generate Adversarial sample

In [111]:
adv_ids = bb["text"].input_ids.clone()
lm_probs = model.forward_G(bb).softmax(dim=2)
adv_ids[gen_labels==1] = torch.multinomial(lm_probs[gen_labels==1], 1).squeeze()

In [118]:
tokenizer.batch_decode(adv_ids[:, :30])

['[CLS] 內 文 連 看 都 沒 看 的 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 1 喝 都 喝 了 也 沒 辦 法 催 吐 吧 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 可 以 一 推 再 推 呀 愈 多 人 推 就 代 表 那 首 愈 讚 [UNK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 連 拆 都 沒 拆 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 後 來 改 吃 鹽 埕 區 的 水 源 羊 肉 了 要 試 一 試 元 豐 羊 肉 嗎 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 人 一 多 就 整 個 炸 掉 了 吧 過 年 人 車 都 回 來 了 [UNK] 對 阿 附 近 多 出 一 堆 車 亂',
 '[CLS] 叫 也 叫 不 回 說 剛 剛 那 個 動 作 讓 她 覺 得 在 別 人 面 前 丟 臉 看 電 影 [SEP] [PAD] [PAD]',
 '[CLS] 多 買 一 片 是 一 片 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 '[CLS] 你 要 嘛 聽 一 聽 就 好 [SEP] [P

## Discriminator

In [None]:
model.forward_D(bb).argmax(2)[0]