In [1]:
# Standard libraries
import math
import os
import urllib.request
from functools import partial
from urllib.error import HTTPError

# Plotting
import matplotlib
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
import numpy as np

# PyTorch Lightning
import pytorch_lightning as pl
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

# Torchvision
import torchvision
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms
from torchvision.datasets import CIFAR100
from tqdm.notebook import tqdm

plt.set_cmap("cividis")
%matplotlib inline
matplotlib_inline.backend_inline.set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = None
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
print("Device:", device)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 42


Device: mps


<Figure size 640x480 with 0 Axes>

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()

        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model) # (seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model)) # (seq_len, d_model)
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model)) # (seq_len, d_model)
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x:torch.Tensor):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

In [3]:
def attention(q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, mask=None):
    d_k = q.size()[-1] # q,k,v : (batch, head, seq_len, embed_size_per_head)
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) # (batch, head, seq_len, seq_len)
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attn_logits, v) # (batch, head, seq_len, embed_size_per_head)
    return values, attention

def init_weights(x:nn.Linear):
    with torch.no_grad():
        nn.init.xavier_uniform_(x.weight)
        x.bias.data.fill_(0)

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, input_dim:int, d_model: int, h: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h

        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h

        self.w_q = nn.Linear(input_dim, d_model) # Wq
        self.w_k = nn.Linear(input_dim, d_model) # Wk
        self.w_v = nn.Linear(input_dim, d_model) # Wv
        self.w_o = nn.Linear(d_model, d_model) # Wo

        init_weights(self.w_q)
        init_weights(self.w_k)
        init_weights(self.w_v)
        init_weights(self.w_o)

    def forward(self, q_x:torch.Tensor, k_x:torch.Tensor, v_x:torch.Tensor, mask=None):
        q:torch.Tensor = self.w_q(q_x) # (batch, seq_len, d_model)
        k:torch.Tensor = self.w_k(k_x) # (batch, seq_len, d_model)
        v:torch.Tensor = self.w_v(v_x) # (batch, seq_len, d_model)

        q_h = q.reshape(q.shape[0], q.shape[1], self.h, self.d_k).transpose(1, 2) # (batch, head, seq_len, d_k)
        k_h = k.reshape(k.shape[0], k.shape[1], self.h, self.d_k).transpose(1, 2) # (batch, head, seq_len, d_k)
        v_h = v.reshape(v.shape[0], v.shape[1], self.h, self.d_k).transpose(1, 2) # (batch, head, seq_len, d_k)

        attn_out, _ = attention(q_h, k_h, v_h, mask) # (batch, head, seq_len, embed_size_per_head)
        attn_out = attn_out.transpose(1, 2) # (batch, seq_len, head, embed_size_per_head)
        attn_out = attn_out.reshape(attn_out.shape[0], attn_out.shape[1], attn_out.shape[2]*attn_out.shape[3]) # (batch, seq_len, d_model)

        return self.w_o(attn_out) # (batch, seq_len, d_model)

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()

        self.self_attn = MultiHeadAttentionBlock(input_dim, input_dim, num_heads)

        self.ffn_1 = nn.Linear(input_dim, dim_feedforward)
        self.ffn_2 = nn.Linear(dim_feedforward, input_dim)

        init_weights(self.ffn_1)
        init_weights(self.ffn_2)

        self.ffn = nn.Sequential(
            self.ffn_1,
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            self.ffn_2,
        )

        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.self_attn(x, x, x, mask=mask) # (batch, seq_len, input_dim)
        x = x + self.dropout(attn_out) # (batch, seq_len, input_dim)
        x = self.norm1(x) # (batch, seq_len, input_dim)

        ffn_out = self.ffn(x) # (batch, seq_len, input_dim)
        x = x + self.dropout(ffn_out) # (batch, seq_len, input_dim)
        x = self.norm2(x) # (batch, seq_len, input_dim)

        return x

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout=0.0):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(d_model, num_heads, dim_feedforward, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0)->None:
        super().__init__()

        self.self_attn = MultiHeadAttentionBlock(input_dim, input_dim, num_heads)
        self.crss_attn = MultiHeadAttentionBlock(input_dim, input_dim, num_heads)

        self.ffn_1 = nn.Linear(input_dim, dim_feedforward)
        self.ffn_2 = nn.Linear(dim_feedforward, input_dim)

        init_weights(self.ffn_1)
        init_weights(self.ffn_2)
        
        self.ffn = nn.Sequential(
            self.ffn_1,
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            self.ffn_2,
        )

        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.norm3 = nn.LayerNorm(input_dim)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, pred_mask, pad_mask):
        self_attn_out = self.self_attn(x, x, x, mask=pred_mask) # (batch, seq_len, input_dim)
        x = x + self.dropout(self_attn_out) # (batch, seq_len, input_dim)
        x = self.norm1(x) # (batch, seq_len, input_dim)

        crss_attn_out = self.crss_attn(x, encoder_output, encoder_output, mask=pad_mask) # (batch, seq_len, input_dim)
        x = x + self.dropout(crss_attn_out) # (batch, seq_len, input_dim)
        x = self.norm2(x) # (batch, seq_len, input_dim)

        ffn_out = self.ffn(x) # (batch, seq_len, input_dim)
        x = x + self.dropout(ffn_out) # (batch, seq_len, input_dim)
        x = self.norm3(x) # (batch, seq_len, input_dim)

        return x

