<a href="https://colab.research.google.com/github/justinqbui/transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
from math import sqrt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

!pip install pytorch_lightning
import pytorch_lightning as pl


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_gpus = 1 if torch.cuda.is_available() else 0

In [3]:
class BaseDataModule(pl.LightningDataModule):
    # from George Hotz
  def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
    super().__init__()
    self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)
    shuffler = np.random.permutation(self.ds_X.shape[0])
    self.ds_X = self.ds_X[shuffler]
    self.ds_Y = self.ds_Y[shuffler]
    self.split = int(self.ds_X.shape[0]*split)
    self.batch_size = batch_size
    
  def train_dataloader(self):
    ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]
    return torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)

  def val_dataloader(self):
    ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]
    return torch.utils.data.DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)

class ReverseDataModule(BaseDataModule):
  def get_dataset(self, cnt=10000, seq_len=6):
    ds = np.random.randint(0, 10, size=(cnt, seq_len))
    return ds, ds[:, ::-1].ravel().reshape(cnt, seq_len)


In [4]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout):
        """
        A feed forward network after scaled dot product attention
        Params:
        embed_dim = embedding dimension of vector
        hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
        dropout = % dropout for training
        """
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x = self.FFN(x)
        return x

In [5]:
def attention(query, key, value):
    """
    calculate scaled dot product attention given a q, k and v
    Params:
    query -> a given query tensor
    key -> a given key tensor
    value -> a given value tensor
    """
    dim = query.shape[-1]
    # (Query * tranpose(key)) / sqrt(dim)
    scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim)
    weights = F.softmax(scores, dim = -1)
    return torch.bmm(weights, value)

In [6]:
class AttentionHead(nn.Module):
    """
    Compute self-attention of one head of multihead attention
    Params:
    embed_dim -> size of each embedding vector
    head_dim -> dimensionality of projected x embedding to key, query, and value vectors
    """
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.to_q = nn.Linear(embed_dim, head_dim)
        self.to_k = nn.Linear(embed_dim, head_dim)   
        self.to_v = nn.Linear(embed_dim, head_dim)
 
    
    def forward(self, x):
        attn = attention(self.to_q(x), self.to_q(x), self.to_v(x))
        return attn


In [7]:
class MultiHeadAttention(nn.Module):
    """
    Calculates the Multi-Headed attention
    Params:
    heads -> number of heads to use, each head_dim is calculated as embed_dim // heads 
    embed_dim -> embedding dimension of vector x
    """
    def __init__(self, heads, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // heads 
        self.heads = nn.ModuleList([AttentionHead(embed_dim, self.head_dim) for _ in range(heads)]
          )
        self.to_out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        # calculate attention for each head and concatenate tensor
        x = torch.cat([head(x) for head in self.heads], dim = -1)
        x = self.to_out(x)
        return x

In [8]:
class Embeddings(nn.Module):
    """
    Converts input token into token embeddings and adds learnable projection embeddings
    Params:
    seq_len -> number of tokens in sequence
    vocab_size -> number of tokens in total vocabulary
    embed_dim -> size of each embedding vector
    """
    def __init__(self, seq_len, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(seq_len, embed_dim)
    def forward(self, x):
        pos = torch.arange(0, x.size(1))
        return self.token_emb(x) + self.pos_emb(pos)

In [9]:
class Block(nn.Module):
    """
    Creates encoder blocks for Transformer model
    Params:
    heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    dropout -> float value for dropout layers
    embed_dim -> embedding dimension of vector x
    hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
    """
    def __init__(self,heads, dropout, embed_dim, hidden_dim):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadAttention(heads = heads, embed_dim =embed_dim)
        self.feed_forward = FeedForward(embed_dim,hidden_dim, dropout = dropout)
        
    def forward(self, x):
        # pre normalization
        hidden_state = self.layer_norm1(x)
        # residual connection 
        x = x + self.attention(hidden_state)
        # feed forward with residual
        x = x + self.feed_forward((self.layer_norm2(x)))
        return x

In [16]:
class MiniTransformer(pl.LightningModule):
    """
    An encoder only Transformer,  with positional embeddings from the original
    Attention is all you need paper
    Params:
    embed_dim -> embedding dimension of vector x
    heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    seq_len -> number of tokens in sequence
    vocab_size -> number of tokens in total vocabulary
    hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
    dropout -> float value for dropout layers
    blocks -> number of encoder blocks to stack  
    """
    def __init__(self, embed_dim = 128, heads = 4, seq_len = 6, vocab_size = 10, hidden_dim = 256, dropout = 0.1 ,blocks = 2):
        super().__init__()
        self.vocab_size = vocab_size
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.model = nn.Sequential(
            self.embeddings,
            *[Block(heads = heads, dropout = dropout, embed_dim = embed_dim,
            hidden_dim = hidden_dim) for _ in range(blocks)],
            nn.Linear(embed_dim, vocab_size),
            nn.LogSoftmax(dim = -1))

        
    def forward(self, x):
        self.model(x) 
        

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = F.nll_loss(output.view(-1, self.vocab_size), y.view(-1))
        return loss
  
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x).argmax(dim=2)
        val_accuracy = (pred == y).type(torch.float).mean()
  
    def configure_optimizers(self):
    
        return torch.optim.AdamW(self.parameters(), lr=3e-4)

In [17]:
model = MiniTransformer(vocab_size = 128, seq_len=128)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=5, gpus=num_gpus)
data = ReverseDataModule(cnt=1000, seq_len=20)

trainer.fit(model, data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type       | Params
------------------------------------------
0 | embeddings | Embeddings | 32.8 K
1 | model      | Sequential | 314 K 
------------------------------------------
314 K     Trainable params
0         Non-trainable params
314 K     Total params
1.257     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]