In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from copy import deepcopy
import random
import math
import time
from tqdm.notebook import tqdm

from typing import Tuple

import sys
sys.path.append("..")
from dataset import iterator

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

____

# Input Embedding

In [3]:
class InputEmbedding(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=train_dataset.prep.pad_id)
        self.pos_emb = nn.Embedding(seq_length, d_model, padding_idx=train_dataset.prep.pad_id)
        
    def generate_enc_mask_m(self, src) :       
        mask_m = (src != 1).unsqueeze(1).unsqueeze(2)
        return mask_m
    
    def forward(self, txt) : 
        emb = self.tok_emb(txt)
        pos = torch.arange(0, emb.shape[1]).unsqueeze(0).repeat(emb.shape[0], 1).to(emb.device)
        summed = emb / math.sqrt(self.d_model) + self.pos_emb(pos)
        return summed

# Scaled Dot-Product Attention

In [4]:
class ScaledDotProductAttention(nn.Module) : 
    def __init__(self, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask) :         
        score = torch.matmul(q, k.permute(0,1,3,2).contiguous())/math.sqrt(d_model)
        score = score.masked_fill(mask, -1e10)
        scaled_score = torch.softmax(score, dim=-1)
        
        attention = torch.matmul(scaled_score, v).permute(0,2,3,1).contiguous()
        attention = attention.view(attention.shape[0], attention.shape[1], self.d_model)
        
        return self.fc(attention)

___

# Multi-Head Attention

In [5]:
class MultiHeadAttention(nn.Module) : 
    def __init__(self, d_model, seq_length, n_head) : 
        super().__init__()
        assert d_model % n_head == 0, f"n_head({n_head}) does not divide d_model({d_model})"

        self.n_div_head = d_model//n_head
        self.d_model = d_model
        self.seq_len = seq_length
        self.n_head = n_head

        self.Q = nn.Linear(d_model,  d_model)
        self.K = nn.Linear(d_model,  d_model)
        self.V = nn.Linear(d_model,  d_model)
        
    def div_and_sort_for_multiheads(self, projected, seq_len) : 
        div = projected.view(projected.shape[0], self.n_head, seq_len, self.n_div_head)
        return div
    
    def forward(self, emb, enc_inputs=None) :
        q = self.div_and_sort_for_multiheads(self.Q(emb), self.seq_len)
    
        if enc_inputs is not None : # enc-dec attention
            seq_len = enc_inputs.shape[1] # takes target sequence length for k and v
            k = self.div_and_sort_for_multiheads(self.K(enc_inputs), seq_len)
            v = self.div_and_sort_for_multiheads(self.V(enc_inputs), seq_len)
        else : # self-attention
            k = self.div_and_sort_for_multiheads(self.K(emb), self.seq_len)
            v = self.div_and_sort_for_multiheads(self.V(emb), self.seq_len)

        return q,k,v

# Post-process the sub-layer
- layer normalization
- residual conection
- residual dropout

In [27]:
class PostProcessing(nn.Module) : 
    def __init__(self, d_model, p=0.1) : 
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p)
        
    def forward(self, emb, attn) : 
        return emb+self.dropout(self.ln(attn))

# Position-wise FFN

In [7]:
class PositionwiseFFN(nn.Module) : 
    def __init__(self, d_model, d_ff) : 
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x) : 
        return self.fc2(torch.relu(self.fc1(x)))

# Encoder

In [8]:
class EncoderLayer(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model, d_ff, n_head, dropout_p) : 
        super().__init__()
        
        self.ma = MultiHeadAttention(d_model, seq_length, n_head).to(device)
        self.sdp = ScaledDotProductAttention(d_model)
        
        self.pp1 = PostProcessing(d_model, dropout_p)
        self.pp2 = PostProcessing(d_model, dropout_p)
        
        self.positionwise_ffn = PositionwiseFFN(d_model, d_ff)
            
    def forward(self, emb, mask_m) :

        q,k,v = self.ma(emb)    
        attn = self.sdp(q,k,v, mask=mask_m)
        
        attn = self.pp1(emb, attn)
        z = self.positionwise_ffn(attn)

        return self.pp2(attn, z)

# ROBERTa

In [28]:
class ROBERTa(nn.Module) : 
    def __init__(self,
                 vocab_size,
                 seq_length,
                 d_model,
                 d_ff,
                 n_head,
                 dropout_p,
                 n_enc_layer) : 
        
        super().__init__()
        
        self.embber = InputEmbedding(vocab_size, seq_length, d_model)
        
        enc = EncoderLayer(vocab_size, seq_length, d_model, d_ff, n_head, dropout_p)
        
        self.enc = nn.ModuleList([deepcopy(enc) for _ in range(n_enc_layer)])

    def forward(self, txt) : 
        
        emb = self.embber(txt)
        mask_m = self.embber.generate_enc_mask_m(txt)

        for enc_layer in self.enc : 
            emb = enc_layer(emb, mask_m)

        return emb

In [29]:
class ROBERTaFC(nn.Module) : 
    def __init__(self, embedder, d_model, vocab_size) : 
        super().__init__()
        self.embedder = embedder
        self.mlm_fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, txt) : 
        emb = self.embedder(txt)
        return torch.log_softmax(self.mlm_fc(emb), dim=-1)

___

In [30]:
batch_size = 30
seq_len = 256

train_dataset = iterator.RobertaIterator(filename = '../../bert/data/wikitext-2-raw/prep_train.txt',
                                tokenizer_model_path = '../../bert/data/wiki_pretrained_vocab/m.model',
                                seq_len=seq_len)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
 
