In [None]:
import random
from math import sqrt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

!pip install pytorch_lightning
import pytorch_lightning as pl


In [None]:
!pip install einops
from einops import rearrange

In [2]:
from torch.utils.data import DataLoader
# Data Module idea from George Hotz
class BaseDataModule(pl.LightningDataModule):
  def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
    super().__init__()
    self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)
    shuffler = np.random.permutation(self.ds_X.shape[0])
    self.ds_X = self.ds_X[shuffler]
    self.ds_Y = self.ds_Y[shuffler]
    self.split = int(self.ds_X.shape[0]*split)
    self.batch_size = batch_size
    
  def train_dataloader(self):
    ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]
    return DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)

  def val_dataloader(self):
    ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]
    return DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)

  
class WikipediaDataModule(BaseDataModule):
  def get_dataset(self, seq_len=50):
    global enwik8
    if 'enwik8' not in globals():
      import requests
      enwik8_zipped = requests.get("https://data.deepai.org/enwik8.zip").content
      from zipfile import ZipFile
      import io
      enwik8 = ZipFile(io.BytesIO(enwik8_zipped)).read('enwik8')
    en = np.frombuffer(enwik8, dtype=np.uint8).astype(np.int)
    en = en[0:-seq_len+1]
    en[en>127] = 127
    return en[0:-1].reshape(-1, seq_len), en[1:].reshape(-1, seq_len)

In [74]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout):
        """
        A feed forward network after scaled dot product attention
        Params:
        embed_dim = embedding dimension of vector
        hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
        dropout = % dropout for training
        """
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x = self.FFN(x)
        return x

In [66]:
def attention(query, key, value):
    """
    calculate scaled dot product attention given a q, k and v
    Params:
    query -> a given query tensor
    key -> a given key tensor
    value -> a given value tensor
    """
    dim = query.shape[-1]
    # (Query * tranpose(key)) / sqrt(dim)
    scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim)
    weights = F.softmax(scores, dim = -1)
    return torch.bmm(weights, value)

In [67]:
def get_EF(input_dim, k, bias = True):
    """
    helper function to init E and F projections
    this variation, the E and F projections aren't trainable
    """
    E_F = nn.Linear(input_dim, k, bias)
    torch.nn.init.xavier_normal_(E_F.weight)
    return E_F
 

In [68]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim, k, E_i, F_i = None, share_kv = False):
        super().__init__()
        self.to_q = nn.Linear(embed_dim, head_dim)
        self.to_k = nn.Linear(embed_dim, head_dim)   
        self.to_v = nn.Linear(embed_dim, head_dim)
        self.E_i = E_i 
        self.F_i = F_i
        self.share_kv = share_kv
    
    def forward(self, x):
        if self.share_kv == True:
            #down project k and v vectors
            down_k = torch.matmul(self.E_i, self.to_k(x)) 
            down_v = torch.matmul(self.E_i, self.to_v(x))
        else:
            down_k = torch.matmul(self.E_i, self.to_k(x)) 
            down_v = torch.matmul(self.F_i, self.to_v(x))
        
        attn = attention(self.to_q(x), down_k, down_v)
        return attn


In [69]:
class MultiHeadAttention(nn.Module):
    """
    Calculates the Multi-Headed attention
    Params:
    num_heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    embed_dim -> embedding dimension of vector x
    """
    def __init__(self, num_heads, embed_dim, k, share_kv = False, share_headwise = True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads 

        self.E_i = get_EF(embed_dim, k)
        self.F_i = get_EF(embed_dim, k)
        
        if share_headwise == False:
            self.heads = nn.ModuleList(
            AttentionHead(self.embed_dim, self.head_dim, k, get_EF(embed_dim,k), get_EF(embed_dim, k), share_kv) 
            for _ in range(num_heads)
        )
        else:
            self.heads = nn.ModuleList(
                AttentionHead(self.embed_dim, self.head_dim, k, self.E_i, self.F_i, share_kv) 
                for _ in range(num_heads)
            )
        self.to_out = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        # calculate attention for each head and concatenate tensor
        x = torch.cat([head(x) for head in self.heads], dim = -1)
        x = self.to_out(x)
        return x

In [70]:
class Embeddings(nn.Module):
  def __init__(self, seq_len, vocab_size, embed_dim):
    super().__init__()
    self.token_emb = nn.Embedding(vocab_size, embed_dim)
    self.pos_emb = nn.Embedding(seq_len, embed_dim)
  def forward(self, x):
    pos = torch.arange(0, x.size(1), dtype=torch.int32, device=x.device)
    return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)

In [83]:
class LinformerBlock(nn.Module):
    def __init__(self,heads, dropout, embed_dim, hidden_dim , k, share_kv, share_headwise):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadAttention(num_heads = heads, embed_dim =embed_dim, k = k, share_kv = False, share_headwise = True)
        self.feed_forward = FeedForward(embed_dim,hidden_dim, dropout = dropout)
        
    def forward(self, x):
        # pre normalization
        hidden_state = self.layer_norm1(x)
        # residual connection 
        x = x + self.attention(hidden_state)
        # feed forward with residual
        x = x + self.feed_forward((self.layer_norm2))
        return x

In [84]:
class MinLinFormer(pl.LightningModule):
    """
    An encoder only min-linformer,  with positional embeddings from the original
    Attention is all you need paper,
    
    """
    def __init__(self, k, embed_dim, heads, seq_len, vocab_size, hidden_dim , share_kv, share_headwise, dropout ,blocks ):
        super().__init__()
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.model = nn.Sequential(
            self.embeddings,
            *[LinformerBlock(heads = heads, dropout = dropout, embed_dim = embed_dim,
            hidden_dim = hidden_dim, k = k, share_kv = share_kv,
            share_headwise = share_headwise) for _ in range(blocks)]
            
            )
    def forward(self, x):
        self.model(x) 

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = F.nll_loss(output.view(-1, self.max_value), y.view(-1))
        self.log("train_loss", loss)
        return loss
  
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x).argmax(dim=2)
        val_accuracy = (pred == y).type(torch.float).mean()
        self.log("val_accuracy", val_accuracy, prog_bar=True)
  
    def configure_optimizers(self):
    
        return torch.optim.Adam(self.parameters(), lr=3e-4)

In [None]:
model = MinLinFormer(embed_dim = 512, heads = 2, k = 128, seq_len = 5, 
                     vocab_size = 10000, hidden_dim = 1024,
                     blocks = 2, dropout = .1, share_headwise=True, share_kv = True)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=5, gpus=0)
data = WikipediaDataModule(batch_size=64)

trainer.fit(model, data)

In [86]:
class LinformerLM(nn.Module):
    def __init__(self, num_tokens, dim, seq_len, depth, k = 256, heads = 4, dropout = 0.):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.linformer = MinLinFormer(dim, seq_len, depth, k = k, heads = heads, dim_head = dim_head,
                one_kv_head = one_kv_head, share_kv = share_kv, reversible = reversible, dropout = dropout)
        self.to_logits = nn.Linear(dim, num_tokens)

    def forward(self, x):
        x = self.token_emb(x)
        x = self.pos_emb(torch.arange(x.shape[1], device=x.device)) + x
        x = self.linformer(x)
        out = self.to_logits(x)
        return out