## Decoder part of Attention is all you need Paper

In [1]:
# MODEL CONFIGURATION

NUM_ATTENTION_HEADS = 8
D_MODEL = 512
DROP_PROB = 0.2
BATCH_SIZE = 32
MAX_SEQUENCE_LENGTH = 200
FFN_HIDDEN = 2048
NUM_LAYERS = 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

## Scaled Dot Product

##### transpose torch

In [3]:
## Conceptual code for scaled dot product

key_ = torch.rand(BATCH_SIZE, NUM_ATTENTION_HEADS, MAX_SEQUENCE_LENGTH, D_MODEL // NUM_ATTENTION_HEADS)

key_transposed_ = key_.transpose(2, 3) # we want to transpose the dim 2 and dim 3 also we can write transpose(-2, -1)

key_.shape, key_transposed_.shape

(torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 64, 200]))

##### masked fill torch

In [4]:
mask = torch.ones(5, 5)
mask = torch.triu(mask, diagonal=1)
mask

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

In [5]:
masked = mask.masked_fill(mask == 1, -torch.inf)

masked

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [6]:
attention_scores_ = torch.randn(2, 1, 5, 5)
mask_ = torch.full((5, 5), -torch.inf)
mask_ = torch.triu(mask_, diagonal=1)

attention_scores_masked_ = attention_scores_ + mask_

attention_scores_masked_ # 2 batch 1 head 5 seq 5 seq 

tensor([[[[ 2.4976,    -inf,    -inf,    -inf,    -inf],
          [-0.9841,  0.3592,    -inf,    -inf,    -inf],
          [ 0.1644,  0.6086,  0.8830,    -inf,    -inf],
          [-0.3268, -0.1089, -0.5998, -0.5775,    -inf],
          [-0.9658,  0.1995,  1.8666, -0.4802, -1.0530]]],


        [[[ 2.2161,    -inf,    -inf,    -inf,    -inf],
          [-0.5256,  1.7169,    -inf,    -inf,    -inf],
          [-0.2487,  0.8131,  1.1253,    -inf,    -inf],
          [ 0.2340,  0.0390,  0.6614, -0.4125,    -inf],
          [ 1.4189, -0.0896, -0.0468, -1.7056,  0.0746]]]])

### Scaled Dot Product

