In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import regex as re
import random
import itertools
import tqdm
import time

try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter
    
from apex import amp
from allennlp.training.checkpointer import Checkpointer
from gpt_model import GPT2SimpleLM
# TODO why openaiadam?
#from pytorch_pretrained_bert import GPT2Tokenizer, OpenAIAdam, GPT2Model
from pytorch_pretrained_bert import OpenAIAdam
from transformers import AdamW
from transformers import WarmupLinearSchedule
from torchfly.text.tokenizers import UnifiedBPETokenizer
from torchfly.modules.losses import SequenceFocalLoss, SequenceCrossEntropyLoss
from torchfly.modules.transformers import GPT2SimpleLM, UnifiedGPT2SmallConfig

# TODO no warmup learning rate schedule?

In [3]:
torch.backends.cudnn.benchmark = True
torch.manual_seed(123)
np.random.seed(123)

In [4]:
class PersuadeDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.tokenizer.max_len = 1500
        self.turn_ending = tokenizer.encode("\n\n\n")
        # TODO: no ending?
        # self.dialog_ending = [tokenizer.encoder["[EOS]"]]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        dial_tokens = [tokenizer.encode(item) + self.turn_ending for item in self.data[index]]
        role_ids = [0 if item[0] == 32 else 1 for item in dial_tokens]
        dial_tokens[-1] = dial_tokens[-1][:-2] # + self.dialog_ending
        return role_ids, dial_tokens
        

