## GAN training for MLM
* Inputs:
  * raw data: `../data/raw_cx_data.json` (10.01)
  * CV splits: `../data/cv_splits_10.json` (10.01)
* Outputs:
  * (none)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pickle
from pathlib import Path
from hashlib import sha256
from tqdm.auto import tqdm
from itertools import chain
import torch
import numpy as np
from torch.nn import CrossEntropyLoss, NLLLoss
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.optim as optim
from torchmetrics import MeanMetric
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizerFast, PreTrainedModel, PretrainedConfig
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM, BertForTokenClassification
from import_conart import conart
from conart.mlm_masks import batched_text_gan
from conart import gan_utils as gu

In [3]:
device = torch.device("cuda") \
         if torch.cuda.is_available() else torch.device("cpu")

In [4]:
data_path = "../data/raw_cx_data.json"
with open(data_path, "r", encoding="UTF-8") as fin:
    data = json.load(fin)
## Check data is the same
h = sha256()
h.update(pickle.dumps(data))
data_hash = h.digest().hex()[:6]
assert data_hash == "4063b4"
len(data) # should be 11642

11642

In [5]:
## Read cv splits
with open("../data/cv_splits_10.json", "r") as fin:
    cv_splits = json.load(fin)

In [6]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')

## Checking input data

In [7]:
train_idxs, test_idxs = cv_splits[0].values()

In [8]:
xx = data[1211]
xx

{'board': 'BabyMother',
 'text': ['再', '擠', '也', '擠', '不', '出來', '了'],
 'cnstr': ['O', 'BX', 'IX', 'IX', 'IX', 'IX', 'O'],
 'slot': ['O', 'BV', 'BC', 'BV', 'BC', 'BV', 'O'],
 'cnstr_form': ['v', '也', 'v', '不', 'X'],
 'cnstr_example': ['擠', '也', '擠', '不', '出來']}

## Dataset and Dataloader

In [9]:
train_idxs_ds = TensorDataset(torch.tensor(train_idxs))
test_idxs_ds = TensorDataset(torch.tensor(test_idxs))
cx_lenc = LabelEncoder()
cx_lenc.classes_ = ["[PAD]", "BX", "IX", "O"]
slot_lenc = LabelEncoder()
slot_lenc.classes_ = ["[PAD]", "BC", "IC", "BV", "IV", "O"]
adv_lenc = LabelEncoder()
adv_lenc.classes_ = ["fake", "real"]
BV_id = slot_lenc.transform(["BV"])[0]
IV_id = slot_lenc.transform(["IV"])[0]

