In [None]:
# Transformer Implementation

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer

# --- Data Loading and Tokenization ---
# Load WMT14 Czech-English dataset

datasets = load_dataset("wmt/wmt14", "cs-en")

# Use a pretrained tokenizer for inputs and targets

tokenizer = AutoTokenizer.from_pretrained("t5-small")

max_length = 128
batch_size = 256

# Preprocessing function
def preprocess(batch):
    # Extract source and target sentences
    sources = [ex['cs'] for ex in batch['translation']]
    targets = [ex['en'] for ex in batch['translation']]
    # Tokenize
    inputs = tokenizer(sources, padding='max_length', truncation=True, max_length=max_length)
    labels = tokenizer(targets, padding='max_length', truncation=True, max_length=max_length)
    inputs['labels'] = labels['input_ids']
    return inputs

# Apply preprocessing and set format
tokenized = datasets.map(preprocess, batched=True, remove_columns=datasets['train'].column_names)
tokenized.set_format(type='torch', columns=['input_ids','attention_mask','labels'])

# DataLoader for training
dataloader = torch.utils.data.DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)


In [2]:
# --- Model Components ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1), :]

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scaling = d_k ** -0.5

    def forward(self, q, k, v, mask=None):
        # q, k, v: (batch, heads, seq_len, d_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        return torch.matmul(attn, v)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        # Linear projections
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, q,kv, mask=None):
        batch_size, seq_len, _ = q.size()
        # Linear project and split into heads
        q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(kv).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(kv).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        out = self.attention(q, k, v, mask)  # (batch, heads, seq_len, d_k)
        # Concatenate heads
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.fc(out)

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention + residual + norm
        attn_out = self.dropout(self.self_attn(x,x, mask))
        x = self.norm1(x + attn_out)
        # FFN + residual + norm
        ffn_out = self.dropout(self.ffn(x))
        return self.norm2(x + ffn_out)

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, look_ahead_mask=None, padding_mask=None):
        # Masked self-attention
        attn1 = self.dropout(self.self_attn(x,x, look_ahead_mask))
        x = self.norm1(x + attn1)
        # Encoder-decoder attention
        attn2 = self.dropout(self.cross_attn(x, enc_out, padding_mask))
        x = self.norm2(x + attn2)
        # FFN
        ffn_out = self.dropout(self.ffn(x))
        return self.norm3(x + ffn_out)

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_len=5000):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.enc_layers = nn.ModuleList([EncoderLayer(d_model,num_heads,d_ff) for _ in range(num_layers)])
        self.dec_layers = nn.ModuleList([DecoderLayer(d_model,num_heads,d_ff) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def generate_padding_mask(self, seq):
        # seq: (batch, seq_len)
        return (seq != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)

    def generate_look_ahead_mask(self, size):
        mask = torch.tril(torch.ones((size, size),device="cuda")).bool()
        return mask.unsqueeze(0).unsqueeze(1)

    def forward(self, src, tgt):
        # src, tgt: (batch, seq_len)
        src_mask = self.generate_padding_mask(src)
        tgt_mask = self.generate_padding_mask(tgt) & self.generate_look_ahead_mask(tgt.size(1))

        # Embedding + Positional Encoding
        enc = self.pos_emb(self.token_emb(src) * math.sqrt(self.token_emb.embedding_dim))
        for layer in self.enc_layers:
            enc = layer(enc, src_mask)

        dec = self.pos_emb(self.token_emb(tgt) * math.sqrt(self.token_emb.embedding_dim))
        for layer in self.dec_layers:
            dec = layer(dec, enc, look_ahead_mask=tgt_mask, padding_mask=src_mask)

        return self.fc_out(dec)


In [3]:
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(vocab_size=tokenizer.vocab_size).to(device)


In [4]:
# --- Training Loop ---
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9
)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
num_epochs = 3

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params/1e6}M")


for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        optimizer.zero_grad()
        src = batch['input_ids'].to(device)
        tgt = batch['labels'].to(device)
        # print(src,tgt,model)
        # Teacher forcing: predict next tokens
        outputs = model(src, tgt[:, :-1])  # (batch, seq_len-1, vocab_size)
        loss = criterion(
            outputs.reshape(-1, tokenizer.vocab_size),
            tgt[:, 1:].reshape(-1)
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch}/{num_epochs}]  Loss: {avg_loss:.4f}")

Total parameters: 77.040996M
Epoch [1/3]  Loss: 4.2592
Epoch [2/3]  Loss: 3.1632
Epoch [3/3]  Loss: 2.7719


In [5]:
!nvidia-smi

Fri May  9 14:07:56 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.12              Driver Version: 550.90.12      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:01:00.0 Off |                    0 |
| N/A   31C    P0             53W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  |   00