class Collate_Function:
    """This function handles batch collate.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # self.EOS = self.tokenizer.encoder["[EOS]"]

    def __call__(self, unpacked_data):
        return unpacked_data

In [5]:
tokenizer = UnifiedBPETokenizer()

#tokenizer = torch.load("/home/qingyang/Desktop/GPT2_Modification/special3_gpt2_tokenizer.pkl")
# TODO why load tokneizer
'''
class GPT2SmallConfig:
    vocab_size = 50257 + len(tokenizer.__special_tokens__)
    n_special = len(tokenizer.__special_tokens__)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 768
    n_layer = 12
    n_head = 12
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = False
    
class GPT2MediumConfig:
    vocab_size = 50257 + len(tokenizer.__special_tokens__)
    n_special = len(tokenizer.__special_tokens__)
    n_positions = 1024
    n_ctx = 1024
    n_embd = 1024
    n_layer = 24
    n_head = 16
    resid_pdrop = 0.1
    embd_pdrop = 0.1
    attn_pdrop = 0.1
    layer_norm_epsilon = 1e-5
    initializer_range = 0.02
    gradient_checkpointing = True
'''

'\nclass GPT2SmallConfig:\n    vocab_size = 50257 + len(tokenizer.__special_tokens__)\n    n_special = len(tokenizer.__special_tokens__)\n    n_positions = 1024\n    n_ctx = 1024\n    n_embd = 768\n    n_layer = 12\n    n_head = 12\n    resid_pdrop = 0.1\n    embd_pdrop = 0.1\n    attn_pdrop = 0.1\n    layer_norm_epsilon = 1e-5\n    initializer_range = 0.02\n    gradient_checkpointing = False\n    \nclass GPT2MediumConfig:\n    vocab_size = 50257 + len(tokenizer.__special_tokens__)\n    n_special = len(tokenizer.__special_tokens__)\n    n_positions = 1024\n    n_ctx = 1024\n    n_embd = 1024\n    n_layer = 24\n    n_head = 16\n    resid_pdrop = 0.1\n    embd_pdrop = 0.1\n    attn_pdrop = 0.1\n    layer_norm_epsilon = 1e-5\n    initializer_range = 0.02\n    gradient_checkpointing = True\n'

In [6]:
model_A = GPT2SimpleLM(UnifiedGPT2SmallConfig)
model_B = GPT2SimpleLM(UnifiedGPT2SmallConfig)
#model_A.load_state_dict(torch.load("/home/qingyang/Desktop/GPT2_Modification/special3_gpt2_small.pth"))
#model_B.load_state_dict(torch.load("/home/qingyang/Desktop/GPT2_Modification/special3_gpt2_small.pth"))
model_A.load_state_dict(torch.load("../../Checkpoint/best.th"))
model_B.load_state_dict(torch.load("../../Checkpoint/best.th"))

# model_A = GPT2SimpleLM(GPT2MediumConfig)
# model_B = GPT2SimpleLM(GPT2MediumConfig)
# model_A.load_state_dict(torch.load("/home/qingyang/Desktop/GPT2_Modification/special3_gpt2_medium.pth"))
# model_B.load_state_dict(torch.load("/home/qingyang/Desktop/GPT2_Modification/special3_gpt2_medium.pth"))

<All keys matched successfully>

### load the data

In [7]:
train_data = torch.load("DataProcess/train_dialogs.pkl")
val_data = torch.load("DataProcess/val_dialogs.pkl")

train_dataset = PersuadeDataset(train_data, tokenizer)
val_dataset = PersuadeDataset(val_data, tokenizer)

batch_size = 1
collate_func = Collate_Function(tokenizer)

train_dataloader = DataLoader(dataset=train_dataset, 
                              shuffle=True, 
                              batch_size=batch_size, 
                              collate_fn=collate_func)
val_dataloader = DataLoader(dataset=val_dataset, 
                            shuffle=False, 
                            batch_size=batch_size, 
                            collate_fn=collate_func)

## Define the model

In [8]:
device = torch.device("cuda:0")
model_A = model_A.to(device)
model_B = model_B.to(device)

In [9]:
# define the losses
criterion = SequenceFocalLoss(gamma=1.0, beta=0.0)
eval_criterion = SequenceCrossEntropyLoss()

In [10]:
def train_one_iter(batch, update_count, fp16=False):
    role_ids, dialog_tokens = batch
    dial_inputs = [torch.LongTensor(item).unsqueeze(0).to(device) for item in dialog_tokens]
    
    past = None
    all_logits = []
    # A_logits = []
    # B_logits = []
    # A_target = []
    # B_target = []
#     user = tokenizer.encode("B:" + user)
#     sep = tokenizer.encode("\n\n\n") 
#     suffix = tokenizer.encode("A:")
#     prev_input = sep + user + sep + suffix
    
#     prev_input = torch.LongTensor(prev_input).unsqueeze(0).to(device)
#     past_length = past_position_ids.item()
    
#     past_position_ids = np.arange(past_length, past_length+2).tolist() + \
#                          np.arange(len(user) + 2).tolist() + \
#                          np.arange(2).tolist()
    
#     past_position_ids = torch.LongTensor(past_position_ids).unsqueeze(0).to(device)
    
    for turn_num, dial_turn_inputs in enumerate(dial_inputs):
        if role_ids[turn_num] == 0:
            # breakpoint()
            logits, past = model_A(dial_turn_inputs, past=past)
            all_logits.append(logits)
        else:
            # breakpoint()
            logits, past = model_B(dial_turn_inputs, past=past)
            all_logits.append(logits)

    all_logits = torch.cat(all_logits, dim=1)
    
    
    
    # target
    all_logits = all_logits[:, :-1].contiguous()
    target = torch.cat(dial_inputs, dim=1)[:, 1:].contiguous()
    target_mask = torch.ones_like(target).float()
    
    loss = criterion(all_logits, target, target_mask, label_smoothing=0.02, reduce=True)
    loss /= num_gradients_accumulation
    
    if fp16:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
        
    record_loss = loss.item() * num_gradients_accumulation
    perplexity = np.exp(record_loss)
    
    return record_loss, perplexity


def validate(dataloader):
    with torch.no_grad():
        pbar = progress_bar(dataloader)

        total_ppl = []

        for batch in pbar:
            
            if sum([len(item) for item in batch[0][1]]) > 1024:
                continue
            
            role_ids, dialog_tokens = batch[0]
            dial_inputs = [torch.LongTensor(item).unsqueeze(0).to(device) for item in dialog_tokens]

            past = None
            all_logits = []
            # A_logits = []
            # B_logits = []
            # A_target = []
            # B_target = []

            for turn_num, dial_turn_inputs in enumerate(dial_inputs):
                if role_ids[turn_num] == 0:
                    logits, past = model_A(dial_turn_inputs, past=past)
                    all_logits.append(logits)
                else:
                    logits, past = model_B(dial_turn_inputs, past=past)
                    all_logits.append(logits)

            all_logits = torch.cat(all_logits, dim=1)
            
            # target
            all_logits = all_logits[:, :-1].contiguous()
            target = torch.cat(dial_inputs, dim=1)[:, 1:].contiguous()
            target_mask = torch.ones_like(target).float()
            
            loss = eval_criterion(all_logits, target, target_mask, label_smoothing=-1, reduce="sentence")      

            ppl = torch.exp(loss)
            total_ppl.extend(ppl.tolist())

        print(f"Epcoh {ep} Validation Perplexity: {np.mean(total_ppl)} Variance: {np.var(total_ppl)}")
        
        return np.mean(total_ppl)

### Training

In [11]:
if not os.path.isdir("Checkpoint"):
    os.makedirs("Checkpoint")
checkpointer = Checkpointer(serialization_dir="Checkpoint", 
                            keep_serialized_model_every_num_seconds=3600*2, 
                            num_serialized_models_to_keep=5)

In [12]:
# optimizer
num_epochs = 10
num_gradients_accumulation = 1
num_train_optimization_steps = num_train_optimization_steps = len(train_dataset) * num_epochs // batch_size // num_gradients_accumulation

param_optimizer = list(model_A.named_parameters()) + list(model_B.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]


optimizer = OpenAIAdam(optimizer_grouped_parameters,
                       lr=2e-5,
                       warmup=0.1,
                       max_grad_norm=1.0,
                       weight_decay=0.01,
                       t_total=num_train_optimization_steps)

In [13]:
# support fp16
# [model_A, model_B], optimizer = amp.initialize([model_A, model_B], optimizer, opt_level="O1")

In [14]:
update_count = 0
progress_bar = tqdm.notebook.tqdm
start = time.time()
old_ppl = -float('Inf')

for ep in range(num_epochs):

    "Training"
    pbar = progress_bar(train_dataloader)
    model_A.train()
    model_B.train()
    
    for batch in pbar:
        batch = batch[0]
        # without relative position
        if sum([len(item) for item in batch[1]]) > 1024:
            continue
            
        record_loss, perplexity = train_one_iter(batch, update_count, fp16=False)
        
        update_count += 1

        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
            # update for gradient accumulation
            optimizer.step()
            optimizer.zero_grad()
            
            # speed measure
            end = time.time()
            speed = batch_size * num_gradients_accumulation / (end - start)
            start = end
            
            # show progress
            pbar.set_postfix(loss=record_loss, perplexity=perplexity, speed=speed)

    "Evaluation"
    model_A.eval()
    model_B.eval()
    ppl = validate(val_dataloader)
    
    is_best_so_far = ppl > old_ppl
    old_ppl = ppl
    checkpointer.save_checkpoint(ep, [model_A.state_dict(), model_B.state_dict()], {"None": None}, is_best_so_far)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 0 Validation Perplexity: 11.629383223397392 Variance: 12.327919333045331


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 1 Validation Perplexity: 10.62478511187495 Variance: 11.51269794930117


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 2 Validation Perplexity: 10.249737620353699 Variance: 11.211154902251376


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 3 Validation Perplexity: 10.079965934461477 Variance: 11.183145393877847


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 4 Validation Perplexity: 10.000443351512052 Variance: 11.250616878880448


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 5 Validation Perplexity: 9.968954276065437 Variance: 11.509323373219022


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 6 Validation Perplexity: 9.929912992886134 Variance: 11.60251898158781


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 7 Validation Perplexity: 9.92309160865083 Variance: 11.715699135269219


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 8 Validation Perplexity: 9.932806625658152 Variance: 11.829686047241344


HBox(children=(IntProgress(value=0, max=891), HTML(value='')))




HBox(children=(IntProgress(value=0, max=99), HTML(value='')))


Epcoh 9 Validation Perplexity: 9.939598545736196 Variance: 11.870388044289164