In [10]:
def gan_collate_fn(X, data, cx_lenc, slot_lenc, device="cpu"):  
    idxs = [x[0].item() for x in X]
    batch = batched_text_gan(data, idxs)
    
    max_len = 200
    real_tokens = tokenizer(batch["text"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          return_token_type_ids=False, return_attention_mask=False, 
                          max_length=max_len)    
    masked_tokens = tokenizer(batch["masked"], return_tensors="pt", 
                          is_split_into_words=True, padding=True, truncation=True,
                          max_length=max_len)    
    
    cx_tags = [torch.tensor(cx_lenc.transform(["[PAD]"] + x + ["[PAD]"]))
               for x in batch["cx_tags"]]
    cx_tags = pad_sequence(cx_tags, batch_first=True, padding_value=0)
    cx_tags = cx_tags[:, :real_tokens.input_ids.size(1)]
    
    slot_tags = [torch.tensor(slot_lenc.transform(["[PAD]"] + x + ["[PAD]"])) 
                 for x in batch["slot_tags"]]        
    slot_tags = pad_sequence(slot_tags, batch_first=True, padding_value=0)
    slot_tags = slot_tags[:, :real_tokens.input_ids.size(1)]
    batch["cx_tags"] = cx_tags.to(device)
    batch["slot_tags"] = slot_tags.to(device)
    batch["real_text"] = real_tokens.to(device)
    batch["masked_text"] = masked_tokens.to(device)
    batch.update(gu.make_gendcr_labels(batch, adv_ids=[BV_id, IV_id]))
    return batch

In [11]:
bb = gan_collate_fn([test_idxs_ds[202], test_idxs_ds[203]], data, cx_lenc, slot_lenc)

In [12]:
list(bb.keys())

['text',
 'masked',
 'cx_tags',
 'slot_tags',
 'real_text',
 'masked_text',
 'gen_labels']

In [13]:
## visual check
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["slot_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["cx_tags"][0])])
print([f"{a},{b.item()}" for a,b in zip(tokenizer.convert_ids_to_tokens(bb["real_text"].input_ids[0]), bb["gen_labels"][0])])

['[CLS],0', '也,5', '是,5', '直,5', '接,5', '退,3', '一,1', '退,3', '海,5', '闊,5', '天,5', '空,5', '了,5', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],0', '也,3', '是,3', '直,3', '接,3', '退,1', '一,2', '退,2', '海,3', '闊,3', '天,3', '空,3', '了,3', '[SEP],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0', '[PAD],0']
['[CLS],-100', '也,-100', '是,-100', '直,-100', '接,-100', '退,1', '一,-100', '退,1', '海,-100', '闊,-100', '天,-100', '空,-100', '了,-100', '[SEP],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100', '[PAD],-100']


In [14]:
bb["gen_labels"]

tensor([[-100, -100, -100, -100, -100,    1, -100,    1, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100,    1,    1, -100,
            1,    1, -100,    1, -100]])

## Model definition

In [15]:

class ConartModelDonut(PreTrainedModel):
    def __init__(self, config, bertD_ckpt, bertG_ckpt):        
        super(ConartModelDonut, self).__init__(config)
        # inherit self.bert, self.cls (lm head) from super()              
        self.bertD = BertForTokenClassification.from_pretrained(bertD_ckpt, num_labels=1)
        self.bertG = BertForMaskedLM.from_pretrained(bertG_ckpt)        
    
    def G_params(self):
        return self.bertG.parameters()
    
    def D_params(self):
        return self.bertD.parameters()
    
    def forward_G(self, X, labels=None):
        tokens = X["masked_text"]
        out = self.bertG(**tokens, labels=labels, return_dict=True)        
        lmlogits = out.logits
        lmprobs = lmlogits.softmax(dim=2)
        ret = {"lm_probs": lmprobs, "logits": lmlogits, 
               "mlm_loss": out.loss}        
        return ret
        
    
    def forward_D(self, X, labels=None):
        tokens = X["masked_text"]
        out = self.bertD(**tokens, return_dict=True)
        tok_logits = out.logits
        ret = {"tok_logits": tok_logits}
        
        if labels is not None:        
            real_loss = -torch.mean(tok_logits[labels==1])
            fake_loss = torch.mean(tok_logits[labels==0])
            adv_loss = real_loss + fake_loss
            ret["adv_loss"] = adv_loss
        return ret
    
    def forward(self, X):
        tokens = X["masked_text"]
        cx_tags = X["cx_tags"]
        slot_tags = X["slot_tags"]
        bert_out = self.bert(**tokens, return_dict=True)
        return bert_out

## Check adversarial samples

### Generator

In [16]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
debug_loader = DataLoader(test_idxs_ds, batch_size=2, shuffle=False, 
                         collate_fn=collate_fn)

In [17]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
config = PretrainedConfig()
model = ConartModelDonut(config, "ckiplab/bert-base-chinese", "ckiplab/bert-base-chinese")
model = model.to(device)

Some weights of the model checkpoint at ckiplab/bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and a

In [18]:
bb = next(iter(debug_loader))
model.forward_G(bb)

{'lm_probs': tensor([[[1.3438e-08, 2.0852e-08, 2.4144e-08,  ..., 2.8087e-07,
           4.3796e-07, 4.2810e-08],
          [1.7458e-18, 1.3219e-17, 2.1442e-17,  ..., 3.6188e-15,
           3.1057e-16, 8.1137e-17],
          [2.5644e-23, 1.1808e-21, 2.4968e-22,  ..., 1.1142e-20,
           3.2149e-20, 6.1896e-22],
          ...,
          [1.0755e-19, 1.6767e-19, 2.0689e-18,  ..., 2.9581e-16,
           4.0828e-17, 4.9804e-18],
          [1.7951e-15, 1.0347e-14, 1.5952e-14,  ..., 4.8796e-13,
           1.3940e-13, 4.5419e-14],
          [1.3372e-08, 2.0787e-08, 2.4046e-08,  ..., 2.8081e-07,
           4.3771e-07, 4.2651e-08]],
 
         [[2.8199e-08, 5.6591e-08, 4.1160e-08,  ..., 5.2027e-07,
           3.3007e-07, 1.5627e-07],
          [3.6332e-14, 5.8456e-14, 1.0368e-13,  ..., 2.0944e-12,
           3.0220e-14, 9.4347e-13],
          [2.0888e-11, 2.1918e-11, 1.3978e-11,  ..., 4.3487e-11,
           2.0442e-12, 5.0283e-10],
          ...,
          [4.0348e-10, 3.0904e-09, 3.8405e-09,

In [19]:
masked_ids = bb["gen_labels"].masked_scatter(bb["gen_labels"]==1, bb["real_text"].input_ids)
lm_out = model.forward_G(bb, masked_ids)
gu.visualize_gen(bb, masked_ids, tokenizer)

[CLS]我覺得這位護理師[31m[MASK]([4m錯[0;31m)[0m就[31m[MASK]([4m錯[0;31m)[0m在[31m[MASK]([4m抽[0;31m)[0m了一位不理性
[CLS]1km補給品[31m[MASK]([4m買[0;31m)[0m一[31m[MASK]([4m買[0;31m)[0m窩著先看這幾天發展比


In [20]:
lm_out["mlm_loss"]

tensor(12.0936, device='cuda:0', grad_fn=<NllLossBackward0>)

### Adversarial sample & Discriminator

In [21]:
lm_probs = model.forward_G(bb)["lm_probs"]
adv_out = gu.generate_adversarials(bb, lm_probs)
gu.visualize_adv(adv_out, tokenizer)
tok_out = model.forward_D(bb, labels=adv_out["dcr_labels"])
print(tok_out["adv_loss"])
print(tok_out["tok_logits"].shape)

[CLS]我覺得這位護理師[31m不([4m錯[0;31m)[0m就[31m像([4m錯[0;31m)[0m在抽了一位不理性
[CLS]1km補給品[31m第([4m買[0;31m)[0m一[31m窩([4m買[0;31m)[0m窩著先看這幾天發展比
tensor(0.0739, device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([2, 50, 1])


In [22]:
adv_out["adv_ids"]

tensor([[ 101, 2769, 6221, 2533, 6857,  855, 6362, 4415, 2374,  679, 2218, 1008,
         1762, 2853,  749,  671,  855,  679, 4415, 2595, 2097, 2044, 4638, 6117,
         1416,  872, 4534,  678, 1377,  809,  679, 6206, 1526, 1962, 1962, 6656,
         1961, 3978, 6858, 6313, 1961, 2994,  782, 2853, 1416, 3559,  677,  872,
         1962,  102],
        [ 101,  122,  153,  155, 6171, 5183, 1501, 5018,  671, 4979, 4979, 5865,
         1044, 4692, 6857, 2407, 1921, 4634, 2245, 3683, 6733, 4952, 5018,  671,
         3613, 6882, 2399, 4522, 1762, 7770, 7413,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0]], device='cuda:0')

## Model Training

In [17]:
collate_fn = lambda x: gan_collate_fn(x, data, cx_lenc, slot_lenc, device)
batch_size = 16
train_loader = DataLoader(train_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
test_loader = DataLoader(test_idxs_ds, batch_size=batch_size, shuffle=True, 
                         collate_fn=collate_fn)
print("batch size: ", batch_size)
print("Training dataset:", len(train_idxs_ds))
print("Testing dataset:", len(test_idxs_ds))


batch size:  16
Training dataset: 10477
Testing dataset: 1165


In [18]:
# model = BertForMaskedLM.from_pretrained('bert-base-chinese')
config = PretrainedConfig()
model = ConartModelDonut(config, "ckiplab/bert-base-chinese", "ckiplab/bert-base-chinese")
model = model.to(device)

Some weights of the model checkpoint at ckiplab/bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and a

In [19]:
# train generator
from transformers import get_linear_schedule_with_warmup
writer = SummaryWriter(log_dir="../data/tb_logs/train_adv_01", 
                       comment='epoch=6, lr=1e-5, with linear scheduler')

to_train_G = False

data_loader = train_loader
n_epoch = 3
optim_G = optim.AdamW(model.G_params(), lr=1e-5)
optim_D = optim.AdamW(model.D_params(), lr=1e-5)
scheduler_G = get_linear_schedule_with_warmup(optim_G, 50, len(data_loader)*n_epoch)
scheduler_D = get_linear_schedule_with_warmup(optim_D, 50, len(data_loader)*n_epoch)
# optim_dcr_cls = optim.AdamW(model.tok_cls.parameters(), lr=1e-4)

mlm_loss_epoch = MeanMetric()
dcr_loss_epoch = MeanMetric()

iter_idx = 0
for epoch_i in range(n_epoch):
       
    for batch_idx, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
        #  mlm prediction
        masked_ids = batch["gen_labels"].masked_scatter(
                        batch["gen_labels"]==1, batch["real_text"].input_ids)
        lm_out = model.forward_G(batch, masked_ids)
        mlm_loss = lm_out["mlm_loss"]
        writer.add_scalar('mlm_loss', mlm_loss.item(), iter_idx)
        mlm_loss_epoch.update(mlm_loss.item())    
                
        # generate adversarial samples
        lm_probs = model.forward_G(batch)["lm_probs"]
        adv_out = gu.generate_adversarials(batch, lm_probs)
        
        if batch_idx % 3 == 0:
            to_train_G = True
        else:
            to_train_G = False
        
        if to_train_G:
            # train generator
            optim_G.zero_grad()
            mlm_loss.backward()
            optim_G.step()
            scheduler_G.step()


        # compute adv loss
        tok_out = model.forward_D(batch, labels=adv_out["dcr_labels"])
        adv_loss = tok_out["adv_loss"]
        writer.add_scalar("adv_loss", adv_loss.item(), iter_idx)
        dcr_loss_epoch.update(adv_loss.item()) 
                    
        # train discriminator
        optim_D.zero_grad()
        adv_loss.backward()
        optim_D.step()
        scheduler_D.step()

        iter_idx += 1        
    writer.add_scalar("dcr_loss_epoch", dcr_loss_epoch.compute(), epoch_i)
    writer.add_scalar("mlm_loss_epoch", mlm_loss_epoch.compute(), epoch_i)
    dcr_loss_epoch.reset()
    mlm_loss_epoch.reset()
    
writer.close()

  0%|          | 0/655 [00:00<?, ?it/s]

  0%|          | 0/655 [00:00<?, ?it/s]

  0%|          | 0/655 [00:00<?, ?it/s]

In [20]:
model.save_pretrained("../data/models/donut_adv_01")

In [None]:
# alternating D/G