# Transformer Architecture

![Transformer](https://miro.medium.com/v2/resize:fit:760/1*2vyKzFlzIHfSmOU_lnQE4A.png)


In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

import math
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
k = torch.tensor([  float(" -inf"),   float(" -inf"), float(" -inf"), float(" -inf")])
F.softmax(k)

  F.softmax(k)


tensor([nan, nan, nan, nan])

In [None]:
k = torch.randn(32, 10, 512)
k.view(32, 10, 8, -1).shape
k = torch.tril(k)
k = k.masked_fill(k==0, float('-Inf'))
F.softmax(k, dim=2)

tensor([[[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.6247, 0.3753, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.2739, 0.1008, 0.6253,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0757, 0.2967, 0.1330,  ..., 0.0000, 0.0000, 0.0000],
         [0.1622, 0.0216, 0.3364,  ..., 0.0000, 0.0000, 0.0000],
         [0.0973, 0.1536, 0.0484,  ..., 0.0000, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.9084, 0.0916, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.3221, 0.1061, 0.5718,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0555, 0.1129, 0.2225,  ..., 0.0000, 0.0000, 0.0000],
         [0.0434, 0.0856, 0.0370,  ..., 0.0000, 0.0000, 0.0000],
         [0.0826, 0.4050, 0.0066,  ..., 0.0000, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.4215, 0.5785, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.3405, 0.5131, 0.1464,  ..., 0.0000, 0.0000, 0.

## Positional Encoding

<img src="https://miro.medium.com/v2/resize:fit:640/format:webp/1*m2SB7rpbHdL9UGCFySE40w.png">

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, seq_len, embed_dim):
    super(PositionalEncoding, self).__init__()

    pe = torch.zeros(seq_len, embed_dim)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)

    i = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
    pe[:, 0::2] = torch.sin(position * i)
    pe[:, 1::2] = torch.cos(position * i)

    pe = pe.unsqueeze(0)

    self.register_buffer('pe', pe)

  def forward(self, x):
    self.pe = self.pe.to(x.device)
    return x + self.pe[:, :x.size(1), :]


In [None]:
p = torch.randn(32, 10, 2)
pe = PositionalEncoding(10, 2)
pe(p).shape

torch.Size([32, 10, 2])

In [None]:
p = torch.randn(1, 10, 2)
pe = PositionalEncoding(10, 2)
pe(p)- p

tensor([[[ 0.0000,  1.0000],
         [ 0.8415,  0.5403],
         [ 0.9093, -0.4161],
         [ 0.1411, -0.9900],
         [-0.7568, -0.6536],
         [-0.9589,  0.2837],
         [-0.2794,  0.9602],
         [ 0.6570,  0.7539],
         [ 0.9894, -0.1455],
         [ 0.4121, -0.9111]]])

## 1. MultiHead Attention

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, n_heads):
    super(MultiHeadAttention, self).__init__()

    self.embed_dim = embed_dim
    self.n_heads = n_heads
    self.single_head_dim = embed_dim // n_heads

    self.query_m = nn.Linear(embed_dim, embed_dim, bias=False)
    self.key_m = nn.Linear(embed_dim, embed_dim, bias=False)
    self.value_m = nn.Linear(embed_dim, embed_dim, bias=False)

    self.output = nn.Linear(embed_dim, embed_dim)

  def forward(self, query, key, value, mask=None):
    batch_size = key.shape[0]
    seq_len = key.shape[1]
    query_seq_len = query.shape[1]

    q = self.query_m(query) # (batch_size, seq_len, embed_dim)
    k = self.key_m(key)     # (batch_size, seq_len, embed_dim)
    v = self.value_m(value) # (batch_size, seq_len, embed_dim)
    # print(f"Shapes of query, key, value are {q.shape, k.shape, v.shape}")

    q = q.view(batch_size, query_seq_len, self.n_heads, self.single_head_dim) # (batch_size, query_seq_len, n_heads, single_head_dim)
    k = k.view(batch_size, seq_len, self.n_heads, self.single_head_dim) # (batch_size, seq_len, n_heads, single_head_dim)
    v = v.view(batch_size, seq_len, self.n_heads, self.single_head_dim) # (batch_size, seq_len, n_heads, single_head_dim)
    # print(f"Shapes of q, k, v are {q.shape, k.shape, v.shape}")

    q = q.transpose(1, 2) # (batch_size, n_heads, seq_len, single_head_dim)
    k = k.transpose(1, 2) # (batch_size, n_heads, seq_len, single_head_dim)
    v = v.transpose(1, 2) # (batch_size, n_heads, seq_len, single_head_dim)
    # print(f"Shapes of q, k, v are {q.shape, k.shape, v.shape}")

    qk = torch.matmul(q, k.transpose(-1, -2)) # (batch_size, n_heads, seq_len, seq_len)
    # print(f"Shape of qk is{qk.shape}")
    # print(f"Before masking {qk}")

    if mask is not None:
      qk = qk.masked_fill(mask == 0, -1e20)
      # qk = qk.masked_fill(mask == 0, float("-inf"))
    # print(f"After masking {qk}")

    qk = qk/(self.single_head_dim ** 0.5)

    attn_scores = F.softmax(qk, dim=-1) # (batch_size, n_heads, seq_len, seq_len)
    # print(f"After softmax {attn_scores}")

    attn_scores = torch.matmul(attn_scores, v) # (batch_size, n_heads, seq_len, single_head_dim)
    # print(f"Shape of attention scores is{att_scores.shape}")
    # print(f"Attention {attn_scores}")

    out = attn_scores.transpose(1, 2).contiguous() # (batch_size, seq_len, n_heads, single_head_dim)
    out = out.view(batch_size, query_seq_len, self.n_heads*self.single_head_dim) # (batch_size, seq_len, embed_dim)

    out = self.output(out)  # (batch_size, seq_len, embed_dim)

    return out

In [None]:
m = MultiHeadAttention(512, 8)
m

MultiHeadAttention(
  (query_m): Linear(in_features=512, out_features=512, bias=False)
  (key_m): Linear(in_features=512, out_features=512, bias=False)
  (value_m): Linear(in_features=512, out_features=512, bias=False)
  (output): Linear(in_features=512, out_features=512, bias=True)
)

In [None]:
q = torch.randn(32, 10, 512)
v = torch.randn(32, 10, 512)
k = torch.randn(32, 10, 512)
m(q, v , k).shape

torch.Size([32, 10, 512])

## 2. Encoder

<img src="https://www.researchgate.net/profile/Ehsan-Amjadian/publication/352239001/figure/fig1/AS:1033334390013952@1623377525434/Detailed-view-of-a-transformer-encoder-block-It-first-passes-the-input-through-an.jpg" height="480">

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, embed_dim, n_heads, hidden_factor=4, dropout=0.1):
    super(EncoderLayer, self).__init__()

    self.self_attn = MultiHeadAttention(embed_dim, n_heads)
    self.feed_fwd = nn.Sequential(
        nn.Linear(embed_dim, embed_dim*hidden_factor),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(embed_dim*hidden_factor, embed_dim)
    )
    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.drop = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    attn = self.self_attn(x, x, x, mask=None) # (batch_size, seq_len, embed_dim)
    x = x + self.drop(attn) # attentionresidual
    x = self.norm1(x) # attention residual normalization

    feed_forward = self.feed_fwd(x)
    # (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim*hidden_factor) -> (batch_size, seq_len, embed_dim)
    x = x + self.drop(feed_forward) # feed forward residual
    x = self.norm2(x) # feed forward residual normalization

    return x

In [None]:
class Encoder(nn.Module):
  def __init__(self, input_vocab, embed_dim, seq_len, n_heads=8, num_layers=2, hidden_factor=4, dropout=0.1):
    super(Encoder, self).__init__()

    self.embedding = nn.Embedding(input_vocab, embed_dim)
    self.pe = PositionalEncoding(seq_len, embed_dim)
    self.layers = nn.ModuleList([EncoderLayer(embed_dim, n_heads, hidden_factor, dropout) for _ in range(num_layers)])
    self.drop = nn.Dropout(dropout)

  def forward(self, src, mask=None):
    x = self.embedding(src) # (batch_size, seq_len) -> (batch_size, seq_len, embed_dim)
    x = self.pe(x) # (batch_size, seq_len, embed_dim)
    x = self.drop(x)

    for layer in self.layers:
      x = layer(x, mask)

    return x

In [None]:
enc = Encoder(5500, 512, 10, 8, 2, 4, 0.1)
enc

Encoder(
  (embedding): Embedding(5500, 512)
  (pe): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (query_m): Linear(in_features=512, out_features=512, bias=False)
        (key_m): Linear(in_features=512, out_features=512, bias=False)
        (value_m): Linear(in_features=512, out_features=512, bias=False)
        (output): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_fwd): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=2048, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
    )
  )
  (drop): Dropout(p=0.1, inplace=False)
)

