## GAN preparation for MLM
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
import torch
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast, BertForMaskedLM, BertModel
from import_conart import conart
from conart.mlm_masks import batched_text_gan
from conart import gan_utils as gu

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

## Dataset and Dataloader

In [9]:
train_idxs_ds = TensorDataset(torch.tensor(train_idxs))
test_idxs_ds = TensorDataset(torch.tensor(test_idxs))
cx_lenc = LabelEncoder()
cx_lenc.classes_ = ["[PAD]", "BX", "IX", "O"]
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]
adv_lenc = LabelEncoder()
adv_lenc.classes_ = ["fake", "real"]
BV_id = slot_lenc.transform(["BV"])[0]
IV_id = slot_lenc.transform(["IV"])[0]

In [20]:
def gan_collate_fn(X, data, cx_lenc, slot_lenc, device="cpu"):  
    idxs = [x[0].item() for x in X]
    batch = batched_text_gan(data, idxs)
    
    real_tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          return_token_type_ids=False, return_attention_mask=False)    
    masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True)    
    
    cx_tags = [torch.tensor(cx_lenc.transform(["[PAD]"] + x + ["[PAD]"]))
               for x in batch["cx_tags"]]
    cx_tags = pad_sequence(cx_tags, batch_first=True, padding_value=0)
    cx_tags = cx_tags[:, :real_tokens.input_ids.size(1)]
    
    slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                 for x in batch["slot_tags"]]        
    slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
    slot_tags = slot_tags[:, :real_tokens.input_ids.size(1)]
    batch["cx_tags"] = cx_tags.to(device)
    batch["slot_tags"] = slot_tags.to(device)
    batch["real_text"] = real_tokens.to(device)
    batch["masked_text"] = masked_tokens.to(device)
    batch.update(gu.make_gendcr_labels(batch, adv_ids=[BV_id, IV_id]))
    return batch

In [21]:
bb = gan_collate_fn([test_idxs_ds[202], test_idxs_ds[203]], data, cx_lenc, slot_lenc)

In [22]:
list(bb.keys())

['text',
 'masked',
 'cx_tags',
 'slot_tags',
 'real_text',
 'masked_text',
 'gen_labels',
 'dcr_labels']

In [23]:
## visual check
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["slot_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["cx_tags"][0])])

['[CLS],0', '也,5', '是,5', '直,5', '接,5', '退,3', '一,1', '退,3', '海,5', '闊,5', '天,5', '空,5', '了,5', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],0', '也,3', '是,3', '直,3', '接,3', '退,1', '一,2', '退,2', '海,3', '闊,3', '天,3', '空,3', '了,3', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']


## Eye-balling fake/real labels

In [24]:
## should only show variable sites
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["gen_labels"]!=1, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] [PAD] 退 [PAD] 退 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [25]:
## should only show non-fake sites (discriminator should say 'real')
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["dcr_labels"]!=1, 0)[0, :40])

'[PAD] 也 是 直 接 [PAD] 一 [PAD] 海 闊 天 空 了 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [26]:
## should only show fake sites (discriminator should say 'fake')
tokenizer.decode(bb["real_text"].input_ids.masked_fill(bb["dcr_labels"]!=0, 0)[0, :40])

'[PAD] [PAD] [PAD] [PAD] [PAD] 退 [PAD] 退 [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

## Model definition

In [27]:
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
import torch.nn as nn
import torch.optim as optim

In [28]:
class ConartModelApricot(BertForMaskedLM):
    def __init__(self, config):
        super(ConartModelApricot, self).__init__(config)
        # inherit self.bert, self.cls (lm head) from super()
        self.lm_cls = self.cls
        self.tok_cls = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
    
    def G_params(self):
        return [self.bert.parameters(), self.cls.parameters()]
    
    def D_params(self):
        return [self.bert.parameters(), self.tok_cls.parameters()]
    
    def forward_G(self, X):
        out = self.forward(X)
        return self.lm_cls(out.last_hidden_state)
    
    def forward_D(self, X):
        out = self.forward(X)
        return self.tok_cls(out.last_hidden_state)
    
    def forward(self, X):
        tokens = X["masked_text"]
        cx_tags = X["cx_tags"]
        slot_tags = X["slot_tags"]
        bert_out = self.bert(**tokens, return_dict=True)
        return bert_out

## Model instantiation

In [29]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
batch_size = 16
train_loader = DataLoader(train_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
test_loader = DataLoader(test_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)

In [30]:
print("batch size: ", batch_size)
print("Training dataset:", len(train_idxs_ds))
print("Testing dataset:", len(test_idxs_ds))

batch size:  16
Training dataset: 10477
Testing dataset: 1165


In [31]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = ConartModelApricot.from_pretrained("ckiplab/bert-base-chinese", num_labels=len(cx_lenc.classes_))
model = model.to(device)

Some weights of ConartModelApricot were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['lm_cls.predictions.bias', 'lm_cls.predictions.transform.dense.bias', 'lm_cls.predictions.transform.LayerNorm.weight', 'lm_cls.predictions.transform.LayerNorm.bias', 'tok_cls.weight', 'lm_cls.predictions.decoder.weight', 'tok_cls.bias', 'lm_cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
bb = next(iter(test_loader))
model.forward(bb).last_hidden_state.shape

torch.Size([16, 171, 768])

In [33]:
lm_probs = model.forward_G(bb)[0]

In [34]:
lm_probs.shape

torch.Size([171, 21128])

## Generate Adversarial sample

In [39]:
lm_probs = model.forward_G(bb).softmax(dim=2)
adv_out = gu.generate_adversarials(bb, lm_probs)
gu.visualize_adv(adv_out, tokenizer)

[CLS]我們再[31m看([4m翻[0;31m)[0m一[31m遍([4m翻[0;31m)[0m繪本[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]愛[31m他([4m吃[0;31m)[0m不[31m敢([4m吃[0;31m)[0m很傷腦筋[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS][31m這([4m算[0;31m)[0m一[31m舉([4m算[0;31m)[0m應該比上次89萬投他票的人還多吧
[CLS]密碼都[31m不([4m改[0;31m)[0m一[31m樣([4m改[0;31m)[0m吧[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]結果意識形態[31m是([4m抹[0;31m)[0m一[31m不([4m抹[0;31m)[0m[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]開始冒出一[31m次([4m滴[0;31m)[0m又一[31m層([4m滴[0;31m)[0m的冷汗[SEP][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]你怎麼不趕快去[31m;([4m死[0;31m)[0m一[31m在([4m死[0;31m)[0m[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]微軟也沒在管不過沒升win10其實有點
[CLS][31m大([4m推[0;31m)[0m[31m查([4m推[0;31m)[0m看這系列文章[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]
[CLS]剩下的都是開出來給高齡員工養老用的我以
[CLS]最後他只差了15票就當選了因為說真的他
[CLS]推樓上[31m、([4m吵[0;31m)[0m歸[31m不([4m吵[0;31m)[0m[SEP][PAD][PAD][

## Discriminator

In [None]:
model.forward_D(bb).argmax(2)[0]