# MLSA Transformer Project: Python Code Summarization (OPTIMIZED V2)

**CRITICAL**: If you see an `OutOfMemoryError`, please go to the top menu and select **Runtime -> Restart session**, then run all cells again. This clears the GPU memory.

## 1. Environment Setup

In [None]:
%pip install -q torch datasets transformers torchmetrics tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import RobertaTokenizer
import numpy as np
import copy
import math
import os
from tqdm.auto import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Optimized Dataset Implementation

In [None]:
class CodeSummarizationDataset(Dataset):
    def __init__(self, split, tokenizer, max_code_len=256, max_summary_len=128, subset_size=None):
        self.dataset = load_dataset("code_x_glue_ct_code_to_text", "python", split=split)
        if subset_size and subset_size < len(self.dataset):
            self.dataset = self.dataset.select(range(subset_size))
        self.tokenizer = tokenizer
        self.max_code_len = max_code_len
        self.max_summary_len = max_summary_len
    def __len__(self): return len(self.dataset)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        c_enc = self.tokenizer(item['code'], max_length=self.max_code_len, padding='max_length', truncation=True, return_tensors='pt')
        s_enc = self.tokenizer(item['docstring'], max_length=self.max_summary_len, padding='max_length', truncation=True, return_tensors='pt')
        return {'input_ids': c_enc['input_ids'].squeeze(0), 'labels': s_enc['input_ids'].squeeze(0)}

def get_dataloaders(batch_size=64):
    tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
    train_ds = CodeSummarizationDataset('train', tokenizer, subset_size=50000)
    val_ds = CodeSummarizationDataset('validation', tokenizer, subset_size=5000)
    print(f"Dataset Statistics: Train={len(train_ds)}, Val={len(val_ds)}")
    return DataLoader(train_ds, batch_size=batch_size, shuffle=True), DataLoader(val_ds, batch_size=batch_size), tokenizer

## 3. Transformer Model

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        slope = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * slope); pe[:, 1::2] = torch.cos(position * slope)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x): return (x * np.sqrt(x.size(-1))) + self.pe[:, :x.size(1), :]

class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super().__init__()
        self.n_heads, self.d_model, self.d_k = n_heads, d_model, d_model // n_heads
        self.l_q, self.l_k, self.l_v, self.l_o = [nn.Linear(d_model, d_model) for _ in range(4)]; self.dropout = nn.Dropout(dropout)
    def split_heads(self, x): return x.view(x.size(0), -1, self.n_heads, self.d_k).transpose(1, 2)
    def init_keys(self, key): self.pk = self.split_heads(self.l_k(key)); self.pv = self.split_heads(self.l_v(key))
    def forward(self, query, mask=None):
        q = self.split_heads(self.l_q(query))
        scores = torch.matmul(q, self.pk.transpose(-2, -1)) / np.sqrt(self.d_k)
        if mask is not None: scores = scores.masked_fill(mask.unsqueeze(1) == 0, -1e9)
        alphas = self.dropout(F.softmax(scores, dim=-1))
        return self.l_o(torch.matmul(alphas, self.pv).transpose(1, 2).contiguous().view(query.size(0), -1, self.d_model))

class SubLayerWrapper(nn.Module):
    def __init__(self, d_model, dropout): super().__init__(); self.norm, self.drop = nn.LayerNorm(d_model), nn.Dropout(dropout)
    def forward(self, x, sublayer, is_self_attn=False, **kwargs):
        nx = self.norm(x)
        if is_self_attn: sublayer.init_keys(nx)
        return x + self.drop(sublayer(nx, **kwargs))

class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, dropout=0.1):
        super().__init__(); self.self_attn = MultiHeadedAttention(n_heads, d_model, dropout)
        self.ffn = nn.Sequential(nn.Linear(d_model, ff_units), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_units, d_model))
        self.subs = nn.ModuleList([SubLayerWrapper(d_model, dropout) for _ in range(2)])
    def forward(self, x, mask=None):
        x = self.subs[0](x, self.self_attn, is_self_attn=True, mask=mask)
        return self.subs[1](x, self.ffn)

class DecoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, dropout=0.1):
        super().__init__(); self.self_attn = MultiHeadedAttention(n_heads, d_model, dropout)
        self.cross_attn = MultiHeadedAttention(n_heads, d_model, dropout)
        self.ffn = nn.Sequential(nn.Linear(d_model, ff_units), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_units, d_model))
        self.subs = nn.ModuleList([SubLayerWrapper(d_model, dropout) for _ in range(3)])
    def init_keys(self, states): self.cross_attn.init_keys(states)
    def forward(self, x, s_mask=None, t_mask=None):
        x = self.subs[0](x, self.self_attn, is_self_attn=True, mask=t_mask)
        x = self.subs[1](x, self.cross_attn, mask=s_mask)
        return self.subs[2](x, self.ffn)

class EncoderDecoderTransf(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, ff_units, vocab_size, max_len=512):
        super().__init__(); self.pe = PositionalEncoding(max_len, d_model)
        self.enclayers = nn.ModuleList([EncoderLayer(n_heads, d_model, ff_units) for _ in range(n_layers)])
        self.declayers = nn.ModuleList([DecoderLayer(n_heads, d_model, ff_units) for _ in range(n_layers)])
        self.src_embed = nn.Embedding(vocab_size, d_model); self.tgt_embed = nn.Embedding(vocab_size, d_model)
        self.out_linear = nn.Linear(d_model, vocab_size); self.norm = nn.LayerNorm(d_model)
        self.register_buffer('mask', (1 - torch.triu(torch.ones((1, max_len, max_len)), diagonal=1)))

    def forward(self, src, tgt):
        x = self.pe(self.src_embed(src))
        for l in self.enclayers: x = l(x)
        enc_states = self.norm(x)
        for l in self.declayers: l.init_keys(enc_states)
        y = self.pe(self.tgt_embed(tgt)); L = tgt.size(1)
        for l in self.declayers: y = l(y, t_mask=self.mask[:, :L, :L])
        return self.out_linear(self.norm(y))

## 4. Accelerated Training Loop

In [None]:
def train():
    # SAFER HYPERPARAMETERS TO AVOID OOM
    batch_size, epochs, lr = 64, 10, 1e-4
    d_model, n_heads, n_layers, ff_units = 256, 8, 4, 512

    train_loader, val_loader, tokenizer = get_dataloaders(batch_size=batch_size)
    model = EncoderDecoderTransf(n_layers, n_heads, d_model, ff_units, tokenizer.vocab_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    for epoch in range(epochs):
        model.train(); total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in pbar:
            ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
            dec_in, targets = labels[:, :-1], labels[:, 1:].contiguous()
            optimizer.zero_grad()
            logits = model(ids, dec_in)
            loss = criterion(logits.view(-1, tokenizer.vocab_size), targets.view(-1))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'ppl': f"{math.exp(loss.item()):.2f}"})
        
        model.eval(); v_loss = 0
        with torch.no_grad():
            for b in val_loader:
                ids, labels = b['input_ids'].to(device), b['labels'].to(device)
                v_loss += criterion(model(ids, labels[:, :-1]).view(-1, tokenizer.vocab_size), labels[:, 1:].contiguous().view(-1)).item()
        
        print(f"Epoch {epoch+1} Val Loss: {v_loss/len(val_loader):.4f}")
        torch.save(model.state_dict(), f'checkpoint_epoch_{epoch+1}.pt')

train()