In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import IterableDataset

import torchtext
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import matplotlib.pyplot as plt
import spacy
import numpy as np

from copy import deepcopy
import random
import math
import time
from tqdm.notebook import tqdm

from typing import Tuple

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# copied from https://pytorch.org/tutorials/beginner/transformer_tutorial.html

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

# train_iter was "consumed" by the process of building the vocab,
# so we have to create it again
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def batchify(data: Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

batch_size = 64
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [5]:
bptt = 35
def get_batch(source: Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].T.reshape(-1)
    return data, target

# Input Embedding

In [6]:
class InputEmbedding(nn.Module) : 
    def __init__(self, vocab_size, max_length, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_length, d_model)

    def generate_dec_mask_m(self, src) :
        src_len = src.shape[1]
        src_sub_mask = torch.tril(torch.ones((src_len, src_len)), diagonal=0).bool().to(src.device) # mask subsequent token        
        return src_sub_mask
    
    def forward(self, x) : 
        emb = self.tok_emb(x)
        pos = torch.arange(0, emb.shape[1]).unsqueeze(0).repeat(emb.shape[0], 1).to(emb.device)
        summed = emb / math.sqrt(self.d_model) + self.pos_emb(pos)
        return summed

# Scaled Dot-Product Attention

In [7]:
class ScaledDotProductAttention(nn.Module) : 
    def __init__(self, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask) :         
        score = torch.matmul(q, k.permute(0,1,3,2).contiguous())/math.sqrt(d_model)
        score = score.masked_fill(mask, -1e10)
        scaled_score = torch.softmax(score, dim=-1)
        
        attention = torch.matmul(scaled_score, v).permute(0,2,3,1).contiguous()
        attention = attention.view(attention.shape[0], attention.shape[1], self.d_model)
        
        return self.fc(attention)

___

# Multi-Head Attention

In [8]:
class MultiHeadAttention(nn.Module) : 
    def __init__(self, d_model, n_head) : 
        super().__init__()
        assert d_model % n_head == 0, f"n_head({n_head}) does not divide d_model({d_model})"

        self.n_div_head = d_model//n_head
        self.d_model = d_model
        self.n_head = n_head

        self.Q = nn.Linear(d_model,  d_model)
        self.K = nn.Linear(d_model,  d_model)
        self.V = nn.Linear(d_model,  d_model)
        
    def div_and_sort_for_multiheads(self, projected, seq_len) : 
        div = projected.view(projected.shape[0], self.n_head, seq_len, self.n_div_head)
        return div
    
    def forward(self, emb, enc_inputs=None) :
        seq_len = emb.shape[1]
        q = self.div_and_sort_for_multiheads(self.Q(emb), seq_len)
        k = self.div_and_sort_for_multiheads(self.K(emb), seq_len)
        v = self.div_and_sort_for_multiheads(self.V(emb), seq_len)

        return q,k,v

# Post-process the sub-layer
- layer normalization
- residual conection
- residual dropout

In [9]:
class PostProcessing(nn.Module) : 
    def __init__(self, d_model, p=0.1) : 
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p)
        
    def forward(self, emb, attn) : 
        return self.ln(emb+self.dropout(attn))

# Position-wise FFN

In [10]:
class PositionwiseFFN(nn.Module) : 
    def __init__(self, d_model, d_ff) : 
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.act = nn.GELU()
    def forward(self, x) : 
        return self.fc2(self.act(self.fc1(x)))

# Encoder

In [11]:
class EncoderLayer(nn.Module) : 
    def __init__(self, vocab_size, d_model, d_ff, n_head, dropout_p) : 
        super().__init__()
        
        self.ma = MultiHeadAttention(d_model, n_head).to(device)
        self.sdp = ScaledDotProductAttention(d_model)
        
        self.pp1 = PostProcessing(d_model, dropout_p)
        self.pp2 = PostProcessing(d_model, dropout_p)
        
        self.positionwise_ffn = PositionwiseFFN(d_model, d_ff)

    def forward(self, emb, mask_m) :

        q,k,v = self.ma(emb)    
        attn = self.sdp(q,k,v, mask=mask_m)
        
        attn = self.pp1(emb, attn)
        z = self.positionwise_ffn(attn)

        return self.pp2(attn, z)