In [None]:
s = torch.randint(0, 5500, (32, 10))
so = enc(s, None)
so.shape

torch.Size([32, 10, 512])

## 2. Decoder

<img src="https://discuss.pytorch.org/uploads/default/optimized/3X/8/e/8e5d039948b8970e6b25395cb207febc82ba320a_2_177x500.png" height="480">

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, embed_dim, n_heads, hidden_factor=4, dropout=0.1):
    super(DecoderLayer, self).__init__()

    self.self_attn = MultiHeadAttention(embed_dim, n_heads)
    self.cross_attn = MultiHeadAttention(embed_dim, n_heads) # Encoder-Decoder Attention
    self.feed_fwd = nn.Sequential(
        nn.Linear(embed_dim, embed_dim*hidden_factor),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(embed_dim*hidden_factor, embed_dim)
    )
    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.norm3 = nn.LayerNorm(embed_dim)
    self.drop = nn.Dropout(dropout)

  def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
    self_attention = self.self_attn(x, x, x, tgt_mask)
    x = x + self.drop(self_attention) # residual
    x = self.norm1(x) # normalization

    cross_attention = self.cross_attn(x, enc_output, enc_output)
    x = x + self.drop(cross_attention) # residual
    x = self.norm2(x)

    feed_forward = self.feed_fwd(x)
    x = x + self.drop(feed_forward)
    x = self.norm3(x)

    return x

