I'm gonna try and implement [Attention is All You Need](https://arxiv.org/abs/1706.03762) for learning transformers.

In [1]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads) -> None:
        """
        Args:
            embed_size (int): Embedding size of the input
            heads (int): Number of heads to split the input into for multi-head attention
        """
        super(SelfAttention, self).__init__()

        self.embed_size = embed_size
        self.heads = heads
        self.heads_dim = embed_size // heads

        assert (self.heads_dim * heads == embed_size), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.heads_dim, self.heads_dim, bias=False)
        self.keys = nn.Linear(self.heads_dim, self.heads_dim, bias=False)
        self.queries = nn.Linear(self.heads_dim, self.heads_dim, bias=False)

        # Fully connected layer to process after heads are concatenated again
        self.fc_out = nn.Linear(heads * self.heads_dim, embed_size) # Both are actually the same, but just for clarity

    def forward(self, values, keys, queries, mask):
        """
        Args:
            values:     Input Sentence 
            keys:       Input Sentence
            queries:    Output Sentence
            mask:       Mask to hide some part of input so that we actually learn

        Shapes:
            values:     N x length of input/output vector x embedding_size
            keys:       N x length of input/output vector x embedding_size
            queries:    N x length of input/output vector x embedding_size
            mask:       
        
        Variables:
            N:  Number of values passed in
        """
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embeddings into self.heads number of pieces
        values = values.reshape(N, value_len, self.heads, self.heads_dim)
        keys = keys.reshape(N, value_len, self.heads, self.heads_dim)
        queries = queries.reshape(N, value_len, self.heads, self.heads_dim)

        # qeries.shape = (N, query_len, heads, heads_dim)
        # keys.shape = (N, key_len, heads, heads_dim)
        # energy.shape = (N, heads, query_len, key_len)
        # torch.einsum is a function that allows us to do matrix multiplication on multiple dimensions using simple notation
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # dim 3 since we are normalizing across source sentence (key_len) cuz we wanna see attention of each word in source sentence
        attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim = 3) 

        # attention.shape = (N, heads, query_len, key_len)
        # values.shape = (N, value_len, heads, heads_dim)
        # out.shape = (N, query_len, heads, heads_dim)
        # Key and value are same, so we are gonna multiply across that dimension. We refer to it as 'l'
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])

        # flatten last 2 dimensions to concatenate
        out = out.reshape(N, query_len, self.heads * self.heads_dim)

        out = self.fc_out(out)

        return out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion) -> None:
        super(TransformerBlock, self).__init__()

        self.attention = SelfAttention(embed_size, heads)

        # TODO: lookup LayerNorm vs BatchNorm
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))

        return out