<a href="https://colab.research.google.com/github/ishammansoor/AI-and-Machine-Learning/blob/main/TransformerEncoderBlock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class SelfAttention(nn.Module):
  def __init__(self, embed_dim):
    super(SelfAttention, self).__init__()

    self.embed_dim = embed_dim

    self.query = nn.Linear(embed_dim, embed_dim)
    self.key = nn.Linear(embed_dim, embed_dim)
    self.value = nn.Linear(embed_dim, embed_dim)

    self.out_proj = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    B, T, E = x.size()

    # step1: compute the Q, K, V
    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    # step2: Compute scaled dot product
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / E ** 0.5
    attn_weights = F.softmax(attn_scores, dim=-1)

    #step3: Apply weight to values
    attn_output = torch.matmul(attn_weights, V)


    output = self.out_proj(attn_output)

    return output, attn_weights




In [4]:
class TransformerEncoderBlock(nn.Module):

  def __init__(self, embed_dim, ffn_hidden_dim):
    super(TransformerEncoderBlock, self).__init__()

    self.self_attn = SelfAttention(embed_dim)
    self.norm1 = nn.LayerNorm(embed_dim)

    self.ffn = nn.Sequential(
        nn.Linear(embed_dim, ffn_hidden_dim),
        nn.ReLU(),
        nn.Linear(ffn_hidden_dim, embed_dim)
    )
    self.norm2 = nn.LayerNorm(embed_dim)


  def forward(self, x):

    # self Attention + Residual + Norm

    attn_out, attn_weights = self.self_attn(x)
    x = self.norm1(x + attn_out)

    # Feed Forward + Residual + Norm

    ff_out = self.ffn(x)
    x = self.norm2(x + ff_out)

    return x, attn_weights




In [5]:
batch_size = 2
seq_len = 5
embed_dim = 16
ff_dim = 64

x = torch.randn(batch_size, seq_len, embed_dim)
encoder_block = TransformerEncoderBlock(embed_dim, ff_dim)

out, weights = encoder_block(x)
print("Output shape:", out.shape) # (2, 5, 16)

Output shape: torch.Size([2, 5, 16])
