In [None]:
import random
from math import sqrt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

!pip install pytorch_lightning
import pytorch_lightning as pl


In [None]:
!pip install einops
from einops import rearrange

In [None]:
from transformers.util.data import DataLoader
# Data Module idea from George Hotz
class BaseDataModule(pl.LightningDataModule):
  def __init__(self, batch_size=32, split=0.8, *args, **kwargs):
    super().__init__()
    self.ds_X, self.ds_Y = self.get_dataset(*args, **kwargs)
    shuffler = np.random.permutation(self.ds_X.shape[0])
    self.ds_X = self.ds_X[shuffler]
    self.ds_Y = self.ds_Y[shuffler]
    self.split = int(self.ds_X.shape[0]*split)
    self.batch_size = batch_size
    
  def train_dataloader(self):
    ds_X_train, ds_Y_train = self.ds_X[0:self.split], self.ds_Y[0:self.split]
    return DataLoader(list(zip(ds_X_train, ds_Y_train)), batch_size=self.batch_size)

  def val_dataloader(self):
    ds_X_test, ds_Y_test = self.ds_X[self.split:], self.ds_Y[self.split:]
    return DataLoader(list(zip(ds_X_test, ds_Y_test)), batch_size=self.batch_size)

  
class WikipediaDataModule(BaseDataModule):
  def get_dataset(self, seq_len=50):
    global enwik8
    if 'enwik8' not in globals():
      import requests
      enwik8_zipped = requests.get("https://data.deepai.org/enwik8.zip").content
      from zipfile import ZipFile
      import io
      enwik8 = ZipFile(io.BytesIO(enwik8_zipped)).read('enwik8')
    en = np.frombuffer(enwik8, dtype=np.uint8).astype(np.int)
    en = en[0:-seq_len+1]
    en[en>127] = 127
    return en[0:-1].reshape(-1, seq_len), en[1:].reshape(-1, seq_len)

In [None]:
def scaled_attention(queries, keys, values):
    d = queries.shape[-1]
    scores = torch.matmul(queries, keys.transpose(-2,-1))/sqrt(d) #  dot product attention
    attention_weights = F.softmax(scores, dim=-1)  #scaled dot product attention
    return torch.matmul(attention_weights, values)

In [None]:
def project_kv(key, value, E_proj):
    """
    project value and key vectors from the original (n x d)- dimensional
    """
    key = torch.einsum("bhjd, jk -> bhkd", key, E_proj)
    value = torch.einsum("bhjd, jk -> bhkd", value, E_proj)
    return key, value

In [None]:
def EF_proj(input, k, bias = True):
    """
    helper function to init E and F projections
    this variation, the E and F projections aren't trainable
    """
    E_F = nn.Linear(input, k, bias) #randomly initialize the weights with N(0, 1)
    torch.nn.init.normal_(E_F, mean=0.0, std=1/k)
    return E_F


In [None]:
def EF_proj(input, k, bias = True):
    """
    helper function to init E and F projections
    this variation, the E and F projections aren't trainable
    """
    E = torch.nn.Parameter(torch.randn(input))
    return E

In [None]:
class LinformerSelfAttention(nn.Module):
    '''
    Linear self-attention
    run time = O(nk) where n denotes sequence length and k denotes dim of linear projection of k
    '''
    def __init__(self, embed_dim, k, seq_len = 512, heads = 4,  dropout = .1):
        super().__init__()
        assert (embed_dim % heads) == 0
        self.heads = heads
        self.k = k
        
        self.dim_head = embed_dim // heads   # each head is of size embed_dim / heads
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # create E and F projections 
        self.E_proj = EF_proj(embed_dim, k)  
        self.F_proj = EF_proj(embed_dim, k)
        
        # linear layer that transforms embedding into q k and v vectors
        self.to_q = nn.Linear(embed_dim, embed_dim, bias = False)
        self.to_k = nn.Linear(embed_dim, embed_dim, bias = False)
        self.to_v = nn.Linear(embed_dim, embed_dim, bias = False)

        # resize matrix back to originnn.al input size
        self.output_linear = nn.Linear(embed_dim, embed_dim)



    def forward(self, x, **kwargs):
        assert x.dim() == 3 # batch size, length, k

        # batch size, seq_len, dim, dim_head, num_heads, k
        b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k
        
        # q,k, v vectors
        query = self.to_q(x)
        key = self.to_k(x)
        value = self.to_v(x)

        key, value = project_kv(key, value, self.E_proj)

        out = scaled_attention(query,key, value)
        

        return self.output_linear(out)


