## GAN training for MLM
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
from itertools import chain
import torch
import numpy as np
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.optim as optim
from torchmetrics import MeanMetric
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
from import_conart import conart
from conart.mlm_masks import batched_text_gan
from conart import gan_utils as gu

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

## Dataset and Dataloader

In [9]:
train_idxs_ds = TensorDataset(torch.tensor(train_idxs))
test_idxs_ds = TensorDataset(torch.tensor(test_idxs))
cx_lenc = LabelEncoder()
cx_lenc.classes_ = ["[PAD]", "BX", "IX", "O"]
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]
adv_lenc = LabelEncoder()
adv_lenc.classes_ = ["fake", "real"]
BV_id = slot_lenc.transform(["BV"])[0]
IV_id = slot_lenc.transform(["IV"])[0]

In [10]:
def gan_collate_fn(X, data, cx_lenc, slot_lenc, device="cpu"):  
    idxs = [x[0].item() for x in X]
    batch = batched_text_gan(data, idxs)
    
    max_len = 200
    real_tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          return_token_type_ids=False, return_attention_mask=False, 
                          max_length=max_len)    
    masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          max_length=max_len)    
    
    cx_tags = [torch.tensor(cx_lenc.transform(["[PAD]"] + x + ["[PAD]"]))
               for x in batch["cx_tags"]]
    cx_tags = pad_sequence(cx_tags, batch_first=True, padding_value=0)
    cx_tags = cx_tags[:, :real_tokens.input_ids.size(1)]
    
    slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                 for x in batch["slot_tags"]]        
    slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
    slot_tags = slot_tags[:, :real_tokens.input_ids.size(1)]
    batch["cx_tags"] = cx_tags.to(device)
    batch["slot_tags"] = slot_tags.to(device)
    batch["real_text"] = real_tokens.to(device)
    batch["masked_text"] = masked_tokens.to(device)
    batch.update(gu.make_gendcr_labels(batch, adv_ids=[BV_id, IV_id]))
    return batch

In [11]:
bb = gan_collate_fn([test_idxs_ds[202], test_idxs_ds[203]], data, cx_lenc, slot_lenc)

In [12]:
list(bb.keys())

['text',
 'masked',
 'cx_tags',
 'slot_tags',
 'real_text',
 'masked_text',
 'gen_labels']

In [13]:
## visual check
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["slot_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["cx_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["gen_labels"][0])])

['[CLS],0', '也,5', '是,5', '直,5', '接,5', '退,3', '一,1', '退,3', '海,5', '闊,5', '天,5', '空,5', '了,5', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],0', '也,3', '是,3', '直,3', '接,3', '退,1', '一,2', '退,2', '海,3', '闊,3', '天,3', '空,3', '了,3', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],-100', '也,-100', '是,-100', '直,-100', '接,-100', '退,1', '一,-100', '退,1', '海,-100', '闊,-100', '天,-100', '空,-100', '了,-100', '[SEP],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100']


In [14]:
bb["gen_labels"]

tensor([[-100, -100, -100, -100, -100,    1, -100,    1, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100,    1,    1, -100,
            1,    1, -100,    1, -100]])

## Model definition

In [15]:

class ConartModelApricot(BertForMaskedLM):
    def __init__(self, config):
        super(ConartModelApricot, self).__init__(config)              
    
    def forward(self, X, labels=None):
        tokens = X["masked_text"]        
        out = super().forward(**tokens, labels=labels, return_dict=True)             
        return out

## Check adversarial samples

### Generator

In [20]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
debug_loader = DataLoader(test_idxs_ds, batch_size=2, shuffle=False, 
                         collate_fn=collate_fn)

In [21]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = ConartModelApricot.from_pretrained("ckiplab/bert-base-chinese")
model = model.to(device)

In [22]:
bb = next(iter(debug_loader))
masked_ids = bb["gen_labels"].masked_scatter(bb["gen_labels"]==1, bb["real_text"].input_ids)
lm_out = model.forward(bb, masked_ids)
gu.visualize_gen(bb, masked_ids, tokenizer)

[CLS]我覺得這位護理師[31m[MASK]([4m錯[0;31m)[0m就[31m[MASK]([4m錯[0;31m)[0m在[31m[MASK]([4m抽[0;31m)[0m了一位不理性
[CLS]1km補給品[31m[MASK]([4m買[0;31m)[0m一[31m[MASK]([4m買[0;31m)[0m窩著先看這幾天發展比


In [23]:
lm_out.loss

tensor(12.0936, device='cuda:0', grad_fn=<NllLossBackward0>)

## Model Training

In [16]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
batch_size = 16
train_loader = DataLoader(train_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
test_loader = DataLoader(test_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
print("batch size: ", batch_size)
print("Training dataset:", len(train_idxs_ds))
print("Testing dataset:", len(test_idxs_ds))


batch size:  16
Training dataset: 10477
Testing dataset: 1165


In [19]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model = ConartModelApricot.from_pretrained("ckiplab/bert-base-chinese")
model = model.to(device)

In [20]:
# train generator
from transformers import get_linear_schedule_with_warmup
writer = SummaryWriter(log_dir="../data/tb_logs/train_mlm_03", 
                       comment='epoch=5, lr=1e-4, with linear scheduler')

to_train_lm = False

optim_G = optim.AdamW(model.parameters(), lr=1e-4)
data_loader = train_loader
n_epoch = 2
scheduler_G = get_linear_schedule_with_warmup(optim_G, 50, len(data_loader)*n_epoch)

lm_loss_epoch = MeanMetric()

iter_idx = 0
for epoch_i in range(n_epoch):       
    for batch_idx, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
        #  mlm prediction
        # masked_ids = batch["gen_labels"].masked_scatter(
        #                 batch["gen_labels"]==1, batch["real_text"].input_ids)
        lm_out = model.forward(batch, batch["real_text"].input_ids)
        lm_loss = lm_out.loss
        writer.add_scalar('lm_loss', lm_loss.item(), iter_idx)
        lm_loss_epoch.update(lm_loss.item())
        
        # train generator
        optim_G.zero_grad()
        lm_loss.backward()
        optim_G.step()
        scheduler_G.step()
     
        iter_idx += 1        
    writer.add_scalar("lm_loss_epoch", lm_loss_epoch.compute(), epoch_i)
    lm_loss_epoch.reset()
    
writer.close()

  0%|          | 0/655 [00:00<?, ?it/s]

  0%|          | 0/655 [00:00<?, ?it/s]

In [21]:
model.save_pretrained("../data/models/apricot_mlm_03")

In [None]:
# alternating D/G