In [7]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention:
    
    Args:
       query: Tensor of shape (batch_size, num_heads, seq_len, d_k)
       key: Tensor of shape (batch_size, num_heads, seq_len, d_k)
       value: Tensor of shape (batch_size, num_heads, seq_len, d_v)
       mask: True or False   
    Returns:
       attention_output: Attention-Weighted values with matmul to value (batch_size, num_heads, seq_len, d_v)
       attention_weights: Attention-Weighted Values
    """
    d_k = query.size(-1) # get the dimension of query
    attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # scaled dot product attention with key matrix transposed
    if mask is not None:
        attention_scores = attention_scores + mask
        print(f"Mask shape: {mask.shape}")

   
    attention_weights = F.softmax(attention_scores, dim=-1)
    attention_output = torch.matmul(attention_weights, value)
    
    # log
    print(f"Query Key Value shape in Self Attention: {query.shape}, {key.shape}, {value.shape}")
    print(f"Attention scores shape in Self Attention: {attention_scores.shape}")
    print(f"Attention weights shape in Self Attention: {attention_weights.shape}")
    print(f"Attention output shape in Self Attention: {attention_output.shape}")
        
    return attention_output, attention_weights

In [8]:
#### testing
query = key = value = torch.rand(BATCH_SIZE, NUM_ATTENTION_HEADS, MAX_SEQUENCE_LENGTH, D_MODEL // NUM_ATTENTION_HEADS)

mask = torch.full((MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH), float("-inf"))
mask = torch.triu(mask, diagonal=1)

attention_output = scaled_dot_product_attention(query, key, value, mask)

Mask shape: torch.Size([200, 200])
Query Key Value shape in Self Attention: torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64])
Attention scores shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention weights shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention output shape in Self Attention: torch.Size([32, 8, 200, 64])


In [9]:
x = torch.randn((BATCH_SIZE, MAX_SEQUENCE_LENGTH, D_MODEL))
y = torch.randn((BATCH_SIZE, MAX_SEQUENCE_LENGTH, D_MODEL))

mask = torch.full((MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH), float("-inf"))
mask = torch.triu(mask, diagonal=1)

## MultiHead Attention Layer

The **Multi-Head Attention** mechanism is a fundamental part of the Transformer architecture, as described in the paper **"Attention Is All You Need"**. This module allows the model to focus on different positions of the input simultaneously, improving its ability to model relationships in sequences.

## Key Concepts

### 1. Purpose of Multi-Head Attention

- The goal of Multi-Head Attention is to allow the model to attend to different parts of the input sequence at different positions and in parallel. 
- This allows the model to extract more information by applying multiple attention heads, each focusing on different aspects of the input sequence.

### 2. Linear Projections

- The first step in the multi-head attention mechanism is the linear projection of the input into three different vectors: Queries (\( Q \)), Keys (\( K \)), and Values (\( V \)).
- The dimensions of the projections are given by:
  $$
  Q = X W_Q, \quad K = X W_K, \quad V = X W_V
  $$
  where \( X \) is the input matrix, and \( W_Q \), \( W_K \), \( W_V \) are learnable weight matrices.

### 3. Scaled Dot-Product Attention

- The core operation of attention is the dot-product between the query and the key, scaled by the square root of the dimension of the key $\sqrt{d_k}$:
  $$
  \text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V
  $$
  where \( d_k \) is the dimension of the key vectors. The softmax function ensures that the attention weights are normalized.

### 4. Multi-Head Attention

- Once the attention scores are computed, the attention outputs for each head are concatenated:
  $$
  \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O
  $$
  where $$\text{head}_i = \text{Attention}(Q_i, K_i, V_i)$$, and \( W_O \) is the output weight matrix.

- The concatenated output is then passed through a final linear layer to project it back to the original dimension $d_{\text{model}}$.

##### view torch

In [10]:
query_ = torch.randn((32, 8, 200, 64))
Q = query_.view(32, -1, 8, 64 ) # 32 200 8 64
Q = Q.transpose(1, 2) # 32 8 200 64
Q.shape

torch.Size([32, 8, 200, 64])

#### Linear Projection

In [11]:
linear_ = torch.nn.Linear(512, 512)
x_ = torch.rand((32, 8, 200, 512)) # query * wq and 32 8 will be broadcasted to match the shape
out_ = linear_(x_)

print("Output after linear layer")
print(out_.shape)

out_ = out_.view(32, -1, 8, 64)

print("After view")
print(out_.shape)

out_ = out_.transpose(1, 2)

print("After transpose")
print(out_.shape) # this is for 8 attenti0n heads so that the shape is 32 batch size 8 heads and 1600 sequence for 8 attention and 64 inner dimension

Output after linear layer
torch.Size([32, 8, 200, 512])
After view
torch.Size([32, 1600, 8, 64])
After transpose
torch.Size([32, 8, 1600, 64])


In [12]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        # Linear layers for query, key, and value projections
        # Here directly projecting to the d_model and then later splitting for each attention head for computational efficiency
        self.W_Query = nn.Linear(d_model, d_model)
        self.W_Key = nn.Linear(d_model, d_model)
        self.W_Value = nn.Linear(d_model, d_model)
        
        # Final linear layer after concatenating heads
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Project and reshape input to (batch_size, num_heads, seq_len, d_k)
        Q = self.W_Query(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_Key(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_Value(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        print(f"Query Key Value Shape in MultiHead Attention after Projecting and Reshaping: {Q.shape} {K.shape} {V.shape}")
        
        # Scaled dot-product attention
        attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and project back to d_model
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        output = self.linear_layer(attention_output)
        
        return output


In [13]:
# Example
d_model = 512
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)

query = torch.randn(32, 200, 512)  # (batch_size, seq_len, d_model)
key = torch.randn(32, 200, 512)
value = torch.randn(32, 200, 512)
mask = torch.full((MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH), float("-inf"))
mask = torch.triu(mask, diagonal=1)

output = mha(query, key, value, mask)

print(f"Output of Multi Head Attention: {output.shape}")

Query Key Value Shape in MultiHead Attention after Projecting and Reshaping: torch.Size([32, 8, 200, 64]) torch.Size([32, 8, 200, 64]) torch.Size([32, 8, 200, 64])
Mask shape: torch.Size([200, 200])
Query Key Value shape in Self Attention: torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64])
Attention scores shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention weights shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention output shape in Self Attention: torch.Size([32, 8, 200, 64])
Output of Multi Head Attention: torch.Size([32, 200, 512])


## FeedForward Network

In [14]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.2):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

## Decoder

In [31]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.enc_dec_attention = MultiHeadAttention(d_model,num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Masked self attention
        print("..........................................................................................")
        print("\nDECODER SELF ATTENTION")
        self_attention_output = self.self_attention(x, x, x, tgt_mask)
        print("\nADD AND NORM")
        x = self.norm1(x + self.dropout(self_attention_output)) # add and Norm
        print(f"After add and norm: {x.size()}")
        # Cross Attention
        print("\nCROSS ATTENTION")
        enc_dec_attention_output = self.enc_dec_attention(x, enc_output, enc_output, src_mask)
        print(f"Output: {enc_dec_attention_output.size()}")
        print("\nADD AND NORM")
        x = self.norm2(x + self.dropout(enc_dec_attention_output)) # add and norm2
        print(f"After add and norm2: {x.size()}")
        
        # Feed Forward NN
        print("\nFEED FORWARD NN")
        ff_output = self.feed_forward(x)
        print(f"Output of ff_nn: {ff_output.size()}")
        print("\nADD AND NORM")
        x = self.norm3(x + self.dropout(ff_output))
        print(f"After add and norm3: {x.size()}")
        print("................................................................................................")
        return x

## Positional Encoding

In [16]:
torch.arange(0, 10, 2)

tensor([0, 2, 4, 6, 8])

In [17]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


In [18]:
#### Testing
pe = PositionalEncoding(512, 200)
x = torch.randn(32, 200, 512)
output = pe(x)

print(f"Output of Positional Encoding: {output.shape}")

Output of Positional Encoding: torch.Size([32, 200, 512])


## Decoder

In [32]:
class FullDecoder(torch.nn.Module):
    def __init__(self, d_model=512, num_layers=6, num_heads=8, d_ff=2048, vocab_size=10000, max_len=200):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        print("TOKEN EMBEDDING")
        x = self.embedding(x)
        print(f"embedding: {x.size()}")
        print("\nPOSITIONAL EMBEDDING")
        x = self.positional_encoding(x)
        print(f"Input to Decoder: {x.shape}\n")
        
        print("**************************DECODER START**************************")
        for i, layer in enumerate(self.layers):
            print(f"DECODER LAYER: {i+1}")
            x = layer(x, enc_output, src_mask, tgt_mask)
            print(f"Output of Decoder {i}: {x.size()}")
        print("**************************DECODER FINISH**************************")
        return self.fc_out(x)
        

In [33]:
### Testing
VOCAB_SIZE = 10000
D_MODEL = 512
NUM_ATTENTION_HEADS = 8
NUM_LAYERS = 3
FFN_HIDDEN = 2048
MAX_SEQUENCE_LENGTH = 200
BATCH_SIZE = 32

decoder_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MAX_SEQUENCE_LENGTH))
encoder_output = torch.randn(BATCH_SIZE, MAX_SEQUENCE_LENGTH, D_MODEL)

decoder = FullDecoder(
    d_model=D_MODEL,
    num_layers=NUM_LAYERS,
    num_heads=NUM_ATTENTION_HEADS,
    d_ff=FFN_HIDDEN,
    vocab_size=VOCAB_SIZE,
    max_len=MAX_SEQUENCE_LENGTH
)

decoder

FullDecoder(
  (embedding): Embedding(10000, 512)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (layers): ModuleList(
    (0-2): 3 x DecoderLayer(
      (self_attention): MultiHeadAttention(
        (W_Query): Linear(in_features=512, out_features=512, bias=True)
        (W_Key): Linear(in_features=512, out_features=512, bias=True)
        (W_Value): Linear(in_features=512, out_features=512, bias=True)
        (linear_layer): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (enc_dec_attention): MultiHeadAttention(
        (W_Query): Linear(in_features=512, out_features=512, bias=True)
        (W_Key): Linear(in_features=512, out_features=512, bias=True)
        (W_Value): Linear(in_features=512, out_features=512, bias=True)
        (linear_layer): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm2): LayerNorm((512,), eps=1e-05, e

In [35]:
output = decoder(decoder_input,encoder_output, mask, mask)

print(f"Output of After Final Layer: {output.shape}")

TOKEN EMBEDDING
embedding: torch.Size([32, 200, 512])

POSITIONAL EMBEDDING
Input to Decoder: torch.Size([32, 200, 512])

**************************DECODER START**************************
DECODER LAYER: 1
..........................................................................................

DECODER SELF ATTENTION
Query Key Value Shape in MultiHead Attention after Projecting and Reshaping: torch.Size([32, 8, 200, 64]) torch.Size([32, 8, 200, 64]) torch.Size([32, 8, 200, 64])
Mask shape: torch.Size([200, 200])
Query Key Value shape in Self Attention: torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64]), torch.Size([32, 8, 200, 64])
Attention scores shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention weights shape in Self Attention: torch.Size([32, 8, 200, 200])
Attention output shape in Self Attention: torch.Size([32, 8, 200, 64])

ADD AND NORM
After add and norm: torch.Size([32, 200, 512])

CROSS ATTENTION
Query Key Value Shape in MultiHead Attention after Projec