<a href="https://colab.research.google.com/github/elangbijak4/LLM-SLM-Examples/blob/main/Transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention_weights = None

    def scaled_dot_product_attention(self, query, key, value, mask=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        return torch.matmul(p_attn, value), p_attn

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        query, key, value = [l(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        x, self.attention_weights = self.scaled_dot_product_attention(query, key, value, mask=mask)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.output_linear(x)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model)
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(2)])

    def forward(self, x, mask=None):
        x = self.layer_norms[0](x + self.self_attn(x, x, x, mask))
        x = self.layer_norms[1](x + self.feed_forward(x))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.src_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model)
        self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(3)])

    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        x = self.layer_norms[0](x + self.self_attn(x, x, x, tgt_mask))
        x = self.layer_norms[1](x + self.src_attn(x, memory, memory, src_mask))
        x = self.layer_norms[2](x + self.feed_forward(x))
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, num_heads, num_layers):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(input_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads) for _ in range(num_layers)])
        self.output_linear = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.embedding(src) * math.sqrt(self.embedding.embedding_dim)
        tgt = self.embedding(tgt) * math.sqrt(self.embedding.embedding_dim)
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)

        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        memory = src

        for layer in self.decoder_layers:
            tgt = layer(tgt, memory, src_mask, tgt_mask)

        return self.output_linear(tgt)

# Hyperparameters
input_dim = 10000  # Size of the input vocabulary
output_dim = 10000  # Size of the output vocabulary
d_model = 512  # Embedding dimension
num_heads = 8  # Number of attention heads
num_layers = 6  # Number of encoder and decoder layers

# Initialize model
model = Transformer(input_dim, output_dim, d_model, num_heads, num_layers)

# Example input tensors (batch_size=2, seq_len=10)
src = torch.randint(0, input_dim, (10, 2))  # (seq_len, batch_size)
tgt = torch.randint(0, output_dim, (10, 2))  # (seq_len, batch_size)

# Forward pass
output = model(src, tgt)

print(output.shape)  # Output shape: (seq_len, batch_size, output_dim)


torch.Size([10, 2, 10000])
