In [1]:
import torch
from torch import nn

We will be implementing a custom seq2seq model with better attention mechanism in today's post. Ever since it was first introduced in encoder-decoder architectures, attention mechanisms have become increasingly important. It is also at the core of more advanced transformers moodels like BERT and GPT. Attention is a natural way of forcing the model to learn which portion of the input and output data correspond to each other. Gradually, with gradient descent, the model learns to see patterns in these sequential data. It is also worth noting that there are variations of the classic transformer attention architecture. Allen AI's Longformer, for instance, allows BERT-style tranformer models to process long sequences by distinguishing between local and global attention, combined with a sliding window approach. Ultimately, this allows the attention calculation to run in linear time, as opposed to quadratic run time.

The equation for attention can be written as follows:

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V
$$

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.heads = heads
        self.head_dim = embed_dim // heads
        
        if self.head_dim * heads != embed_dim:
            raise ValueError(`embed_dim` must be a multiple of `heads`)
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, values, keys, queries):
        N = query.size(0)