In [7]:
class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dim_feedforward, dropout):
        super().__init__()
        self.layers = nn.ModuleList([DecoderBlock(d_model, num_heads, dim_feedforward, dropout) for _ in range(num_layers)])

    def forward(self, x, encoder_output, pred_mask=None, pad_mask=None):
        for layer in self.layers:
            x = layer(x, encoder_output, pred_mask, pad_mask)
        return x

In [8]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, seq_len, d_model, num_heads, dim_feedforward, num_encoder_layers, num_decoder_layers, dropout=0.0) -> None:
        super(Transformer, self).__init__()

        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, seq_len, dropout)

        self.encoder_block = Encoder(num_encoder_layers, d_model, num_heads, dim_feedforward, dropout)
        self.decoder_block = Decoder(num_decoder_layers, d_model, num_heads, dim_feedforward, dropout)

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        init_weights(self.fc)

        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)       


    def generate_mask(self, src:torch.Tensor, tgt:torch.Tensor):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) # (batch, 1, seq_len, 1)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device=device) # (1, seq_len, seq_len)
        tgt_mask = tgt_mask & nopeak_mask # (batch, 1, seq_len, seq_len)
        return src_mask, tgt_mask
    

    def forward(self, src:torch.Tensor, tgt:torch.Tensor):
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        src_embedded = self.encoder_embedding(src) # (batch, seq_len, d_model)
        tgt_embedded = self.decoder_embedding(tgt) # (batch, seq_len, d_model)

        src_embedded = self.positional_encoding(src_embedded) # (batch, seq_len, d_model)
        tgt_embedded = self.positional_encoding(tgt_embedded) # (batch, seq_len, d_model)

        src_embedded = self.dropout(src_embedded) # (batch, seq_len, d_model)
        tgt_embedded = self.dropout(tgt_embedded) # (batch, seq_len, d_model)

        enc_output = self.encoder_block(src_embedded, src_mask) # (batch, seq_len, d_model)
        dec_output = self.decoder_block(tgt_embedded, enc_output, tgt_mask, src_mask) # (batch, seq_len, d_model)

        return self.fc(dec_output) # (batch, seq_len, tgt_vocab_size)

In [9]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [47]:
import random

def generate_data(n=10000, start_rand=100, max_seq_length=100):
    vocab_size = 0
    data_src, data_tgt = [], []

    for _ in range(n):
        sumv = 0
        maxv = 0
        seq = []

        for j in range(2*max_seq_length):
            if j <= 1:
                d = random.randint(0, start_rand)
            else:
                d = abs(sumv-maxv)

            while d > 1000:
                d = d/2 if d % 2 == 0 else (d+1)/2
                d = int(d)

            vocab_size = max(vocab_size, d+1)
            seq += [d]

            sumv += d
            maxv = max(maxv, d)

        data_src += [seq[:max_seq_length]]
        data_tgt += [seq[max_seq_length:]]
    
    return torch.tensor(data_src, dtype=torch.int64), torch.tensor(data_tgt, dtype=torch.int64), vocab_size

In [48]:
n = 100000
m = int(0.8*n)
data_src, data_tgt, vocab_size = generate_data(n)

data_src_train, data_src_test = data_src[:m], data_src[m:]
data_tgt_train, data_tgt_test = data_tgt[:m], data_tgt[m:]

In [51]:
d_model = 64
num_heads = 2
num_layers = 1
d_ff = 64
max_seq_length = 100
dropout = 0.0

transformer = Transformer(vocab_size, vocab_size, max_seq_length, d_model, num_heads, d_ff, num_layers, num_layers, dropout).to(device=device)

In [52]:
n_epochs = 20    # number of epochs to run
batch_size = 128  # size of each batch
batches_per_epoch = data_tgt_train.shape[0] // batch_size

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.001)
lr_scheduler = CosineWarmupScheduler(optimizer, warmup=50, max_iters=batches_per_epoch*n_epochs)

transformer.train()

for epoch in range(n_epochs):
    for i in range(batches_per_epoch):
        optimizer.zero_grad()
        start = i * batch_size

        data_src_train_batch = data_src_train[start:start+batch_size]
        data_tgt_train_batch = data_tgt_train[start:start+batch_size]

        output:torch.Tensor = transformer(data_src_train_batch.to(device=device), data_tgt_train_batch[:, :-1].to(device=device))
        loss:torch.Tensor = criterion(output.contiguous().view(-1, vocab_size), data_tgt_train_batch[:, 1:].to(device=device).contiguous().view(-1))

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        print(f"Epoch: {epoch+1}, Batch: {i+1}, Loss: {loss.item()}")

