# Attention is all you need
This notebook breaks down and implements the seminal paper "Attention is all you need". I felt like a lot of existing examples left the actual implementation details (e.g. dimensions of each input/output and the actual dataflow) a bit unclear so I decided to code my own. A cleaner version of the code is available in transformer.py. For the paper, focus on section 3 for implementing the net and section 5 for training it.

## Model Breakdown

## Implementation Notes
- p8, residual dropout. We apply dropout [ 33] to the output of each sub-layer, before it is added to the
sub-layer input and normalized. In addition, we apply dropout to the sums of the embeddings and the
positional encodings in both the encoder and decoder stacks. For the base model, we use a rate of
Pdrop = 0.1.
- p5 " In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation... in the embedding layers, we multiply those weights by √dmodel"

## Parameters
**General Model Parameters**
- Input vocabulary composed of $n$ words: $x \in \mathbb{R}^{n}$
- Output vocabulary composed of $m$ words: $y \in \mathbb{R}^{m}$
- $d_{model}=512$ TODO idk what this means exactly
  
**Attention**
- $h$ is n_heads in the code, i.e. the # of heads in multihead attention
- $ d_k = d_{model} / h $
- $h$ must be chosen to be a factor of $d_{model}$
- $ d_v = d_k $
- $ W_i^Q \in \mathbb{R}^{d_\text{model} \times d_k }, i=\{1,...,h\} $
- $ W_i^K \in \mathbb{R}^{d_\text{model} \times d_k }, i=\{1,...,h\}  $
- $ W_i^V \in \mathbb{R}^{d_\text{model} \times d_v }, i=\{1,...,h\}  $
- $ W^O \in \mathbb{R}^{h d_v \times d_\text{model} }, i=\{1,...,h\}  $
- Query vector: $Q \in \mathbb{R}^{1 \times d_\text{model}}$ (I assume)
- Key vector: $K \in \mathbb{R}^{1 \times d_\text{model}}$ (I assume)
- value vector: $V \in \mathbb{R}^{1 \times d_\text{model}}$ (I assume)

In [1]:
import copy
from math import sqrt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## Attention
$ \text{{Attention}}(Q, K, V) = \text{{softmax}}\left(\frac{{QK^T}}{{\sqrt{d_k}}}\right)V $

In [24]:
def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None,
) -> torch.Tensor:
    """Implements section 3.2.1

    Parameters
    ----------
    d_k : int
        d_model // # heads
    query : torch.Tensor
        shape (batch_size, h, _, d_k). _ is some value on [1,n] where n = size of input vocab
    key : torch.Tensor
        shape (batch_size, h, n, d_k)
    value : torch.Tensor
        shape (batch_size, h, n, d_v) (remember d_k == d_v)
    mask : torch.Tensor, optional
        shape (batch_size, 1, 1, n)

    Returns
    -------
    torch.Tensor
        shape (batch_size, h, _, d_v), same as query (also recall d_k == d_v)

    TODO
    -----
    - I'm not sure if the size n for the dimensions is correct
    - unittest against existing implementations
    """
    d_k = query.shape[-1]
    # QK^T from the paper is of dimensions 1,d_k x d_k = 1
    # MatMul for dim > 2 matches the first dim-2 dimensions and matrix multiplies the last 2
    # meaning to get query and key to match we need to do key.transpose(-2, -1)
    # query: (batch_size, h, _, d_k), key: (batch_size, h, n, d_k), key.transpose(-2, -1): (batch_size, h, d_k, n)
    # x: (batch_size, h, _, n)
    x = query @ key.transpose(-2, -1) / sqrt(d_k)
    if mask is not None:
        x = x.masked_fill(mask == 0, -1e9)
    # dim=-1 bc the last dim corresponds to the # of input vocab words for our elements xi
    # and we want to run the softmax sum over that dimension
    return F.softmax(x, dim=-1) @ value


# testing functionality
d_model = 512
batch_size = 1
n = 10
h = 8
d_k = d_model // h
d_v = d_k
d_q = 5
q = torch.tensor(np.random.random((batch_size, h, d_q, d_k)))
k = torch.tensor(np.random.random((batch_size, h, n, d_k)))
v = torch.tensor(np.random.random((batch_size, h, n, d_v)))
mask = torch.tensor(np.random.random((batch_size, 1, 1, n)))
assert scaled_dot_product_attention(q, k, v, mask).shape == (batch_size, h, d_q, d_v)

