## Transformer 

This notebook is for understanding transformer architecture.

For more detail see: [d2l: Transformer Architecture](https://d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html)


- Transformer Architecture
  - Embeddings 
  - Positional FFN
  - LayerNorm
  - AddNorm
  - PositionalEncoding
  - Multihead Attention
  - TransformerEncoderBlock(single layer)
  - TransformerEncoder
  - TransformerDecoderBlock
  - TransformerDecoder
  - Transformer Model

In [2]:
import math
import torch
from torch import nn
import numpy as np

#### Embeddings

In [11]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        """ Embedding class to convert a word into embedding space."""
        super(Embedding, self).__init__()
        self.embed_dim = embed_dim
        self.embed = nn.Embedding(vocab_size, embed_dim) # vocab_size x embed_dim

    def forward(self, x):
        output = self.embed(x) * sqrt(self.embed_dim)
        return output 

In [12]:
e = Embedding(100, 512)
e

Embedding(
  (embed): Embedding(100, 512)
)

In [13]:
def check_shape(a, shape):
    assert a.shape == shape, f'tensor\'s shape {a.shape} != expected shape {shape}'

#### Positional FFN

PFFN transforms the representation at sequence positions using the same MLP(Multi layer perceptron). 

Here, input X with shape (batch size, number of time steps or sequence length in tokens, number of hidden units or feature dimension) will be transformed by a two-layer MLP into an output tensor of shape (batch size, number of time steps, ffn_num_outputs).

In [14]:
import warnings
warnings.filterwarnings('ignore')

In [15]:
class PositionWiseFFN(nn.Module): 
    """The positionwise feed-forward network."""
    def __init__(self, ffn_num_hiddens, ffn_num_outputs):
        super().__init__()

        # As we want to initialize weights later on we will use Lazy version 
        self.dense1 = nn.LazyLinear(ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.LazyLinear(ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

In [16]:
ffn = PositionWiseFFN(4, 8)
ffn.eval()

PositionWiseFFN(
  (dense1): LazyLinear(in_features=0, out_features=4, bias=True)
  (relu): ReLU()
  (dense2): LazyLinear(in_features=0, out_features=8, bias=True)
)

In [17]:
# example: to see tensors in PFFN
# As the same MLP transforms at all the positions
# when the inputs at all these positions are the same
# their outputs are also identical
ffn(torch.ones((2, 3, 4)))[0]

tensor([[ 0.4662, -0.2785, -0.1906, -0.0019, -0.0851,  0.4234, -0.0755,  0.2674],
        [ 0.4662, -0.2785, -0.1906, -0.0019, -0.0851,  0.4234, -0.0755,  0.2674],
        [ 0.4662, -0.2785, -0.1906, -0.0019, -0.0851,  0.4234, -0.0755,  0.2674]],
       grad_fn=<SelectBackward0>)

#### Layer Normalization

As we want to normalize each feature dim, LayerNorm is more suitable as it applies per-element scale and bias which makes it scale independence and batch size independence.
Unlike Batch Normalization which applies scalar scale and bias for each entire channel/plane.

In [18]:
ln = nn.LayerNorm(2)
X = torch.tensor([[10, 2], [2, 5]], dtype=torch.float32)
ln(X) # compute mean and var of X

tensor([[ 1.0000, -1.0000],
        [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>)

In [19]:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)

layer_norm(embedding[0][0])

tensor([ 0.1380, -1.9568,  0.5463,  1.5227,  0.4377,  0.5928,  0.7892,  0.0591,
        -1.2869, -0.8421], grad_fn=<NativeLayerNormBackward0>)

#### AddNorm

In [20]:
class AddNorm(nn.Module):
    """The residual connection followed by layer normalization."""
    def __init__(self, norm_shape, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [21]:
add_norm = AddNorm(4, 0.5)
shape = (2, 3, 4)

In [22]:
check_shape(add_norm(torch.ones(shape), torch.ones(shape)), shape)

#### Attention

In [23]:
class Attention(nn.Module): 
    """Basic scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

#### MultiHead Attention

In [24]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = Attention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    
    def forward(self, queries, keys, values, valid_lens):
        queries = self.transpose_qkv(self.W_q(queries)) # (b*h, n_q, h)
        keys = self.transpose_qkv(self.W_k(keys))       # (b*h, n_kv, h)
        values = self.transpose_qkv(self.W_v(values))   # (b*h, n_kv, h)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        output = self.attention(queries, keys, values, valid_lens) # (b*h, n_q, h)
        
        output_concat = self.transpose_output(output) # (b, n_q, n_h)
        return self.W_o(output_concat)

In [25]:
def add_to_class(Class):
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [26]:
@add_to_class(MultiHeadAttention)
def transpose_qkv(self, X):
    """Transposition for multiple attention heads."""
    # input X: (b, n_q or n_kv, h)
    # output: (b, n_q or n_kv, h, h)
    X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

In [27]:
@add_to_class(MultiHeadAttention)
def transpose_output(self, X):
    """Reverse the operation of transpose_qkv."""
    X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

#### Positional Encoding

In [28]:
class PositionalEncoding(nn.Module):
    """Positional encoding."""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

#### Transformer Encoder Block

In [29]:
class TransformerEncoderBlock(nn.Module):
    """The Transformer encoder block."""
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
                 use_bias=False):
        super().__init__()
        self.attention = MultiHeadAttention(num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens, dropout)

    def forward(self, X, valid_lens = torch.tensor([3, 2])):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

In [30]:
X = torch.ones((2, 100, 24))
valid_lens = torch.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
encoder_blk.eval()

TransformerEncoderBlock(
  (attention): MultiHeadAttention(
    (attention): Attention(
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (W_q): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_k): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_v): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_o): LazyLinear(in_features=0, out_features=24, bias=False)
  )
  (addnorm1): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
    (ln): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
  )
  (ffn): PositionWiseFFN(
    (dense1): LazyLinear(in_features=0, out_features=48, bias=True)
    (relu): ReLU()
    (dense2): LazyLinear(in_features=0, out_features=24, bias=True)
  )
  (addnorm2): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
    (ln): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
  )
)

In [31]:
class TransformerEncoder(nn.Module):
    """The Transformer encoder."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, use_bias=False):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerEncoderBlock(
                num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens = torch.tensor([3, 2])):
        # Since positional encoding values are between -1 and 1, the embedding
        # values are multiplied by the square root of the embedding dimension
        # to rescale before they are summed up
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

#### Transformer Decoding 

In [32]:
class TransformerDecoderBlock(nn.Module):
    # i-th block 
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
        super().__init__()
        self.i = i
        self.attention1 = MultiHeadAttention(num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.attention2 = MultiHeadAttention(num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(num_hiddens, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # Shape of dec_valid_lens: (batch_size, num_steps), where every
            # row is [1, 2, ..., num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None
        # Self-attention
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # Encoder-decoder attention
        # (batch_size, num_steps, num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

In [33]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                 num_blks, dropout):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.num_blks = num_blks
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), TransformerDecoderBlock(
                num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.LazyLinear(vocab_size)

    def init_state(self, enc_outputs, enc_valid_lens):
        return [enc_outputs, enc_valid_lens, [None] * self.num_blks]

    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            # Decoder self-attention weights
            self._attention_weights[0][
                i] = blk.attention1.attention.attention_weights
            # Encoder-decoder attention weights
            self._attention_weights[1][
                i] = blk.attention2.attention.attention_weights
        return self.dense(X), state

    def attention_weights(self):
        return self._attention_weights

In [34]:
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)

In [35]:
encoder

TransformerEncoder(
  (embedding): Embedding(200, 24)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (blks): Sequential(
    (block0): TransformerEncoderBlock(
      (attention): MultiHeadAttention(
        (attention): Attention(
          (dropout): Dropout(p=0.5, inplace=False)
        )
        (W_q): LazyLinear(in_features=0, out_features=24, bias=False)
        (W_k): LazyLinear(in_features=0, out_features=24, bias=False)
        (W_v): LazyLinear(in_features=0, out_features=24, bias=False)
        (W_o): LazyLinear(in_features=0, out_features=24, bias=False)
      )
      (addnorm1): AddNorm(
        (dropout): Dropout(p=0.5, inplace=False)
        (ln): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
      )
      (ffn): PositionWiseFFN(
        (dense1): LazyLinear(in_features=0, out_features=48, bias=True)
        (relu): ReLU()
        (dense2): LazyLinear(in_features=0, out_features=24, bias=True)
      )
      (addnorm2): AddN

In [36]:
decoder = TransformerDecoderBlock(24, 48, 8, 0.5, 0)

In [37]:
decoder

TransformerDecoderBlock(
  (attention1): MultiHeadAttention(
    (attention): Attention(
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (W_q): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_k): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_v): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_o): LazyLinear(in_features=0, out_features=24, bias=False)
  )
  (addnorm1): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
    (ln): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
  )
  (attention2): MultiHeadAttention(
    (attention): Attention(
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (W_q): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_k): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_v): LazyLinear(in_features=0, out_features=24, bias=False)
    (W_o): LazyLinear(in_features=0, out_features=24, bias=False)
  )
  (addnorm2): AddNorm(
    (dropout): Dropout(p=0.5, inplace=False)
  

#### Transformer Model

In [44]:
class Transformer(nn.Module):

    def __init__(self,
                 embed_dim,
                 src_vocab_size,
                 target_vocab_size,
                 ffn,
                 num_blocks=6,
                 heads=8,
                 dropout=0.2):
        super(Transformer, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.encoder = TransformerEncoder(
                        vocab_size=src_vocab_size, 
                        num_hiddens=embed_dim, 
                        ffn_num_hiddens=ffn, 
                        num_heads=heads,
                        num_blks=num_blocks, 
                        dropout=dropout)
        self.decoder = TransformerDecoder(
                        vocab_size=target_vocab_size, 
                        num_hiddens=embed_dim, 
                        ffn_num_hiddens=ffn, 
                        num_heads=heads, 
                        num_blks=num_blocks, 
                        dropout=dropout)
        self.fc_out = nn.Linear(embed_dim, target_vocab_size)

    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        # returns the lower triangular part of matrix filled with ones
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        )
        return trg_mask

    def forward(self, source, target):
        trg_mask = self.make_trg_mask(target)
        enc_out = self.encoder(source)
        outputs = self.decoder(target, enc_out, trg_mask)
        output = F.softmax(self.fc_out(outputs), dim=-1)
        return output

In [47]:
# test example
model = Transformer(embed_dim=512,
                    src_vocab_size=12,
                    target_vocab_size=12,
                    ffn=22,
                    num_blocks=6,
                    heads=8)

print(model)

Transformer(
  (encoder): TransformerEncoder(
    (embedding): Embedding(12, 512)
    (pos_encoding): PositionalEncoding(
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (blks): Sequential(
      (block0): TransformerEncoderBlock(
        (attention): MultiHeadAttention(
          (attention): Attention(
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (W_q): LazyLinear(in_features=0, out_features=512, bias=False)
          (W_k): LazyLinear(in_features=0, out_features=512, bias=False)
          (W_v): LazyLinear(in_features=0, out_features=512, bias=False)
          (W_o): LazyLinear(in_features=0, out_features=512, bias=False)
        )
        (addnorm1): AddNorm(
          (dropout): Dropout(p=0.2, inplace=False)
          (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (ffn): PositionWiseFFN(
          (dense1): LazyLinear(in_features=0, out_features=22, bias=True)
          (relu): ReLU()
          (dense2): LazyLin