Epoch: 1, Batch: 1, Loss: 6.992865085601807
Epoch: 1, Batch: 2, Loss: 6.9928507804870605
Epoch: 1, Batch: 3, Loss: 6.988949298858643
Epoch: 1, Batch: 4, Loss: 6.983707904815674
Epoch: 1, Batch: 5, Loss: 6.977285385131836
Epoch: 1, Batch: 6, Loss: 6.972637176513672
Epoch: 1, Batch: 7, Loss: 6.95846700668335
Epoch: 1, Batch: 8, Loss: 6.949999809265137
Epoch: 1, Batch: 9, Loss: 6.938541412353516
Epoch: 1, Batch: 10, Loss: 6.9290771484375
Epoch: 1, Batch: 11, Loss: 6.9148969650268555
Epoch: 1, Batch: 12, Loss: 6.9016313552856445
Epoch: 1, Batch: 13, Loss: 6.895565032958984
Epoch: 1, Batch: 14, Loss: 6.878353118896484
Epoch: 1, Batch: 15, Loss: 6.8649139404296875
Epoch: 1, Batch: 16, Loss: 6.846013069152832
Epoch: 1, Batch: 17, Loss: 6.825469493865967
Epoch: 1, Batch: 18, Loss: 6.807492256164551
Epoch: 1, Batch: 19, Loss: 6.788546562194824
Epoch: 1, Batch: 20, Loss: 6.765722751617432
Epoch: 1, Batch: 21, Loss: 6.745671272277832
Epoch: 1, Batch: 22, Loss: 6.724663734436035
Epoch: 1, Batch: 2

In [None]:
def predict(model:nn.Module, n=100):
    model.eval()
    with torch.no_grad():
        preds:torch.Tensor = model(data_src_test[:n,:].to(device=device), data_tgt_test[:n, :-1].to(device=device))
        preds = preds.argmax(dim=-1)
        return preds
    
def evaluate(model:nn.Module, n=100):
    preds:torch.Tensor = predict(model, n)
    preds = preds.flatten().tolist()
    actuals = data_tgt_test[:n, 1:].flatten().tolist()
    accuracy = sum([preds[i] == actuals[i] for i in range(len(preds))])/len(preds)
    return accuracy

In [63]:
evaluate(transformer, n=data_tgt_test.shape[0])

0.978920202020202

In [53]:
# evaluate trained model with test set
transformer.eval()
with torch.no_grad():
    preds:torch.Tensor = transformer(data_src_test[:10,:].to(device=device), data_tgt_test[:10, :-1].to(device=device))

In [58]:
preds.argmax(dim=-1)

tensor([[ 529,  533,  537,  542,  546,  550,  554,  559,  563,  568,  572,  576,
          581,  586,  590,  595,  599,  604,  609,  613,  618,  623,  628,  633,
          638,  643,  648,  653,  658,  663,  668,  674,  679,  684,  689,  695,
          700,  706,  711,  717,  722,  728,  734,  739,  745,  751,  757,  763,
          769,  775,  781,  787,  793,  799,  806,  812,  818,  825,  831,  838,
          844,  851,  857,  864,  871,  878,  884,  891,  898,  905,  912,  920,
          927,  934,  941,  949,  956,  963,  971,  979,  986,  994,  501,  503,
          505,  507,  509,  511,  513,  515,  517,  519,  521,  523,  525,  527,
          529,  531,  533],
        [ 533,  537,  542,  546,  550,  554,  559,  563,  567,  572,  576,  581,
          585,  590,  595,  599,  604,  609,  613,  618,  623,  628,  633,  638,
          643,  648,  653,  658,  663,  668,  673,  679,  684,  689,  695,  700,
          706,  711,  717,  722,  728,  734,  739,  745,  751,  757,  763,  769,


In [59]:
data_tgt_test[:10, 1:]

tensor([[ 529,  533,  537,  542,  546,  550,  554,  559,  563,  568,  572,  576,
          581,  586,  590,  595,  599,  604,  609,  613,  618,  623,  628,  633,
          638,  643,  648,  653,  658,  663,  668,  674,  679,  684,  689,  695,
          700,  706,  711,  717,  722,  728,  734,  739,  745,  751,  757,  763,
          769,  775,  781,  787,  793,  799,  806,  812,  818,  825,  831,  838,
          844,  851,  857,  864,  871,  878,  884,  891,  898,  905,  912,  920,
          927,  934,  941,  949,  956,  963,  971,  979,  986,  994,  501,  503,
          505,  507,  509,  511,  513,  515,  517,  519,  521,  523,  525,  527,
          529,  531,  533],
        [ 533,  537,  542,  546,  550,  554,  559,  563,  567,  572,  576,  581,
          585,  590,  595,  599,  604,  609,  613,  618,  623,  628,  633,  638,
          643,  648,  653,  658,  663,  668,  673,  679,  684,  689,  695,  700,
          706,  711,  717,  722,  728,  734,  739,  745,  751,  757,  763,  769,