q = torch.tensor(np.random.random((batch_size, d_q, d_k)))
k = torch.tensor(np.random.random((batch_size, n, d_k)))
v = torch.tensor(np.random.random((batch_size, n, d_v)))
mask = torch.tensor(np.random.random((batch_size, 1, n)))
assert scaled_dot_product_attention(q, k, v, mask).shape == (batch_size, d_q, d_v)

In [36]:
class MultiHeadAttention(nn.Module):
    """Implements section 3.2.2"""

    def __init__(self, d_model: int, h: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert d_model % h == 0, "h must be a factor of d_model"
        self.d_model = d_model
        self.h = h
        self.d_k = d_model // h
        self.d_v = d_k
        
        # W matrices from section 3.2.2
        # self.w_q = nn.Linear(d_model, d_k * h, bias=False)
        # self.w_k = nn.Linear(d_model, d_k * h, bias=False)
        # self.w_v = nn.Linear(d_model, d_v * h, bias=False)
        # self.w_o = nn.Linear(h * d_v, d_model, bias=False)

        # self.layers = nn.ModuleList(
        #     [copy.deepcopy(layer) for _ in range(N_ENCODER_LAYERS)]
        # )
        self.w_q = nn.ModuleList([nn.Linear(d_model, d_k, bias=False) for _ in range(h)])
        self.w_k = nn.ModuleList([nn.Linear(d_model, d_k, bias=False) for _ in range(h)])
        self.w_v = nn.ModuleList([nn.Linear(d_model, d_v, bias=False) for _ in range(h)])
        self.w_o = nn.Linear(h * d_v, d_model, bias=False)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        query : torch.Tensor
            shape (batch_size, _, d_model) (_ <= n)
        key : torch.Tensor
            shape (batch_size, n, d_model)
        value : torch.Tensor
            shape (batch_size, n, d_model)
        mask : torch.Tensor, optional
            shape (batch_size, 1, n)

        Returns
        -------
        torch.Tensor
            shape (batch_size, _, d_model)

        TODO
        ----
        - I'm not sure if size n for the dimensions is correct
        - unittest against existing implementations
        """
        if mask is not None:
            mask = mask[:, None, :]
        q = torch.stack([self.w_q[i](query) for i in range(self.h)], dim=1)
        k = torch.stack([self.w_k[i](key) for i in range(self.h)], dim=1)
        v = torch.stack([self.w_v[i](value) for i in range(self.h)], dim=1)
        x = scaled_dot_product_attention(q, k, v, mask)
        # x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_v)
        return self.w_o(torch.cat([x[:,i] for i in range(self.h)], dim=-1))


# testing functionality
d_model = 512
batch_size = 1
n = 10
h = 8
n_q = 5
q = torch.tensor(np.random.random((batch_size, n_q, d_model)), dtype=torch.float32)
k = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
v = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
mask = torch.tensor(np.random.random((batch_size, 1, n)), dtype=torch.float32)
# train network for 1 iteration
mha = MultiHeadAttention(d_model, h)
_ = mha(q, k, v, mask)

## Position-wise Feed-Forward Networks
Section 3.3  
$\text{FFN}(x) = \text{max}(0,xW_1 + b_1)W_2 + b_2$  
Input and output is of size $d_{\text{model}}=512$ and the inner layer has dimensionality $d_{ff}=2048$

In [43]:
class PositionwiseFeedForward(nn.Module):
    """Implements section 3.3"""

    def __init__(self, d_model: int, d_ff: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.d_model = d_model
        self.d_ff = d_ff
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            shape (batch_size, _, d_model)

        Returns
        -------
        torch.Tensor
            shape (batch_size, _, d_model)

        TODO
        ----
        - unittest against existing implementations
        """
        return self.ff(x)
    

# testing functionality
d_model = 512
d_ff = 2048
batch_size = 1
n = 10
x = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
ff = PositionwiseFeedForward(d_model, d_ff)
assert ff(x).shape == x.shape

## Residual Connection
Residual connection + layer normalization + dropout

In [44]:
class ResidualConnection(nn.Module):
    """Implements the residual connection along with layer normalization and dopout

    pg. 3 "we employ a residual connection...around each of the two sub-layers,
    followed by layer normalization"
    """

    def __init__(self, d_model: int, p_dropout: float, *args, **kwargs):
        """
        Parameters
        ----------
        d_model : int
            size of the model
        p_dropout : float
            dropout probability in range [0, 1]
        """
        super().__init__(*args, **kwargs)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        # x + dropout(sublayer(x)) implements the residual connection
        return self.layer_norm(x + self.dropout(sublayer(x)))
    

# testing functionality
d_model = 512
p_dropout = 0.1
batch_size = 1
n = 10
x = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
sublayer = nn.Linear(d_model, d_model)
rc = ResidualConnection(d_model, p_dropout)
assert rc(x, sublayer).shape == x.shape

## Encoder

In [55]:
class EncoderLayer(nn.Module):
    """Encoder is composed of self-attention and feed forward

    TODO
    ----
    - simplify the self attention stuff
    """

    def __init__(self, self_attn: nn.Module, feed_forward: nn.Module, residual_connection: nn.Module, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.residual_connection = residual_connection

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.residual_connection(x, lambda x : self.self_attn(x, x, x, mask))
        return self.residual_connection(x, self.feed_forward)
    

# testing functionality
d_model = 512
h = 8
d_ff = 2048
p_dropout = 0.1
batch_size = 1
n = 10
x = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
mask = torch.tensor(np.random.random((batch_size, 1, n)), dtype=torch.float32)
self_attn = MultiHeadAttention(d_model, h)
feed_forward = PositionwiseFeedForward(d_model, d_ff)
residual_connection = ResidualConnection(d_model, p_dropout)
encoder_layer = EncoderLayer(self_attn, feed_forward, residual_connection)
assert encoder_layer(x, mask).shape == x.shape

In [58]:
class Encoder(nn.Module):
    """Transformer encoder module

    pg 3, figure 1: this class implements the left half of the diagram
    """

    def __init__(self, layer: nn.Module, n_encoder_layers: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layers = nn.ModuleList(
            [copy.deepcopy(layer) for _ in range(n_encoder_layers)]
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Forward pass for encoder"""
        for layer in self.layers:
            x = layer(x, mask)
        return x
    

# testing functionality
d_model = 512
h = 8
d_ff = 2048
p_dropout = 0.1
n_encoder_layers = 6
batch_size = 1
n = 10
x = torch.tensor(np.random.random((batch_size, n, d_model)), dtype=torch.float32)
mask = torch.tensor(np.random.random((batch_size, 1, n)), dtype=torch.float32)
self_attn = MultiHeadAttention(d_model, h)
feed_forward = PositionwiseFeedForward(d_model, d_ff)

encoder_layer = EncoderLayer(self_attn, feed_forward, ResidualConnection(d_model, p_dropout))
encoder = Encoder(encoder_layer, n_encoder_layers)
assert encoder(x, mask).shape == x.shape

## Decoder

In [59]:
class LinearSoftmax(nn.Module):
    """Linear + Softmax Layer to map the decoder output to the vocabulary

    pg 3, figure 1: implements linear and softmax layers at the top right of the diagram
    """

    def __init__(self, d_model: int, n_vocab: int, *args, **kwargs):
        """
        Parameters
        ----------
        d_model : int
            size of the model
        n_vocab : int
            size of the vocabulary
        """
        super().__init__(*args, **kwargs)
        self.proj = nn.Linear(d_model, n_vocab)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.softmax(self.proj(x), dim=-1)
    

# testing functionality

## Final Model

In [37]:
# Hyperparameters pulled from paper
N_ENCODER_LAYERS = 6  # pg 3
N_DECODER_LAYERS = 6  # pg 3
D_MODEL = 512  # pg 3 TODO explain this better
P_DROPOUT = 0.1  # pg 8
N_HEADS = 8  # pg 5, number of paralell attention layers
D_K = (
    D_MODEL / N_HEADS
)  # pg 5. dimension of key projection parameter matrix for multi-head attention
D_V = D_K  # pg 5. dimension of value projection mat for mult-head attn
D_FF = 2048  # pg. 5 dimension of feed forward network


class EncoderDecoder(nn.Module):
    """Vanilla Encoder-Decoder NN Architecture

    TODO
    ----
    - understand masking
    - understand what src_embed and tgt_embed are
    - add type hints for src_embed and tgt_embed
    - rename variables with "_" if they should be private
    - understand what the generator is and does
    """

    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        src_embed,
        tgt_embed,
        linear_softmax,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.linear_softmax = linear_softmax

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor,
        tgt_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Performs forward pass i.e. computation at every call

        Parameters
        ----------
        src : torch.Tensor
            source input
        tgt : torch.Tensor
            target input
        src_mask : torch.Tensor
            source mask
        tgt_mask : torch.Tensor
            target mask

        Returns
        -------
        torch.Tensor
            output of forward pass
        """
        return self.decoder(
            self.tgt_embed(tgt),
            self.encoder(self.src_embed(src), src_mask),
            src_mask,
            tgt_mask,
        )