<a href="https://colab.research.google.com/github/justinqbui/min_linformer/blob/main/linformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
from math import sqrt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

!pip install pytorch_lightning
import pytorch_lightning as pl


In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_gpus = 1 if torch.cuda.is_available() else 0

In [32]:
# from https://github.com/geohot
class BaseDataModule(pl.LightningDataModule):
  def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
    super().__init__()
    self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)
    shuffler = np.random.permutation(self.ds_X.shape[0])
    self.ds_X = self.ds_X[shuffler]
    self.ds_Y = self.ds_Y[shuffler]
    self.split = int(self.ds_X.shape[0]*split)
    self.batch_size = batch_size
    
  def train_dataloader(self):
    ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]
    return torch.utils.data.DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)

  def val_dataloader(self):
    ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]
    return torch.utils.data.DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)

class ReverseDataModule(BaseDataModule):
  def get_dataset(self, cnt=10000, seq_len=6):
    ds = np.random.randint(0, 10, size=(cnt, seq_len))
    print(ds, ds[:, ::-1].ravel().reshape(cnt, seq_len))
    return ds, ds[:, ::-1].ravel().reshape(cnt, seq_len)


In [33]:
class FeedForward(pl.LightningModule):
    def __init__(self, embed_dim, hidden_dim, dropout):
        """
        A feed forward network after scaled dot product attention
        Params:
        embed_dim = embedding dimension of vector
        hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
        dropout = % dropout for training
        """
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x = self.FFN(x)
        return x

In [34]:
class Embeddings(pl.LightningModule):
    """
    Converts input token into token embeddings and adds learnable projection embeddings
    Params:
    seq_len -> number of tokens in sequence
    vocab_size -> number of tokens in total vocabulary
    embed_dim -> size of each embedding vector
    """
    def __init__(self, seq_len, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(seq_len, embed_dim)
    def forward(self, x):
       # print(x.shape)
        pos = torch.arange(0, x.size(1)).to(device)
        return self.token_emb(x) + self.pos_emb(pos)

In [35]:
def attention(query, key, value):
    """
    calculate scaled dot product attention given a q, k and v
    Params:
    query -> a given query tensor
    key -> a given key tensor
    value -> a given value tensor
    """
    dim = query.shape[-1]
    scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim)
    weights = F.softmax(scores, dim = -1)
    return torch.bmm(weights, value)

In [36]:
def get_EF(input_dim, k, bias = True):
    """
    helper function to init E and F projections which are learnable
    """
    E_F = nn.Linear(input_dim, k, bias)
    torch.nn.init.xavier_normal_(E_F.weight)
    return E_F

In [37]:
class AttentionHead(pl.LightningModule):
    """
    Compute self-attention according to the LinFormer paper
    We downproject the K and V vectors to reduce the size of the 
    matrix 
    Params:
    embed_dim -> size of each embedding vector
    head_dim -> dimensionality of projected x embedding to key, query, and value vectors
    k -> dimensionality to reduce the key and value vectors
    E_i -> E matrix to down project the key vector 
    F_i -> F matrix to down project the value vector
    share_kv -> bool -> both key and value vectors down projected by E_i matrix
    """
    def __init__(self, embed_dim, head_dim, k, E_i, F_i = None, share_kv = False):
        super().__init__()
        self.to_q = nn.Linear(embed_dim, head_dim)
        self.to_k = nn.Linear(embed_dim, head_dim)   
        self.to_v = nn.Linear(embed_dim, head_dim)
        self.E_i = E_i 
        self.F_i = F_i
        self.share_kv = share_kv
    
    def forward(self, x):
        if self.share_kv:
            # K and V vectors share same down projection matrix E
            down_k = self.E_i(self.to_k(x))
            down_v = self.E_i(self.to_v(x))
        else:
            # K down projected by E matrix
            down_k = self.E_i(self.to_k(x))
            # V down projected by F matrix
            down_v = self.F_i(self.to_v(x))

        attn = attention(self.to_q(x), down_k, down_v)
        return attn