In [None]:
class Decoder(nn.Module):
  def __init__(self, output_vocab, embed_dim, seq_len, n_heads=8, num_layers=2, hidden_factor=4, dropout=0.1):
    super(Decoder, self).__init__()

    self.embedding = nn.Embedding(output_vocab, embed_dim)
    self.pe = PositionalEncoding(seq_len , embed_dim)
    self.layers = nn.ModuleList([DecoderLayer(embed_dim, n_heads, hidden_factor, dropout) for _ in range(num_layers)])
    self.output = nn.Linear(embed_dim , output_vocab)
    self.drop = nn.Dropout(dropout)

  def forward(self, tgt, enc_output, src_mask=None, tgt_mask=None):
    x = self.embedding(tgt)
    x = self.pe(x)
    x = self.drop(x)

    for layer in self.layers:
      x = layer(x, enc_output, src_mask, tgt_mask)

    x = self.output(x)
    x = F.softmax(x, dim=-1)

    return x

In [None]:
dec = Decoder(5500, 512, 10, 8, 2, 4, 0.1)
dec

Decoder(
  (embedding): Embedding(5500, 512)
  (pe): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x DecoderLayer(
      (self_attn): MultiHeadAttention(
        (query_m): Linear(in_features=512, out_features=512, bias=False)
        (key_m): Linear(in_features=512, out_features=512, bias=False)
        (value_m): Linear(in_features=512, out_features=512, bias=False)
        (output): Linear(in_features=512, out_features=512, bias=True)
      )
      (cross_attn): MultiHeadAttention(
        (query_m): Linear(in_features=512, out_features=512, bias=False)
        (key_m): Linear(in_features=512, out_features=512, bias=False)
        (value_m): Linear(in_features=512, out_features=512, bias=False)
        (output): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_fwd): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=2048, out_

In [None]:
t = torch.randint(0, 5500, (32, 10))
F.softmax(dec(t, so, None, None), dim=-1).shape

torch.Size([32, 10, 5500])

### Masking

In [None]:
def create_mask(src, tgt, pad_token):
  src_mask = (src != pad_token).unsqueeze(-2)
  tgt_mask = (tgt != pad_token).unsqueeze(-2)

  seq_len = tgt.size(1)
  mask = torch.tril(torch.ones(1, seq_len, seq_len)).bool()

  tgt_mask = tgt_mask & mask
  tgt_mask = tgt_mask.unsqueeze(1)
  src_mask = src_mask.unsqueeze(1)

  return src_mask, tgt_mask

In [None]:
s = torch.randint(1, 10, (2, 4))
s

tensor([[8, 9, 8, 5],
        [8, 7, 7, 3]])

In [None]:
sm, tm = create_mask(s, s, 5)
sm.shape, tm.shape

(torch.Size([2, 1, 1, 4]), torch.Size([2, 1, 4, 4]))

In [None]:
sm

tensor([[[[ True,  True,  True, False]]],


        [[[ True,  True,  True,  True]]]])

In [None]:
tm

tensor([[[[ True, False, False, False],
          [ True,  True, False, False],
          [ True,  True,  True, False],
          [ True,  True,  True, False]]],


        [[[ True, False, False, False],
          [ True,  True, False, False],
          [ True,  True,  True, False],
          [ True,  True,  True,  True]]]])

In [None]:
pad_token = 0
src = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])  # Source sequences with padding
tgt = torch.tensor([[1, 2, 0, 0], [4, 5, 6, 0]])  # Target sequences with padding

