In [1]:
import os
import numpy as np
from tempfile import TemporaryDirectory
import copy
import time

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
def data_process(raw_text_iter):
    """Convert raw text into a flat Tensor.
    """
    data = [
        torch.tensor(vocab(tokenizer(item)), dtype=torch.long)
        for item in raw_text_iter
    ]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def batchify(data, bsz, device):
    """Divides data into 'bsz' separate sequences & removes extra elements
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)


def generate_square_subsequent_mask(sz):
    """Generate upper-triangular matrix of ``-inf`` with zeros on the diagonal.
    """
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


def get_batch(x_src, i, bptt=35):
    """
    x_src is a tensor of shape (full_seq_len, batch_size).
    
    Returns a tuple (data, target) where data has shape (seq_len, batch_size) and
    target has shape (seq_len * batch_size)
    """
    seq_len = min(bptt, len(x_src) - 1 - i)
    data = x_src[i: i + seq_len]
    target = x_src[(i + 1): (i + 1 + seq_len)].reshape(-1)
    return data, target


def train(train_data, model, optimizer, criterion, bptt, device):
    model.train()
    total_loss = 0.0
    log_interval = 200
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    
    num_batches = len(train_data) // bptt
    # Set size - 1 as upper bound since targets come from a look-ahead
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        seq_len = data.size(0)
        # If last batch 
        if seq_len != bptt:
            src_mask = src_mask[:seq_len, :seq_len]
        yhat = model(data, src_mask)
        loss = criterion(yhat.view(-1, n_tokens), targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = np.exp(cur_loss)
            print((
                f"| epoch {epoch} | batch {batch}/{num_batches} "
                f"| lr {np.round(lr, 5)} | ms/batch {np.round(ms_per_batch, 2)} "
                f"| loss {np.round(cur_loss, 2)} | ppl {np.round(ppl, 2)}"
            ))
            total_loss = 0.0
            start_time = time.time()
            

def evaluate(eval_data, model, optimizer, criterion, bptt, device):
    model.eval()
    total_loss = 0.0
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        # Set size - 1 as upper bound since targets come from a look-ahead
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            if seq_len != bptt:
                src_mask = src_mask[:seq_len, :seq_len]
            yhat = model(data, src_mask)
            yhat_flat = yhat.view(-1, n_tokens)
            total_loss += seq_len * criterion(yhat_flat, targets).item()
    return total_loss / (len(eval_data) - 1)


In [3]:
class TransformerModel(nn.Module):
    
    def __init__(self, n_tokens, d_model, n_heads, d_hid, n_layers, dropout=0.5):
        super().__init__()
        self.model_type = "Transformer"
        self.encoder = nn.Embedding(n_tokens, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = TransformerEncoderLayer(d_model, n_heads, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer, n_layers)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, n_tokens)
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x_src, x_src_mask):
        x_src = self.encoder(x_src) * np.sqrt(self.d_model)
        x_src = self.pos_encoder(x_src)
        x_dest = self.transformer_encoder(x_src, x_src_mask)
        x_dest = self.decoder(x_dest)
        return x_dest


class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        # Odd positions
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        # Even positions
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        """
        x is a tensor of shape (seq_len, batch_size, embedding_dim)
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [4]:
train_iter = WikiText2(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [5]:
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chkpt_path = "../checkpoints/lang_model_transformer.pth"
batch_size = 20
eval_batch_size = 10

In [7]:
train_data = batchify(train_data, batch_size, device)
val_data = batchify(val_data, eval_batch_size, device)
test_data = batchify(test_data, eval_batch_size, device)

In [8]:
# Size of vocab
n_tokens = len(vocab)
# Embedding dimension
emsize = 200
# Dimension size of feed forward network in TransformerEncoder
d_hid = 200
# Number of TransformerEncoderLayer in TransformerEncoder
n_layers = 2
# Number of heads in MultiheadAttention
n_head = 2
# Dropout probability (used by all network modules)
dropout = 0.2

bptt = 35

In [9]:
model = TransformerModel(n_tokens, emsize, n_head, d_hid, n_layers, dropout).to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

In [11]:
best_val_loss = float("inf")
epochs = 5

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data, model, optimizer, criterion, bptt, device)
    val_loss = evaluate(val_data, model, optimizer, criterion, bptt, device)
    val_ppl = np.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    
    print("-" * 59)
    print((
        f"| end of epoch {epoch} | time {np.round(elapsed, 2)} "
        f"| valid loss {np.round(val_loss, 2)} | valid ppl {np.round(val_ppl, 2)}"
    ))
    print("-" * 59)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
    
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "criterion": criterion
        }
        torch.save(checkpoint, chkpt_path)
    scheduler.step()
    
checkpoint = torch.load(chkpt_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
criterion = checkpoint["criterion"]


| epoch 1 | batch 200/2928 | lr [5.] | ms/batch 15.85 | loss 8.22 | ppl 3717.18
| epoch 1 | batch 400/2928 | lr [5.] | ms/batch 11.78 | loss 6.88 | ppl 974.82
| epoch 1 | batch 600/2928 | lr [5.] | ms/batch 11.1 | loss 6.45 | ppl 630.48
| epoch 1 | batch 800/2928 | lr [5.] | ms/batch 10.61 | loss 6.31 | ppl 547.87
| epoch 1 | batch 1000/2928 | lr [5.] | ms/batch 11.51 | loss 6.19 | ppl 486.05
| epoch 1 | batch 1200/2928 | lr [5.] | ms/batch 10.85 | loss 6.16 | ppl 471.79
| epoch 1 | batch 1400/2928 | lr [5.] | ms/batch 10.69 | loss 6.11 | ppl 452.55
| epoch 1 | batch 1600/2928 | lr [5.] | ms/batch 10.76 | loss 6.11 | ppl 449.5
| epoch 1 | batch 1800/2928 | lr [5.] | ms/batch 10.87 | loss 6.02 | ppl 411.21
| epoch 1 | batch 2000/2928 | lr [5.] | ms/batch 11.45 | loss 6.02 | ppl 410.79
| epoch 1 | batch 2200/2928 | lr [5.] | ms/batch 10.96 | loss 5.89 | ppl 361.48
| epoch 1 | batch 2400/2928 | lr [5.] | ms/batch 11.13 | loss 5.97 | ppl 392.38
| epoch 1 | batch 2600/2928 | lr [5.] | ms/ba

In [13]:
test_loss = evaluate(test_data, model, optimizer, criterion, bptt, device)
test_ppl = np.exp(test_loss)
print("-" * 59)
print((
    f"| End of training | test loss {np.round(test_loss, 2)} "
    f"test ppl {np.round(test_ppl, 2)}"
))
print("-" * 59)

-----------------------------------------------------------
| End of training | test loss 5.49 test ppl 242.57
-----------------------------------------------------------
