In [1]:
from IPython.display import Image, display
from IPython.core.display import HTML

# Sparse Transformer

[Sparse Transformers](https://arxiv.org/pdf/1904.10509.pdf) 

In [24]:
Image(url='https://d3i71xaburhd42.cloudfront.net/21da617a0f79aabf94272107184606cefe90ab75/5-Figure4-1.png', height = 500, width = 300)

In [22]:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

test = torch.nn.LayerNorm((sentence_length, embedding_dim))
test(embedding).shape

torch.Size([20, 5, 10])

In [43]:
linear = torch.nn.Linear(in_features = embedding_dim, out_features = 10)
linear(embedding).shape

torch.Size([20, 5, 10])

In [47]:
embedding.view([20, 50])

tensor([[ 2.0206e+00,  1.5834e+00,  2.1301e+00,  1.5505e-01,  1.8897e+00,
          1.0426e+00, -7.2495e-01,  5.2339e-01, -9.4932e-01,  2.1811e-01,
         -1.7737e+00, -8.3407e-01, -1.5284e+00, -3.0882e-01,  3.5281e-01,
          1.0329e+00, -1.0169e+00,  1.7042e+00, -4.6733e-01,  1.4523e-01,
         -4.5024e-01, -1.1600e-01,  2.1275e-01,  1.0998e+00,  1.0220e+00,
         -9.7954e-01, -7.0801e-01,  4.9808e-01, -6.4611e-01, -6.4230e-01,
         -5.0815e-01, -2.0468e-01,  3.8282e-01, -2.1118e-01,  9.9657e-01,
          5.7417e-01, -7.7482e-01, -2.7934e-01, -1.0318e+00, -1.0979e+00,
          4.0156e-01,  3.9143e-01,  4.5486e-01, -4.4089e-01, -5.4288e-01,
          1.7671e+00, -2.9547e-01, -2.6591e-01, -1.6398e+00, -1.4759e+00],
        [ 6.4493e-01,  2.0873e-01, -9.3720e-01, -1.0340e+00,  7.4781e-01,
         -1.3378e+00,  1.3574e+00, -5.5806e-01,  3.2735e-01, -1.8075e+00,
          5.1112e-01, -1.0362e+00,  1.7144e+00, -1.0005e+00,  7.7210e-01,
          1.3846e+00, -3.0017e-01,  4

In [36]:
"""
Paper - https://arxiv.org/pdf/1904.10509.pdf
Building models in pytorch - https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html
LayerNorm - https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
Linear Layer - https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
MultiHeadAttention - https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
Dropout - https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
GELU - https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
"""

import torch

#how do we ensure that 2d input can go into linear layer
#where do we actually use the gelu activation?
#something doesn't look right here? How 

batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)

class ResidualBlock(torch.nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.norm = torch.nn.LayerNorm((sentence_length, embedding_dim))
        self.linear = torch.nn.Linear(in_features = embedding_dim, out_features = 10)
        self.attention = torch.nn.MultiheadAttention(embed_dim = embedding_dim, num_heads = 1)
        self.dropout = torch.nn.Dropout()
        self.gelu = torch.nn.GELU()

    def forward(self, x):
        res = self.norm(x)
        res, attn_output_weights = self.attention(res, res, res)
        res = self.dropout(res)
        res += x
        res = self.norm(res)
        res = self.linear(res)
        res = self.gelu(res)
        res = self.dropout(res)
        res += x
        return res
    
ResBlock = ResidualBlock()
ResBlock.forward(embedding).shape

torch.Size([20, 5, 10])

In [50]:
class SparseTransformer(torch.nn.Module):
    def __init__(self):
        super(SparseTransformer, self).__init__()
        self.ResBlock = ResidualBlock()
        self.linear = torch.nn.Linear(in_features = embedding_dim*5, out_features = 5)
        self.softmax = torch.nn.Softmax()
        
    def forward(self, x):
        x = self.ResBlock(x)
        x = self.linear(x.view([-1, 50]))
        x = self.softmax(x)
        return x
    
SparseTransformer = SparseTransformer()
SparseTransformer.forward(embedding).shape

  x = self.softmax(x)


torch.Size([20, 5])