# Create masks
src_mask, tgt_mask = create_mask(src, tgt, pad_token)

# Instantiate the multi-head attention layer
embed_dim = 8
n_heads = 2
mha = MultiHeadAttention(embed_dim, n_heads)

# Example input embeddings (randomly initialized for demonstration)
query = torch.randn(2, 4, embed_dim)
key = torch.randn(2, 4, embed_dim)
value = torch.randn(2, 4, embed_dim)

# output = mha(query, key, value, tgt_mask)
# print("Output:\n", output)
# print("Source Mask:\n", src_mask)
# print("Target Mask:\n", tgt_mask)

In [None]:
src.shape, tgt.shape

(torch.Size([2, 4]), torch.Size([2, 4]))

In [None]:
tm.shape

torch.Size([2, 1, 4, 4])

In [None]:
m = MultiHeadAttention(16, 4)
q = torch.randn(2, 4, 16)
sm, tm = create_mask(src, tgt, 0)
m(q, q, q, tm).shape

torch.Size([2, 4, 16])

In [None]:
class Transformer(nn.Module):
  def __init__(self, input_vocab, output_vocab, seq_len, embed_dim=512, n_heads=8, num_layers=2, hidden_factor=4, dropout=0.1):
    super(Transformer, self).__init__()

    self.encoder = Encoder(input_vocab, embed_dim, seq_len, n_heads, num_layers, hidden_factor, dropout)
    self.decoder = Decoder(output_vocab, embed_dim, seq_len, n_heads, num_layers, hidden_factor, dropout)

  def forward(self, src, tgt, src_mask=None, tgt_mask=None):
    enc_output = self.encoder(src)
    output = self.decoder.forward(tgt, enc_output, src_mask, tgt_mask)

    return output

In [None]:
tr = Transformer(5500, 6000, 10, 512, 8, 2, 4, 0.1)
tr

Transformer(
  (encoder): Encoder(
    (embedding): Embedding(5500, 512)
    (pe): PositionalEncoding()
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (self_attn): MultiHeadAttention(
          (query_m): Linear(in_features=512, out_features=512, bias=False)
          (key_m): Linear(in_features=512, out_features=512, bias=False)
          (value_m): Linear(in_features=512, out_features=512, bias=False)
          (output): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_fwd): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (drop): Dropout(p=0.1, inplace=False)
      )
    )
    (drop): Dropout(p=0.1, inplace=False)
 

In [None]:
s = torch.randint(0, 5500, (32, 10))
t = torch.randint(0, 6000, (32, 10))
sm, tm = create_mask(s, t, 0)
tr(s, t, src_mask=None, tgt_mask=tm).shape

torch.Size([32, 10, 6000])

In [None]:
!pip install torchinfo



In [None]:
import torchinfo

In [None]:
torchinfo.summary(tr)

Layer (type:depth-idx)                        Param #
Transformer                                   --
├─Encoder: 1-1                                --
│    └─Embedding: 2-1                         2,816,000
│    └─PositionalEncoding: 2-2                --
│    └─ModuleList: 2-3                        --
│    │    └─EncoderLayer: 3-1                 3,150,848
│    │    └─EncoderLayer: 3-2                 3,150,848
│    └─Dropout: 2-4                           --
├─Decoder: 1-2                                --
│    └─Embedding: 2-5                         3,072,000
│    └─PositionalEncoding: 2-6                --
│    └─ModuleList: 2-7                        --
│    │    └─DecoderLayer: 3-3                 4,200,960
│    │    └─DecoderLayer: 3-4                 4,200,960
│    └─Linear: 2-8                            3,078,000
│    └─Dropout: 2-9                           --
Total params: 23,669,616
Trainable params: 23,669,616
Non-trainable params: 0