<a href="https://colab.research.google.com/github/justinqbui/min_linformer/blob/main/linformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
from math import sqrt
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

!pip install pytorch_lightning
import pytorch_lightning as pl


In [None]:
!pip install einops
from einops import rearrange

In [18]:
def scaled_attention(queries, keys, values):
    d = queries.shape[-1]
    scores = torch.matmul(queries, keys.transpose(-2,-1))/sqrt(d) #  dot product attention
    attention_weights = F.softmax(scores, dim=-1)  #scaled dot product attention
    return torch.matmul(attention_weights, values)

In [13]:
def project_kv(key, value, E_proj):
    """
    project value and key vectors from the original (n x d)- dimensional
    """
    key = torch.einsum("bhjd, jk -> bhkd", key, E_proj)
    value = torch.einsum("bhjd, jk -> bhkd", value, E_proj)
    return key, value

In [3]:
def EF_proj(input, k, bias = True):
    """
    helper function to init E and F projections
    this variation, the E and F projections aren't trainable
    """
    E_F = nn.Linear(input, k, bias) #randomly initialize the weights with N(0, 1)
    torch.nn.init.normal_(E_F, mean=0.0, std=1/k)
    return E_F


In [25]:
class LinformerSelfAttention(nn.Module):
    '''
    Linear self-attention
    run time = O(nk) where n denotes sequence length and k denotes dim of linear projection of k
    '''
    def __init__(self, embed_dim, seq_len = 512, heads = 4, k = 128, dropout = .1):
        assert (embed_dim % heads) == 0
        self.heads = heads
        self.k = k
        
        self.dim_head = embed_dim // heads   # each head is of size embed_dim / heads
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # create E and F projections 
        self.E_proj = EF_proj(embed_dim, k)  
        self.F_proj = EF_proj(embed_dim, k)
        
        # linear layer that transforms embedding into q k and v vectors
        self.to_q = nn.Linear(embed_dim, embed_dim, bias = False)
        self.to_k = nn.Linear(embed_dim, embed_dim, bias = False)
        self.to_v = nn.Linear(embed_dim, embed_dim, bias = False)

        # resize matrix back to originnn.al input size
        self.output_linear = nn.Linear(embed_dim, embed_dim)



    def forward(self, x, **kwargs):
        assert x.dim() == 3 # batch size, length, k

        # batch size, seq_len, dim, dim_head, num_heads, k
        b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k
        
        # q,k, v vectors
        query = self.to_q(x)
        key = self.to_k(x)
        value = self.to_v(x)

        key, value = project_kv(key, value, self.E_proj)

        out = scaled_attention(query,key, value)
        
        out = rearrange(out, "b h i d -> b i (h d)")
        return self.output_linear(out)


In [5]:
class FeedForward(nn.Module):
    """
    A simple feed forward network with GELU activation function
    """
    def __init__(self, in_dim,ff_dim, out_dim, dropout = .1):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, out_dim)
        self.gelu = nn.GELU()     # can change to activation of your choice, typically GELU in transformers
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        x = self.dropout(x)



In [8]:
class Embeddings(nn.Module):
    def __init__(self, seq_len, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(seq_len, embed_dim)
        self.layer_norm = nn.LayerNorm(embed_dim, eps = 1e-12)
        self.dropout = nn.Dropout()
    
    def forward(self, input_ids):
        # Create position IDs for input sequence
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length,
        dtype=torch.long).unsqueeze(0)
        # create token and position embeddings
        token_emb = self.token_emb(input_ids)
        pos_emb = self.position_embed(position_ids)
        # Combine token and position embeddings
        embeddings = token_emb + pos_emb
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [30]:
class LinformerBlock(nn.Module):
    def __init__(self,heads, dropout, embed_dim, in_dim, ff_dim, out_dim):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.attention = LinformerSelfAttention(embed_dim = 768, seq_len = 512, heads = 4, k = 128, dropout = .1)
        self.feed_forward = FeedForward(in_dim,ff_dim, out_dim, dropout = .1)
        
    def forward(self, x):
        # pre normalization
        hidden_state = self.layer_norm1(x)
        # residual connection 
        x = x + self.attention(hidden_state)
        # feed forward with residual
        x = x + self.feed_forward((self.layer_norm2))
        return x

In [32]:
class MinLinFormer(pl.LightningModule):
    """
    An encoder only min-linformer,  with positional embeddings from the original
    Attention is all you need paper,
    
    """
    def __init__(self, embed_dim, heads, seq_len, vocab_size, in_dim, ff_dim, out_dim, dropout = .1,blocks = 2):
        super().__init__()
        self.embeddings = Embeddings(seq_len, vocab_size, embed_dim)
        self.model = nn.Sequential(
            self.embeddings,
            *[LinformerBlock(heads = heads, dropout = dropout, embed_dim = embed_dim,
            in_dim = in_dim, ff_dim = ff_dim, out_dim = out_dim) for _ in range(blocks)]
            
            )
    def forward(self, x):
        self.model(x) 