# Import libaries

In [None]:
import math
import random
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# # Utilities & config

In [None]:
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_ID = 0
SOS_ID = 1
EOS_ID = 2

@dataclass
class ModelConfig:
    enc_vocab_size: int
    dec_vocab_size: int
    max_len: int = 64
    d_model: int = 128
    n_head: int = 8
    ffn_hidden: int = 512
    n_layers: int = 2
    dropout: float = 0.1
    warmup_steps: int = 4000

# Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model, dtype=torch.float)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float)
            * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe) # [max_len, d_model]

    def forward(self, length: int) -> torch.Tensor:
        return self.pe[:length].unsqueeze(0) # [1, length, d_model]

# Token + Positional Embedding

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, max_len: int, dropout: float, pad_idx: int = PAD_ID):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_embedding = PositionalEncoding(d_model, max_len)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L]
        B, L = x.size()
        tok = self.token_embedding(x) * math.sqrt(self.d_model)
        pos = self.pos_embedding(L).to(x.device) # [1, L, d_model]
        out = tok + pos
        return self.dropout(out)

# Scaled Dot-Product Attention

In [None]:
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        # q, k, v [B, H, L, D]
        B, H, Lq, D = q.size()
        # k_t [B, H, D, L_k]
        k_t = k.transpose(2, 3)
        score = (q @ k_t) / math.sqrt(D)
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        attn = self.softmax(score)
        out = attn @ v
        return out, attn

# Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float):
        super().__init__()
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.scale_dot_product_attention = ScaleDotProductAttention()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x [B, L, D] -> [B, H, L, D]
        B, L, D = x.size()
        return x.view(B, L, self.n_head, self.d_head).transpose(1, 2)

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        # x [B, H, L, D] -> [B, L, D]
        B, H, L, D = x.size()
        return x.transpose(1, 2).contiguous().view(B, L, H * D)

    def forward(self, q, k, v, mask=None):
        # q, k, v [B, L, D]
        q = self.split_heads(self.W_q(q))
        k = self.split_heads(self.W_k(k))
        v = self.split_heads(self.W_v(v))
        out, attn = self.attn(q, k, v, mask=mask)
        out = self.combine_heads(out)
        out = self.dropout(self.W_o(out))
        return out

# LayerNorm & FFN

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_hat + self.beta

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, hidden: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, d_model),
        )

    def forward(self, x):
        return self.net(x)

# Encoder / Decoder Layers

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, drop_prob)
        self.dropout1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.dropout2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

    def forward(self, x, src_mask):
        # Self-attention
        _x = x
        x = self.self_attn(x, x, x, mask=src_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        # FFN
        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, drop_prob)
        self.dropout1 = nn.Dropout(drop_prob)
        self.norm1 = LayerNorm(d_model)

        self.enc_dec_attn = MultiHeadAttention(d_model, n_head, drop_prob)
        self.dropout2 = nn.Dropout(drop_prob)
        self.norm2 = LayerNorm(d_model)

        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.dropout3 = nn.Dropout(drop_prob)
        self.norm3 = LayerNorm(d_model)

    def forward(self, dec, enc, trg_mask, src_mask):
        # Masked self-attention (look-ahead + pad)
        _x = dec
        x = self.self_attn(dec, dec, dec, mask=trg_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        # Encoder-Decoder attention
        _x = x
        x = self.enc_dec_attn(x, enc, enc, mask=src_mask)
        x = self.dropout2(x)
        x = self.norm2(x + _x)

        # FFN
        _x = x
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

# Encoder / Decoder Stacks

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, pad_idx=PAD_ID):
        super().__init__()
        self.emb = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, pad_idx)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layers)
        ])

    def forward(self, x, src_mask):
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x  # [B, L, d_model]

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, pad_idx=PAD_ID):
        super().__init__()
        self.emb = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, pad_idx)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layers)
        ])
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, trg, enc, trg_mask, src_mask):
        x = self.emb(trg)
        for layer in self.layers:
            x = layer(x, enc, trg_mask, src_mask)
        return self.proj(x)  # logits: [B, L, vocab]

# Full Transformer Model

In [None]:
class Transformer(nn.Module):
    def __init__(self, cfg: ModelConfig, pad_idx=PAD_ID):
        super().__init__()
        self.encoder = Encoder(cfg.enc_vocab_size, cfg.max_len, cfg.d_model,
                               cfg.ffn_hidden, cfg.n_head, cfg.n_layers, cfg.dropout, pad_idx)
        self.decoder = Decoder(cfg.dec_vocab_size, cfg.max_len, cfg.d_model,
                               cfg.ffn_hidden, cfg.n_head, cfg.n_layers, cfg.dropout, pad_idx)
        self.pad_idx = pad_idx

    def make_src_mask(self, src: torch.Tensor) -> torch.Tensor:
        # src: [B, Ls] -> mask over keys: [B, 1, 1, Ls] (1=keep, 0=mask)
        mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2).int()
        return mask  # broadcast to [B, H, Lq, Lk]

    def make_trg_mask(self, trg_in: torch.Tensor) -> torch.Tensor:
        # trg_in: [B, Lt]
        B, Lt = trg_in.size()
        pad_mask = (trg_in != self.pad_idx).unsqueeze(1).unsqueeze(2).int()  # [B,1,1,Lt]
        # look-ahead mask: lower-triangular (allow attending to <= current position)
        sub_mask = torch.tril(torch.ones((Lt, Lt), device=trg_in.device)).unsqueeze(0).unsqueeze(0).int()  # [1,1,Lt,Lt]
        return pad_mask & sub_mask  # [B,1,Lt,Lt], 1=keep, 0=mask

    def forward(self, src: torch.Tensor, trg_in: torch.Tensor) -> torch.Tensor:
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg_in)
        memory = self.encoder(src, src_mask)
        logits = self.decoder(trg_in, memory, trg_mask, src_mask)
        return logits

    @torch.no_grad()
    def greedy_decode(self, src: torch.Tensor, max_len: int, sos_id=SOS_ID, eos_id=EOS_ID) -> torch.Tensor:
        # src: [B, Ls]
        self.eval()
        B = src.size(0)
        src_mask = self.make_src_mask(src)
        memory = self.encoder(src, src_mask)
        ys = torch.full((B, 1), sos_id, dtype=torch.long, device=src.device)
        for _ in range(max_len - 1):
            trg_mask = self.make_trg_mask(ys)
            out = self.decoder(ys, memory, trg_mask, src_mask)  # [B, L, V]
            next_token = out[:, -1, :].argmax(dim=-1, keepdim=True)  # [B,1]
            ys = torch.cat([ys, next_token], dim=1)
            if (next_token == eos_id).all():
                break
        return ys