# MLSA Transformer Project: Python Code Summarization (FINAL STABLE)

This version is **100% synchronized** with your local `src/model.py`. Any weights saved here will work on your Mac without errors.

## 1. Environment Setup

In [None]:
%pip install -q torch datasets transformers torchmetrics tqdm

from google.colab import drive
drive.mount('/content/drive')
!mkdir -p "/content/drive/My Drive/MLSA_Transformer_Checkpoints"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import RobertaTokenizer
import numpy as np
import copy
import math
import os
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content/drive/My Drive/MLSA_Transformer_Checkpoints"

## 2. Dataset Implementation

In [None]:
class CodeSummarizationDataset(Dataset):
    def __init__(self, split, tokenizer, max_code_len=256, max_summary_len=128, subset_size=None):
        self.dataset = load_dataset("code_x_glue_ct_code_to_text", "python", split=split)
        if subset_size and subset_size < len(self.dataset):
            self.dataset = self.dataset.select(range(subset_size))
        self.tokenizer = tokenizer
        self.max_code_len = max_code_len
        self.max_summary_len = max_summary_len

    def __len__(self): return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        c_enc = self.tokenizer(item['code'], max_length=self.max_code_len, padding='max_length', truncation=True, return_tensors='pt')
        s_enc = self.tokenizer(item['docstring'], max_length=self.max_summary_len, padding='max_length', truncation=True, return_tensors='pt')
        
        labels = s_enc['input_ids'].squeeze(0)
        return {
            'input_ids': c_enc['input_ids'].squeeze(0), 
            'decoder_input_ids': labels.clone(), 
            'labels': labels
        }

def get_dataloaders(batch_size=64):
    tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
    train_ds = CodeSummarizationDataset('train', tokenizer, subset_size=50000)
    val_ds = CodeSummarizationDataset('validation', tokenizer, subset_size=5000)
    print(f"Training samples: {len(train_ds)}, Validation: {len(val_ds)}")
    return DataLoader(train_ds, batch_size=batch_size, shuffle=True), DataLoader(val_ds, batch_size=batch_size), tokenizer

## 3. Transformer Architecture (Match to `src/model.py`)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        slope = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * slope)
        pe[:, 1::2] = torch.cos(position * slope)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return (x * np.sqrt(self.d_model)) + self.pe[:, :x.size(1), :]

class MultiHeadedAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        self.n_heads, self.d_model, self.d_k = n_heads, d_model, d_model // n_heads
        self.linear_query = nn.Linear(d_model, d_model)
        self.linear_key = nn.Linear(d_model, d_model)
        self.linear_value = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(p=dropout)
    def make_chunks(self, x):
        return x.view(x.size(0), x.size(1), self.n_heads, self.d_k).transpose(1, 2)
    def init_keys(self, key):
        self.proj_key = self.make_chunks(self.linear_key(key))
        self.proj_value = self.make_chunks(self.linear_value(key))
    def forward(self, query, mask=None):
        if mask is not None and mask.dim() == 3: mask = mask.unsqueeze(1)
        q = self.make_chunks(self.linear_query(query))
        scores = torch.matmul(q, self.proj_key.transpose(-2, -1)) / np.sqrt(self.d_k)
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        alphas = self.dropout(F.softmax(scores, dim=-1))
        context = torch.matmul(alphas, self.proj_value).transpose(1, 2).contiguous().view(query.size(0), -1, self.d_model)
        return self.linear_out(context)

class SubLayerWrapper(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.norm, self.drop = nn.LayerNorm(d_model), nn.Dropout(dropout)
    def forward(self, x, sublayer, is_self_attn=False, **kwargs):
        norm_x = self.norm(x)
        if is_self_attn: sublayer.init_keys(norm_x)
        return x + self.drop(sublayer(norm_x, **kwargs))

class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, dropout=0.1):
        super().__init__()
        self.n_heads, self.d_model = n_heads, d_model
        self.self_attn_heads = MultiHeadedAttention(n_heads, d_model, dropout=dropout)
        self.ffn = nn.Sequential(nn.Linear(d_model, ff_units), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_units, d_model))
        self.sublayers = nn.ModuleList([SubLayerWrapper(d_model, dropout) for _ in range(2)])
    def forward(self, query, mask=None):
        att = self.sublayers[0](query, sublayer=self.self_attn_heads, is_self_attn=True, mask=mask)
        return self.sublayers[1](att, sublayer=self.ffn)

class DecoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, dropout=0.1):
        super().__init__()
        self.n_heads, self.d_model = n_heads, d_model
        self.self_attn_heads = MultiHeadedAttention(n_heads, d_model, dropout=dropout)
        self.cross_attn_heads = MultiHeadedAttention(n_heads, d_model, dropout=dropout)
        self.ffn = nn.Sequential(nn.Linear(d_model, ff_units), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ff_units, d_model))
        self.sublayers = nn.ModuleList([SubLayerWrapper(d_model, dropout) for _ in range(3)])
    def init_keys(self, states): self.cross_attn_heads.init_keys(states)
    def forward(self, query, source_mask=None, target_mask=None):
        att1 = self.sublayers[0](query, sublayer=self.self_attn_heads, is_self_attn=True, mask=target_mask)
        att2 = self.sublayers[1](att1, sublayer=self.cross_attn_heads, mask=source_mask)
        return self.sublayers[2](att2, sublayer=self.ffn)

class EncoderTransf(nn.Module):
    def __init__(self, encoder_layer, n_layers=1, max_len=512):
        super().__init__()
        self.d_model = encoder_layer.d_model
        self.pe = PositionalEncoding(max_len, self.d_model)
        self.norm = nn.LayerNorm(self.d_model)
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(n_layers)])
    def forward(self, query, mask=None):
        x = self.pe(query)
        for layer in self.layers: x = layer(x, mask)
        return self.norm(x)

class DecoderTransf(nn.Module):
    def __init__(self, decoder_layer, n_layers=1, max_len=512):
        super(DecoderTransf, self).__init__()
        self.d_model = decoder_layer.d_model
        self.pe = PositionalEncoding(max_len, self.d_model)
        self.norm = nn.LayerNorm(self.d_model)
        self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(n_layers)])
    def init_keys(self, states): 
        for layer in self.layers: layer.init_keys(states)
    def forward(self, query, s_mask=None, t_mask=None):
        x = self.pe(query)
        for layer in self.layers: x = layer(x, s_mask, t_mask)
        return self.norm(x)

class EncoderDecoderTransf(nn.Module):
    def __init__(self, encoder, decoder, src_vocab_size, tgt_vocab_size, max_len=512):
        super(EncoderDecoderTransf, self).__init__()
        self.encoder, self.decoder, self.d_model = encoder, decoder, encoder.d_model
        self.src_embed = nn.Embedding(src_vocab_size, self.d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, self.d_model)
        self.out_linear = nn.Linear(self.d_model, tgt_vocab_size)
        self.register_buffer('subsequent_mask', (1 - torch.triu(torch.ones((1, max_len, max_len)), diagonal=1)))
    def encode(self, src, mask=None):
        states = self.encoder(self.src_embed(src), mask)
        self.decoder.init_keys(states); return states
    def decode(self, tgt, s_mask=None, t_mask=None):
        if t_mask is None: t_mask = self.subsequent_mask[:, :tgt.size(1), :tgt.size(1)]
        outputs = self.decoder(self.tgt_embed(tgt), s_mask, t_mask)
        return self.out_linear(outputs)
    def forward(self, src, tgt):
        self.encode(src)
        return self.decode(tgt)

## 4. Faster Training Loop (RECOVERY ENABLED)

In [None]:
def train(resume_epoch=None):
    batch_size, epochs, lr = 64, 10, 1e-4
    d_model, n_heads, n_layers, ff_units = 256, 8, 4, 512
    train_loader, val_loader, tokenizer = get_dataloaders(batch_size=batch_size)
    
    enclayer = EncoderLayer(n_heads, d_model, ff_units)
    declayer = DecoderLayer(n_heads, d_model, ff_units)
    encoder = EncoderTransf(enclayer, n_layers, max_len=256)
    decoder = DecoderTransf(declayer, n_layers, max_len=128)
    model = EncoderDecoderTransf(encoder, decoder, tokenizer.vocab_size, tokenizer.vocab_size, max_len=128).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    
    start_epoch = 0
    if resume_epoch:
        ckpt = os.path.join(SAVE_DIR, f'checkpoint_epoch_{resume_epoch}.pt')
        if os.path.exists(ckpt):
            model.load_state_dict(torch.load(ckpt))
            start_epoch = resume_epoch
            print(f"--- Resumed from Epoch {resume_epoch} ---")

    for epoch in range(start_epoch, epochs):
        model.train(); total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in pbar:
            ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
            optimizer.zero_grad()
            logits = model(ids, labels[:, :-1])
            loss = criterion(logits.view(-1, tokenizer.vocab_size), labels[:, 1:].contiguous().view(-1))
            loss.backward(); optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
        
        path = os.path.join(SAVE_DIR, f'checkpoint_epoch_{epoch+1}.pt')
        torch.save(model.state_dict(), path)
        print(f"Epoch {epoch+1} finished and saved to Google Drive.")

train()