<a href="https://colab.research.google.com/github/gnoejh/ict1022/blob/main/Transformer/11_decoder_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

# Attention, Positional Encoding, Decoder Block
- https://newsletter.theaiedge.io/p/the-transformer-architecture-v2

### 1. Attention Mechanism


**Objective**: Enable models to focus on relevant parts of input sequences, essential for handling longer dependencies.

**Explanation**: 
- Attention allows a model to “attend” to specific input parts when generating each output. It’s especially useful for tasks requiring selective referencing of input tokens (e.g., translation, summarization).

    

In [6]:
# Simple attention mechanism
def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(key.size(-1), dtype=torch.float32))
    attn_weights = torch.nn.functional.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, value)
    return output, attn_weights

# Example inputs
query = torch.randn(5, 3, 8)  # Batch size = 5, Sequence length = 3, Embedding size = 8
key = torch.randn(5, 3, 8)
value = torch.randn(5, 3, 8)

# Apply attention
output, weights = attention(query, key, value)
print("Attention Output:", output)
print("Attention Weights:", weights)


Attention Output: tensor([[[-0.4368,  1.1387, -0.0765, -0.4963,  0.6597,  0.8202,  1.0736,
          -0.4982],
         [-0.2681,  1.0187, -0.1171, -0.4373,  0.7556,  0.8140,  0.9195,
          -0.8436],
         [-0.7186,  1.3628,  0.2090, -0.7164, -0.0629,  1.0361,  2.0843,
           1.2794]],

        [[ 0.5862,  1.1620, -0.6905, -0.6694, -0.0634,  0.3326,  0.6173,
          -0.4638],
         [ 0.6313,  1.2360, -0.8401, -0.5667, -0.3285,  0.3045,  0.5720,
          -0.7707],
         [ 0.2021,  0.9877, -0.2883, -0.7976,  0.1314,  0.2375,  0.7687,
           0.0092]],

        [[-0.6210,  1.3632, -0.9874,  0.0380, -0.3837, -1.3748, -0.1222,
           0.1058],
         [-0.7731,  1.2723, -0.4795,  0.2371, -0.2152, -0.6736,  0.1619,
           0.1475],
         [-1.0058,  1.0546,  0.2966,  0.3289, -0.0406,  0.5681,  0.5247,
           0.1245]],

        [[-0.9264, -0.6532,  0.1170, -0.4256,  0.2450, -0.5173, -0.0366,
           0.8772],
         [-0.4929, -0.5228,  0.1945,  0.4260, 

### 2. Self-Attention Mechanism


**Objective**: Understand how self-attention allows a model to relate each word in a sentence to every other word, helping capture context effectively.

**Explanation**:
   - **Queries, Keys, and Values**: In self-attention, each word is represented by three vectors: a query, a key, and a value.
   - **Dot Product and Scaling**: The query and key vectors are multiplied to determine attention scores, representing the similarity between words. These scores are scaled and then passed through a softmax function to obtain attention weights.
   - **Weighted Sum**: Each word’s final representation is the sum of all value vectors, weighted by attention scores, allowing the model to focus on relevant words in context.
    

### 3. Multi-Head Attention


**Objective**: Enhance the model’s ability to capture diverse word relationships by using multiple attention heads.

**Explanation**:
   - **Multiple Heads**: Multi-head attention applies self-attention multiple times in parallel, each with different learned projections. This allows the model to capture various types of dependencies, such as grammatical or semantic.
   - **Concatenation and Linear Transformation**: Each head’s output is concatenated and passed through a linear layer, blending the insights from all heads into a single representation.
    

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0
        self.head_dim = embed_size // num_heads
        self.num_heads = num_heads

        # Linear layers for queries, keys, values
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        
        # Final linear layer after concatenation
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value):
        batch_size, seq_len, embed_size = query.size()

        # Transform inputs into multiple heads
        query = self.query(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Apply self-attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attn_weights = F.softmax(scores, dim=-1)
        attention = torch.matmul(attn_weights, value)
        
        # Concatenate heads and pass through the final linear layer
        attention = attention.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_size)
        return self.fc_out(attention)

# Example usage
embed_size = 16
num_heads = 4
x = torch.randn(2, 5, embed_size)  # Batch size = 2, Sequence length = 5

multi_head_attention = MultiHeadAttention(embed_size, num_heads)
output = multi_head_attention(x, x, x)
print("Multi-Head Attention Output:", output)


Multi-Head Attention Output: tensor([[[ 2.2437e-01, -6.2121e-02,  3.3677e-01,  3.4617e-01, -1.2561e-01,
          -1.4320e-02, -3.0945e-01, -1.2908e-01,  9.7997e-02, -1.2948e-01,
           1.8598e-01, -1.2567e-01, -1.3778e-01, -1.3313e-01,  4.0892e-01,
           2.7462e-01],
         [ 2.7985e-01, -9.5385e-02,  2.7260e-01,  3.0782e-01, -8.0252e-03,
          -3.4150e-02, -3.4385e-01, -8.9990e-02,  1.1554e-01, -1.2106e-01,
           1.8222e-01, -1.0269e-01, -6.9136e-02, -2.5624e-01,  3.5380e-01,
           3.0314e-01],
         [ 2.1652e-01, -1.2097e-01,  1.3035e-01,  2.7102e-01, -7.1895e-02,
          -1.0088e-01, -2.7516e-01, -1.3812e-01,  1.3554e-01, -1.2206e-01,
           8.2788e-02, -1.4428e-01, -9.2552e-02, -1.0975e-01,  1.7394e-01,
           3.4829e-01],
         [-1.9586e-02, -2.2503e-01, -1.5532e-02,  2.9430e-01,  3.6480e-04,
          -2.1269e-01, -1.4501e-01, -1.2479e-01,  1.9416e-01,  3.2542e-02,
           2.2505e-02, -1.6535e-01, -2.5836e-01, -9.1432e-02,  1.7192e-01,

### 4. Positional Encoding


**Objective**: Introduce positional information into word embeddings, as Transformers process words in parallel without inherent sequence order.

**Explanation**:
   - Since Transformers do not process words sequentially, positional encodings are added to the embeddings to give the model information about the order of words.
   - **Sine and Cosine Functions**: Positional encodings are computed using sine and cosine functions with varying frequencies, creating unique patterns for each position in the sequence.
    

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        for pos in range(max_len):
            for i in range(0, embed_size, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / embed_size)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * i) / embed_size)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

# Example usage
embed_size = 16
seq_len = 10
pos_encoding = PositionalEncoding(embed_size)
x = torch.randn(1, seq_len, embed_size)
output = pos_encoding(x)
print("Positional Encoding Output:", output)


Positional Encoding Output: tensor([[[-0.3780,  2.4281,  1.1620,  2.2261,  0.3933,  3.6151, -0.3845,
          -0.5912, -0.7609,  1.5486, -0.8610,  3.1962, -0.3819,  0.4404,
           0.8772,  3.1241],
         [ 0.5593,  1.5919, -0.8094,  0.7374, -0.2747,  0.1606, -0.4240,
           1.1753,  0.2719, -0.4462, -0.2870,  0.7103,  0.2695, -0.2540,
          -0.6235,  0.8658],
         [ 0.5411,  1.9566, -0.7689,  2.7376,  2.2273,  1.1636, -0.0636,
           0.1820, -0.6461,  1.4622, -1.7594,  0.7083, -1.5733,  1.4677,
           0.6176,  0.3939],
         [ 1.2280,  0.0967,  0.8716,  1.9796,  0.0933,  1.1175, -0.6185,
          -0.5824, -1.0138,  1.4257,  0.7839,  0.7707,  0.5574,  2.9357,
           1.6869,  1.6726],
         [-1.1834, -0.7026, -0.2265,  1.2105,  0.1016,  1.0678,  0.4235,
           1.6662, -0.5722, -0.1417,  0.7989, -0.1837, -0.7619, -0.9646,
          -0.0545,  1.5960],
         [-2.2983,  0.0287,  0.1830,  1.7864,  1.2916,  1.4007, -0.3100,
           1.2188,  0.29

### 5. Transformer Decoder Block


**Objective**: Understand how the Transformer decoder combines self-attention, encoder-decoder attention, and feedforward layers, the fundamental units in a Transformer.

**Explanation**:
   - **Self-Attention Layer**: The decoder starts with a masked multi-head attention layer, allowing the model to attend to relevant parts of the input while preventing attending to future tokens.
   - **Encoder-Decoder Attention**: The second attention layer attends to the encoder’s output, allowing the decoder to focus on relevant parts of the input sequence.
   - **Add & Norm**: After each attention layer, a residual connection is added, followed by layer normalization to stabilize learning.
   - **Feedforward Layer**: The attention output is passed through a fully connected layer for further processing.
   - **Final Add & Norm**: A second residual connection and layer normalization complete the block, making the Transformer’s decoder robust to varying input sequences.
    

In [9]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, forward_expansion):
        super(TransformerDecoderBlock, self).__init__()
        self.self_attention = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.cross_attention = MultiHeadAttention(embed_size, num_heads)  # Cross-attention layer
        self.norm2 = nn.LayerNorm(embed_size)
        self.encoder_decoder_attention = MultiHeadAttention(embed_size, num_heads)
        self.norm3 = nn.LayerNorm(embed_size)
        self.norm4 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

    def forward(self, x, enc_out):
        self_attention, _ = self.self_attention(x, x, x)
        x = self.norm1(self_attention + x)  # Add & Norm
        cross_attention, _ = self.cross_attention(x, enc_out, enc_out)  # Cross-attention
        x = self.norm2(cross_attention + x)  # Add & Norm
        encoder_decoder_attention, _ = self.encoder_decoder_attention(x, enc_out, enc_out)
        x = self.norm3(encoder_decoder_attention + x)  # Add & Norm
        forward = self.feed_forward(x)
        x = self.norm4(forward + x)  # Add & Norm
        return x

# Example usage
embed_size = 16
num_heads = 4
forward_expansion = 4
decoder_block = TransformerDecoderBlock(embed_size, num_heads, forward_expansion)
x = torch.randn(2, 5, embed_size)
enc_out = torch.randn(2, 5, embed_size)
output = decoder_block(x, enc_out)
print("Transformer Decoder Block Output:", output)


Transformer Decoder Block Output: tensor([[[-0.0504,  0.2236, -1.9815,  1.2312,  0.4428, -0.4968, -1.0361,
          -0.2568,  0.3426,  0.2296,  2.6510,  0.5192, -0.9043, -0.0385,
          -0.0826, -0.7931],
         [ 1.0573,  0.5727, -1.7287,  0.5727,  1.4477,  0.3912,  0.1467,
          -1.9621, -0.3355, -0.2418, -0.9247,  0.4133, -0.1908, -1.3024,
           0.9555,  1.1290],
         [ 2.4449, -1.0871, -0.1706, -1.1433,  0.0641, -0.1058,  0.4111,
           1.0630, -0.8163, -1.0199, -1.0326, -1.0673,  0.8826, -0.0682,
           0.5570,  1.0883],
         [-0.2866, -1.5279, -0.0509,  1.1634, -0.2142,  0.0341,  0.3061,
          -0.3619,  1.1847, -1.8800, -0.9231,  0.6996, -1.1453,  0.3403,
           1.9252,  0.7363],
         [ 1.6752, -1.7031, -0.9129,  0.5246, -1.7744, -0.1799,  1.0348,
          -0.4354,  0.4402, -0.8662,  1.2825,  0.2872, -0.5341, -0.1110,
          -0.0582,  1.3306]],

        [[ 0.5203, -0.7214, -1.0901,  0.0691, -0.6485, -1.1235,  0.8946,
           1.418