# Encoder-Decoder

In [12]:
class LM(nn.Module) : 
    def __init__(self,
                 vocab_size,
                 max_length,
                 d_model,
                 d_ff,
                 n_head,
                 dropout_p,
                 n_enc_layer) : 
        
        super().__init__()
        
        self.src_embber = InputEmbedding(vocab_size, max_length, d_model)
        enc = EncoderLayer(vocab_size, d_model, d_ff, n_head, dropout_p)
        self.enc = nn.ModuleList([deepcopy(enc) for _ in range(n_enc_layer)])        
        self.fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, src) : 
        
        src_emb = self.src_embber(src)
        src_mask_m = self.src_embber.generate_dec_mask_m(src)
        
        for enc_layer in self.enc : 
            src_emb = enc_layer(src_emb, src_mask_m)        
        
        return self.fc(src_emb)

In [13]:
d_model = 256
d_ff = 512
n_head = 8
vocab_size = len(vocab.get_itos())
dropout_p = 0.1
n_enc_layer = 3

In [14]:
model = LM(vocab_size,
            bptt,
            d_model,
            d_ff,
            n_head,
            dropout_p,
            n_enc_layer).to(device)

model(get_batch(train_data, 0)[0].T).shape

torch.Size([64, 35, 28782])

___

# Train

In [15]:
# torch.autograd.set_detect_anomaly(True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2.5e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0.001)

In [16]:
def train(model, data, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    cnt = 0
    
    for i in tqdm(range(0, data.shape[0] - 1, bptt), desc='train') :
        src, trg = get_batch(train_data, i)
        src = src.T # [batch_size, sequence_length]
        
        optimizer.zero_grad()        
        output = model(src) # [batch_size, sequence_length, hidden_dim]
        output = output.view(-1, vocab_size)
        
        loss = criterion(output, trg)
        loss.backward()        
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)        
        optimizer.step()        
        
        epoch_loss += loss.item()
        cnt += 1
    return epoch_loss / cnt

