In [1]:
from IPython.display import Image, display
from IPython.core.display import HTML
import torch

  from .autonotebook import tqdm as notebook_tqdm


# Sparse Transformer

[Sparse Transformers](https://arxiv.org/pdf/1904.10509.pdf) proposes 1. an architecture variation and initialization 2. memory efficient computation method for attention matricies 3. an attention kernel to improve speed

In [2]:
Image(url='https://d3i71xaburhd42.cloudfront.net/21da617a0f79aabf94272107184606cefe90ab75/5-Figure4-1.png', height = 500, width = 300)

In [3]:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

embedding.shape

torch.Size([20, 5, 10])

In [4]:
"""
Paper - https://arxiv.org/pdf/1904.10509.pdf
Building models in pytorch - https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html
LayerNorm - https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
Linear Layer - https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
MultiHeadAttention - https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
Dropout - https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
GELU - https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
Softmax - https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html
"""

import torch

batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

class ResidualBlock(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.norm = torch.nn.LayerNorm((sentence_length, embedding_dim))
        self.linear = torch.nn.Linear(in_features = embedding_dim, out_features = 10)
        self.attention = torch.nn.MultiheadAttention(embed_dim = embedding_dim, num_heads = 1)
        self.dropout = torch.nn.Dropout()
        self.gelu = torch.nn.GELU()

    def forward(self, x):
        res = self.norm(x)
        res, attn_output_weights = self.attention(res, res, res)
        res = self.dropout(res)
        res += x
        res = self.norm(res)
        res = self.linear(res)
        res = self.gelu(res)
        res = self.dropout(res)
        res += x
        return res
    
ResBlock = ResidualBlock()
ResBlock.forward(embedding).shape

torch.Size([20, 5, 10])

In [13]:
class SparseTransformer(torch.nn.Module):
    def __init__(self):
        super(SparseTransformer, self).__init__()
        self.ResBlock = ResidualBlock()
        self.linear = torch.nn.Linear(in_features = embedding_dim*5, out_features = 5)
        self.softmax = torch.nn.Softmax(dim = 1)
        
    def forward(self, x):
        x = self.ResBlock(x)
        x = self.linear(x.view([-1, 50]))
        x = self.softmax(x)
        return x
    
SparseTransformer = SparseTransformer()
SparseTransformer.forward(embedding).shape

torch.Size([20, 5])

In [9]:
data = torch.randn(2, 3)

In [12]:
m = torch.nn.Softmax(dim=0)
output = m(data)
output

tensor([[0.9110, 0.6456, 0.3931],
        [0.0890, 0.3544, 0.6069]])