valid_dataset = iterator.RobertaIterator(filename = '../../bert/data/wikitext-2-raw/prep_valid.txt',
                                tokenizer_model_path = '../../bert/data/wiki_pretrained_vocab/m.model',
                                seq_len=seq_len)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [31]:
d_model = 64
d_ff = d_model * 4
n_head = 8
vocab_size = 30000
dropout_p = 0.1
n_enc_layer = 1
seq_length = train_dataset.seq_len
vocab = train_dataset.vocab

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [32]:
roberta_ember = ROBERTa(vocab_size,
             seq_length,
             d_model,
             d_ff,
             n_head,
             dropout_p,
             n_enc_layer).to(device) 
# generate embeding using transformer architecture

roberta_predicter = nn.DataParallel(ROBERTaFC(roberta_ember, d_model, vocab_size)).to(device)
# return 2 fc layer for training bi-directional representation

___

# Train

In [33]:
# torch.autograd.set_detect_anomaly(True) # -- for debugging the train progress --
N_EPOCHS = 20

criterion1 = nn.NLLLoss(ignore_index=0).to(device)
criterion2 = nn.NLLLoss().to(device)

optimizer = optim.AdamW(roberta_predicter.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-1, pct_start=0.01, 
                                          steps_per_epoch=len(train_dataloader), epochs=N_EPOCHS, 
                                          total_steps=N_EPOCHS * len(train_dataloader), anneal_strategy='linear')

In [34]:
def train(model, dataloader, optimizer, scheduler):
    
    model.train()
    
    epoch_total_loss = 0
    epoch_mlm_acc = 0    
    cnt = 0
    
    for data in tqdm(dataloader, desc='train') :
        
        data = {k:v.to(device) for k,v in data.items()}
        
        optimizer.zero_grad()        
        mlm_pred = model(data['text'])

        # calculate loss from masked language modeling
        loss = criterion1(mlm_pred.transpose(1,2), data['mlm'])
                
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()        
        
        # calculate acc for mlm
        acc = (mlm_pred.view(-1,vocab_size).argmax(1) == data['mlm'].view(-1)).sum().item() / data['mlm'].view(-1).shape[0]
        epoch_mlm_acc += acc
        
        epoch_total_loss += loss.item()
        cnt += 1
        scheduler.step()
        
    print(f'\tTrain Total Loss: {epoch_total_loss / cnt:.3f} | MLM ACC : {epoch_mlm_acc / cnt: .3f} | Learning Rate : {scheduler.get_last_lr()[0]:.3f}')

def evaluate(model, dataloader):
    
    model.eval()
    
    epoch_total_loss = 0
    epoch_mlm_acc = 0    
    cnt = 0
    
    with torch.no_grad() : 
        for data in tqdm(dataloader, desc='valid') :
            
            data = {k:v.to(device) for k,v in data.items()}

            mlm_pred = model(data['text'])

            # calculate loss from masked language modeling
            loss = criterion1(mlm_pred.transpose(1,2), data['mlm'])
            
            acc = (mlm_pred.view(-1,vocab_size).argmax(1) == data['mlm'].view(-1)).sum().item() / data['mlm'].view(-1).shape[0]
            epoch_mlm_acc += acc

            epoch_total_loss += loss.item()
            cnt += 1

        print(f'\tValid Total Loss: {epoch_total_loss / cnt:.3f} | MLM ACC : {epoch_mlm_acc / cnt: .3f}')

In [35]:
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):    
    train(roberta_predicter, train_dataloader, optimizer, scheduler)
    evaluate(roberta_predicter, train_dataloader)

train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 7.304 | MLM ACC :  0.011 | Learning Rate : 0.096


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.865 | MLM ACC :  0.015


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.883 | MLM ACC :  0.016 | Learning Rate : 0.091


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.774 | MLM ACC :  0.016


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.821 | MLM ACC :  0.016 | Learning Rate : 0.086


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.702 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.753 | MLM ACC :  0.017 | Learning Rate : 0.081


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.641 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.718 | MLM ACC :  0.017 | Learning Rate : 0.076


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.632 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.701 | MLM ACC :  0.017 | Learning Rate : 0.071


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.612 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.668 | MLM ACC :  0.017 | Learning Rate : 0.066


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.588 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.660 | MLM ACC :  0.017 | Learning Rate : 0.061


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.569 | MLM ACC :  0.017


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.627 | MLM ACC :  0.018 | Learning Rate : 0.056


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.546 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.601 | MLM ACC :  0.018 | Learning Rate : 0.050


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.529 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.579 | MLM ACC :  0.018 | Learning Rate : 0.045


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.505 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.546 | MLM ACC :  0.018 | Learning Rate : 0.040


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.487 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.515 | MLM ACC :  0.018 | Learning Rate : 0.035


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.467 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.487 | MLM ACC :  0.018 | Learning Rate : 0.030


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.450 | MLM ACC :  0.018


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.459 | MLM ACC :  0.019 | Learning Rate : 0.025


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.428 | MLM ACC :  0.019


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.442 | MLM ACC :  0.019 | Learning Rate : 0.020


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.400 | MLM ACC :  0.019


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.414 | MLM ACC :  0.019 | Learning Rate : 0.015


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.386 | MLM ACC :  0.019


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.394 | MLM ACC :  0.019 | Learning Rate : 0.010


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.357 | MLM ACC :  0.019


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.360 | MLM ACC :  0.019 | Learning Rate : 0.005


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.343 | MLM ACC :  0.019


train:   0%|          | 0/270 [00:00<?, ?it/s]

	Train Total Loss: 6.337 | MLM ACC :  0.020 | Learning Rate : -0.000


valid:   0%|          | 0/270 [00:00<?, ?it/s]

	Valid Total Loss: 6.326 | MLM ACC :  0.020