In [None]:
class LinformerSelfAttention(nn.Module):
    def __init__(self, embed_dim, k, seq_len, dim_head, heads,dropout = .1):
        super().__init__()
        self.embed_dim = embed_dim
        self.k = k
        self.seq_len = seq_len
        self.dim_head = dim_head if dim_head is not None else embed_dim // heads
        self.dropout = nn.Dropout(dropout)
        self.E_proj = EF_proj(seq_len, k)  
        self.F_proj = EF_proj(seq_len, k)
        self.output_linear = nn.Linear(dim_head * heads, embed_dim)
    
    def forward(self, x):
        b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k
        



In [None]:
class FeedForward(nn.Module):
    """
    A simple feed forward network with GELU activation function
    """
    def __init__(self, in_dim,ff_dim, out_dim, dropout = .1):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, out_dim)
        self.gelu = nn.GELU()     # can change to activation of your choice, typically GELU in transformers
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        x = self.dropout(x)



In [None]:
class Embeddings(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(seq_len, embed_dim)
        self.layer_norm = nn.LayerNorm(embed_dim, eps = 1e-12)
        self.dropout = nn.Dropout()
    
    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length,
        dtype=torch.long).unsqueeze(0)
        # create token and position embeddings
        token_emb = self.token_emb(input_ids)
        pos_emb = self.position_embed(position_ids)
        # Combine token and position embeddings
        embeddings = token_emb + pos_emb
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
class Embeddings(nn.Module):
  def __init__(self, seq_len, vocab_size, embed_dim):
    super().__init__()
    self.token_emb = nn.Embedding(vocab_size, embed_dim)
    self.pos_emb = nn.Embedding(seq_len, embed_dim)
  def forward(self, x):
    pos = torch.arange(0, x.size(1), dtype=torch.int32, device=x.device)
    return self.token_emb(x) + self.pos_emb(pos).view(1, x.size(1), -1)

In [None]:
class LinformerBlock(nn.Module):
    def __init__(self,heads, dropout, embed_dim, in_dim, ff_dim, out_dim):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.attention = LinformerSelfAttention(embed_dim = 768, seq_len = 512, heads = 4, k = 128, dropout = .1)
        self.feed_forward = FeedForward(in_dim,ff_dim, out_dim, dropout = .1)
        
    def forward(self, x):
        # pre normalization
        hidden_state = self.layer_norm1(x)
        # residual connection 
        x = x + self.attention(hidden_state)
        # feed forward with residual
        x = x + self.feed_forward((self.layer_norm2))
        return x

In [None]:
class MinLinFormer(pl.LightningModule):
    """
    An encoder only min-linformer,  with positional embeddings from the original
    Attention is all you need paper,
    
    """
    def __init__(self, k, embed_dim, heads, seq_len, vocab_size, in_dim, ff_dim, out_dim,  dropout = .1,blocks = 2):
        super().__init__()
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.model = nn.Sequential(
            self.embeddings,
            *[LinformerBlock(heads = heads, dropout = dropout, embed_dim = embed_dim,
            in_dim = in_dim, ff_dim = ff_dim, out_dim = out_dim) for _ in range(blocks)]
            
            )
    def forward(self, x):
        self.model(x) 

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = F.nll_loss(output.view(-1, self.max_value), y.view(-1))
        self.log("train_loss", loss)
        return loss
  
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        pred = self.model(x).argmax(dim=2)
        val_accuracy = (pred == y).type(torch.float).mean()
        self.log("val_accuracy", val_accuracy, prog_bar=True)
  
    def configure_optimizers(self):
    
        return torch.optim.Adam(self.parameters(), lr=3e-4)

In [None]:
model = MinLinFormer(embed_dim = 512, heads = 2, seq_len = 5, vocab_size = 10000, in_dim = 512, ff_dim = 1024, out_dim  = 512)
trainer = pl.Trainer(enable_progress_bar=True, max_epochs=5, gpus=1)
data = AdditionDataModule(batch_size=64)

trainer.fit(model, data)

In [None]:
class LinformerLM(nn.Module):
    def __init__(self, num_tokens, dim, seq_len, depth, k = 256, heads = 4, dropout = 0.):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.linformer = MinLinFormer(dim, seq_len, depth, k = k, heads = heads, dim_head = dim_head,
                one_kv_head = one_kv_head, share_kv = share_kv, reversible = reversible, dropout = dropout)
        self.to_logits = nn.Linear(dim, num_tokens)

    def forward(self, x):
        x = self.token_emb(x)
        x = self.pos_emb(torch.arange(x.shape[1], device=x.device)) + x
        x = self.linformer(x)
        out = self.to_logits(x)
        return out