In [38]:
class MultiHeadAttention(pl.LightningModule):
    """
    Calculates the Multi-Headed linear attention
    Params:
    num_heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    embed_dim -> embedding dimension of vector x
    k -> dimensionality to reduce the key and value vectors
    share_kv -> bool -> both key and value vectors down projected by E_i matrix
    share_headwise -> share E_i and F_i projection matrices
    """
    def __init__(self, num_heads, embed_dim, k, share_kv = False, share_headwise = True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads 

        self.E_i = get_EF(self.head_dim, k)
        self.F_i = get_EF(self.head_dim, k)
        
        if share_headwise == False: 
            self.heads = nn.ModuleList(
            AttentionHead(self.embed_dim, self.head_dim, k, get_EF(self.head_dim,k), get_EF(self.head_dim, k), share_kv) 
            for _ in range(num_heads)
        )
        else:
            self.heads = nn.ModuleList(
                AttentionHead(self.embed_dim, self.head_dim, k, self.E_i, self.F_i, share_kv) 
                for _ in range(num_heads)
            )
        self.to_out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        # calculate attention for each head and concatenate tensor
        x = torch.cat([head(x) for head in self.heads], dim = -1)
        x = self.to_out(x)
        return x

In [39]:
class LinformerBlock(pl.LightningModule):
    """
    Creates encoder blocks for Transformer model
    Params:
    heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    dropout -> float value for dropout layers
    embed_dim -> embedding dimension of vector x
    hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
    k -> dimensionality to reduce the key and value vectors
    share_kv -> bool -> both key and value vectors down projected by E_i matrix
    share_headwise -> share E_i and F_i projection matrices
    """
    def __init__(self,heads, dropout, embed_dim, hidden_dim , k, share_kv, share_headwise):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadAttention(num_heads = heads, embed_dim =
        embed_dim, k = k, share_kv = share_kv, share_headwise = share_headwise)
        self.feed_forward = FeedForward(embed_dim,hidden_dim, dropout = dropout)
        
    def forward(self, x):
        # pre normalization
        hidden_state = self.layer_norm1(x)
        # residual connection 
        x = x + self.attention(hidden_state)
        # feed forward with residual
        x = x + self.feed_forward((self.layer_norm2(x)))
        return x

In [40]:
from pytorch_lightning.core.datamodule import LightningDataModule
class Linformer(pl.LightningModule):
    """
    An encoder only min-linformer,  with positional embeddings from the original
    Attention is all you need paper,
    Params:
    k -> dimensionality to reduce the key and value vectors
    embed_dim -> embedding dimension of vector x
    heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    seq_len -> number of tokens in sequence
    vocab_size -> number of tokens in total vocabulary
    hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
    share_kv -> bool -> both key and value vectors down projected by E_i matrix
    share_headwise -> share E_i and F_i projection matrices
    dropout -> float value for dropout layers
    blocks -> number of encoder blocks to stack
    """
    def __init__(self, k, embed_dim, heads, seq_len, vocab_size, hidden_dim , share_kv, share_headwise, dropout ,blocks):
        super().__init__()
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.vocab_size = vocab_size
        self.model = nn.Sequential(
            self.embeddings,
            *[LinformerBlock(heads = heads, dropout = dropout, embed_dim = embed_dim,
            hidden_dim = hidden_dim, k = k, share_kv = share_kv,
            share_headwise = share_headwise) for _ in range(blocks)],
            nn.Linear(embed_dim, vocab_size),
            nn.LogSoftmax(dim = -1))
            
        
    def forward(self, x):
        x = self.model(x) 
        

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = F.nll_loss(output.view(-1, self.vocab_size), y.view(-1))
        return loss
  
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x).argmax(dim=2)
        val_accuracy = (pred == y).type(torch.float).mean()
  
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-4)

In [41]:
data = ReverseDataModule(cnt=1000, seq_len=20)

[[0 6 2 ... 6 5 5]
 [1 8 0 ... 3 8 0]
 [0 1 7 ... 9 2 3]
 ...
 [3 8 7 ... 6 5 8]
 [8 7 3 ... 5 6 3]
 [5 7 2 ... 2 7 4]] [[5 5 6 ... 2 6 0]
 [0 8 3 ... 0 8 1]
 [3 2 9 ... 7 1 0]
 ...
 [8 5 6 ... 7 8 3]
 [3 6 5 ... 3 7 8]
 [4 7 2 ... 2 7 5]]


In [42]:
model = Linformer(embed_dim = 512, heads = 4, k = 128, seq_len = 20, 
                     vocab_size = 1024, hidden_dim = 1024,
                     blocks = 2, dropout = .1, share_headwise=True, share_kv = True)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=5, gpus=num_gpus)

trainer.fit(model, data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | embeddings | Embeddings | 534 K 
1 | model      | Sequential | 5.3 M 
------------------------------------------
5.3 M     Trainable params
0         Non-trainable params
5.3 M     Total params
21.326    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [43]:
class LinformerLM(pl.LightningModule):
    """
    An encoder only min-linformer,  with positional embeddings from the original
    Attention is all you need paper
    Uses Language Modelling pre-training objective
    Params:
    k -> dimensionality to reduce the key and value vectors
    embed_dim -> embedding dimension of vector x
    heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    seq_len -> number of tokens in sequence
    vocab_size -> number of tokens in total vocabulary
    hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
    share_kv -> bool -> both key and value vectors down projected by E_i matrix
    share_headwise -> share E_i and F_i projection matrices
    dropout -> float value for dropout layers
    blocks -> number of encoder blocks to stack
    """
    def __init__(self, k, embed_dim, heads, seq_len, vocab_size, hidden_dim , share_kv, share_headwise, dropout ,blocks):
        super().__init__()
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.vocab_size = vocab_size
        self.model = nn.Sequential(
            self.embeddings,
            *[LinformerBlock(heads = heads, dropout = dropout, embed_dim = embed_dim,
            hidden_dim = hidden_dim, k = k, share_kv = share_kv,
            share_headwise = share_headwise) for _ in range(blocks)],
            nn.Linear(embed_dim, vocab_size))
            
        
    def forward(self, x):
        x = self.model(x) 
        

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = F.nll_loss(output.view(-1, self.vocab_size), y.view(-1))
        return loss
  
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x).argmax(dim=2)
        val_accuracy = (pred == y).type(torch.float).mean()
  
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-4)

In [44]:
model = LinformerLM(embed_dim = 512, heads = 4, k = 128, seq_len = 20, 
                     vocab_size = 1024, hidden_dim = 1024,
                     blocks = 2, dropout = .1, share_headwise=True, share_kv = True)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=5, gpus=num_gpus)

trainer.fit(model, data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | embeddings | Embeddings | 534 K 
1 | model      | Sequential | 5.3 M 
------------------------------------------
5.3 M     Trainable params
0         Non-trainable params
5.3 M     Total params
21.326    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]