def evaluate(model, data, criterion):
    
    model.eval()
    
    epoch_loss = 0
    cnt = 0
    
    with torch.no_grad() : 
        for i in tqdm(range(0, data.shape[0] - 1, bptt), desc='valid') :
            src, trg = get_batch(data, i)
            src = src.T # [batch_size, sequence_length]

            output = model(src)
            output = output.view(-1, vocab_size)
            loss = criterion(output.view(-1, output.shape[-1]), trg)
            epoch_loss += loss.item()
            cnt += 1

    return epoch_loss / cnt

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [17]:
N_EPOCHS = 20
CLIP = 0.5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_data, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_data, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'gpt.pt')
        
    scheduler.step()
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 01 | Time: 0m 57s
	Train Loss: 6.879 | Train PPL: 971.965
	 Val. Loss: 5.771 |  Val. PPL: 320.908


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 02 | Time: 0m 58s
	Train Loss: 5.206 | Train PPL: 182.395
	 Val. Loss: 4.484 |  Val. PPL:  88.607


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 03 | Time: 0m 58s
	Train Loss: 4.229 | Train PPL:  68.671
	 Val. Loss: 3.768 |  Val. PPL:  43.294


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 04 | Time: 0m 58s
	Train Loss: 3.699 | Train PPL:  40.426
	 Val. Loss: 3.371 |  Val. PPL:  29.094


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 05 | Time: 0m 58s
	Train Loss: 3.302 | Train PPL:  27.161
	 Val. Loss: 2.978 |  Val. PPL:  19.640


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 06 | Time: 0m 55s
	Train Loss: 2.955 | Train PPL:  19.200
	 Val. Loss: 2.791 |  Val. PPL:  16.293


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 07 | Time: 0m 58s
	Train Loss: 2.721 | Train PPL:  15.194
	 Val. Loss: 2.523 |  Val. PPL:  12.466


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 08 | Time: 0m 58s
	Train Loss: 2.537 | Train PPL:  12.636
	 Val. Loss: 2.538 |  Val. PPL:  12.650


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 09 | Time: 0m 58s
	Train Loss: 2.319 | Train PPL:  10.167
	 Val. Loss: 2.198 |  Val. PPL:   9.005


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 10 | Time: 0m 58s
	Train Loss: 2.067 | Train PPL:   7.902
	 Val. Loss: 2.123 |  Val. PPL:   8.353


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 11 | Time: 0m 58s
	Train Loss: 1.862 | Train PPL:   6.439
	 Val. Loss: 1.969 |  Val. PPL:   7.165


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 12 | Time: 0m 58s
	Train Loss: 1.667 | Train PPL:   5.299
	 Val. Loss: 1.893 |  Val. PPL:   6.638


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 13 | Time: 0m 57s
	Train Loss: 1.487 | Train PPL:   4.423
	 Val. Loss: 1.673 |  Val. PPL:   5.331


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 14 | Time: 0m 58s
	Train Loss: 1.325 | Train PPL:   3.761
	 Val. Loss: 1.688 |  Val. PPL:   5.407


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 15 | Time: 0m 57s
	Train Loss: 1.204 | Train PPL:   3.334
	 Val. Loss: 1.484 |  Val. PPL:   4.412


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 16 | Time: 0m 58s
	Train Loss: 1.099 | Train PPL:   3.001
	 Val. Loss: 1.561 |  Val. PPL:   4.761


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 17 | Time: 0m 58s
	Train Loss: 1.020 | Train PPL:   2.772
	 Val. Loss: 1.412 |  Val. PPL:   4.106


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 18 | Time: 0m 58s
	Train Loss: 0.963 | Train PPL:   2.620
	 Val. Loss: 1.523 |  Val. PPL:   4.587


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 19 | Time: 0m 58s
	Train Loss: 0.912 | Train PPL:   2.488
	 Val. Loss: 1.331 |  Val. PPL:   3.784


train:   0%|          | 0/916 [00:00<?, ?it/s]

valid:   0%|          | 0/613 [00:00<?, ?it/s]

Epoch: 20 | Time: 0m 58s
	Train Loss: 0.860 | Train PPL:   2.363
	 Val. Loss: 1.560 |  Val. PPL:   4.759


___

# Test

In [17]:
model.load_state_dict(torch.load('gpt.pt'))
test_loss = evaluate(model, test_data, criterion)
print(f'Test Loss: {test_loss:.3f} |  Val. PPL: {math.exp(test_loss):7.3f}')

valid:   0%|          | 0/691 [00:00<?, ?it/s]

Test Loss: 1.317 |  Val. PPL:   3.734


In [80]:
example_idx = 1
ls = get_batch(test_data, 0)[0][:,example_idx].cpu().data.numpy().tolist()
cnt = 0
model.eval()

with torch.no_grad() : 
    for i in tqdm(range(0, test_data.shape[0] - 1, bptt), desc='valid') :
        src, _ = get_batch(test_data, i)
        src = src.T # [batch_size, sequence_length]

        output = model(src)
        ls.append(output.argmax(2)[example_idx].cpu().data.numpy()[-1])
        cnt += 1
        
        if cnt > 35 : 
            break

valid:   0%|          | 0/691 [00:00<?, ?it/s]

In [81]:
lemma_dict = {v:k for k,v in vocab.get_stoi().items()}
','.join(list(map(lambda x : lemma_dict.get(x), ls))).replace(",",' ')

'next day it joined f company to build up what became the main defensive position of the <unk> regiment in front of <unk> . north korean troops during the night passed around the right flank       s   the @ first     to   <unk> <unk> s     the the first     first of <unk> (         a the and   the  '

___

# TODO
- fine-tuning after pre-training
- fine-tuning with pre-training