# Phase 2–3: Train DeepPTX

Load Parquet dataset, build tokenizers, create DataLoaders with curriculum sampling,
and train the Pointer-Generator Transformer.

**Works on:**
- Google Colab (T4/A100) — fastest, uses CUDA AMP
- Mac M1/M2/M3 — uses MPS acceleration (no AMP, still fast)
- CPU — slowest fallback

In [None]:
# --- Run this cell first on Google Colab to clone the repo ---
import os
if os.path.exists("/content"):
    %cd /content
    !rm -rf /content/DeepPTX
    !git clone https://github.com/ns-1456/DeepPTX.git /content/DeepPTX
    %cd /content/DeepPTX
    !pip install -q torch pyarrow pandas wandb tqdm

In [None]:
import sys, os

# Detect environment
IN_COLAB = os.path.exists("/content")

if IN_COLAB:
    REPO_ROOT = "/content/DeepPTX"
else:
    REPO_ROOT = os.path.abspath("..")

if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

import torch
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm

from ptx_decompiler.utils import get_device, supports_amp, is_colab

DEVICE = get_device()
print(f"Device: {DEVICE}")
print(f"AMP supported: {supports_amp(DEVICE)}")
if DEVICE.type == "mps":
    print("Apple Silicon detected — using MPS acceleration")
    print("Tip: MPS is ~3-5x faster than CPU for training")

In [None]:
# ======================= Paths =======================
# Colab: data on Drive or in /content
# Local Mac: data in repo root (generated by notebook 01)

if IN_COLAB:
    # Option 1: Google Drive (persistent)
    # from google.colab import drive; drive.mount("/content/drive")
    # DATA_PATH = "/content/drive/MyDrive/NeuralPTX/dataset_100k.parquet"
    # SAVE_DIR  = Path("/content/drive/MyDrive/NeuralPTX/checkpoints")

    # Option 2: Local to Colab (faster I/O, lost on disconnect)
    DATA_PATH = os.path.join(REPO_ROOT, "dataset_100k.parquet")
    SAVE_DIR  = Path(REPO_ROOT) / "checkpoints"
else:
    # Local Mac / Linux
    DATA_PATH = os.path.join(REPO_ROOT, "dataset_100k.parquet")
    SAVE_DIR  = Path(REPO_ROOT) / "checkpoints"

print(f"Data: {DATA_PATH}")
print(f"Checkpoints: {SAVE_DIR}")
assert os.path.exists(DATA_PATH), f"Dataset not found at {DATA_PATH} — run notebook 01 first!"

In [None]:
# On Colab: apply AMP fix to copy_mechanism (scatter in float32, then cast back)
if IN_COLAB:
    import importlib
    copy_mechanism_path = os.path.join(REPO_ROOT, "ptx_decompiler", "model", "copy_mechanism.py")
    with open(copy_mechanism_path) as f:
        content = f.read()
    new_block = '''        contribution = contribution * valid.to(contribution.dtype)
        # Scatter in float32 to avoid AMP dtype mismatches; use explicit dtype/device
        dtype_orig = logits_vocab.dtype
        logits_f32 = logits_vocab.to(torch.float32)
        contribution_f32 = contribution.to(device=logits_f32.device, dtype=torch.float32)
        logits_f32.scatter_add_(2, enc_ids_clamped, contribution_f32)
        logits_vocab = logits_f32.to(dtype_orig)
        return logits_vocab, p_gen'''
    old1 = "        contribution = (contribution * valid.float()).to(logits_vocab.dtype)\n        logits_vocab.scatter_add_(2, enc_ids_clamped, contribution)\n        return logits_vocab, p_gen"
    old2 = "        contribution = contribution * valid.to(contribution.dtype)\n        contribution = contribution.to(logits_vocab.dtype)\n        logits_vocab.scatter_add_(2, enc_ids_clamped, contribution)\n        return logits_vocab, p_gen"
    old3 = "        contribution = contribution * valid.to(contribution.dtype)\n        # Scatter in float32 to avoid AMP dtype mismatches with scatter_add_, then cast back\n        dtype_orig = logits_vocab.dtype\n        logits_vocab = logits_vocab.float().scatter_add_(\n            2, enc_ids_clamped, contribution.float()\n        ).to(dtype_orig)\n        return logits_vocab, p_gen"
    applied = False
    for old in (old1, old2, old3):
        if old in content:
            content = content.replace(old, new_block)
            with open(copy_mechanism_path, "w") as f:
                f.write(content)
            print("Applied copy_mechanism AMP fix (float32 scatter).")
            applied = True
            break
    if not applied and "contribution_f32" in content:
        print("copy_mechanism already patched.")
    elif not applied:
        print("copy_mechanism.py format changed — check manually.")
    if "ptx_decompiler.model.copy_mechanism" in sys.modules:
        import ptx_decompiler.model.copy_mechanism as _cm
        importlib.reload(_cm)
        print("Reloaded copy_mechanism module.")

## Load Data & Build Tokenizers

In [None]:
from ptx_decompiler.data.dataset import load_parquet_for_training, collate_pad_batch, CurriculumSampler, PTXASTDataset
from ptx_decompiler.tokenizer import PTXTokenizer, ASTTokenizer

df = pd.read_parquet(DATA_PATH)
print(f"Loaded {len(df):,} samples")

ptx_tokenizer = PTXTokenizer(max_vocab_size=2000)
print("Building PTX vocabulary...")
ptx_tokenizer.build_vocab(df["ptx_normalized"].tolist())
print(f"PTX vocab size: {len(ptx_tokenizer)}")

ast_tokenizer = ASTTokenizer()
print(f"AST vocab size: {len(ast_tokenizer)}")

train_ds, val_ds = load_parquet_for_training(
    DATA_PATH, ptx_tokenizer, ast_tokenizer, train_ratio=0.9, seed=42
)
print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,}")

curriculum_sampler = CurriculumSampler(train_ds, shuffle=True, seed=42)

# With gradient checkpointing, 32 fits tier 3 on T4 16GB; use 16 if OOM
BATCH_SIZE = 32 if DEVICE.type == "cuda" else 64
NUM_WORKERS = 2 if DEVICE.type == "cuda" else 0  # parallel data loading on GPU
PIN_MEMORY = DEVICE.type == "cuda"

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    sampler=curriculum_sampler,
    collate_fn=lambda b: collate_pad_batch(b, ptx_tokenizer.pad_id, ast_tokenizer.pad_id),
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=NUM_WORKERS > 0,
)
val_loader = torch.utils.data.DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_pad_batch(b, ptx_tokenizer.pad_id, ast_tokenizer.pad_id),
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=NUM_WORKERS > 0,
)
print(f"Batch size: {BATCH_SIZE} | Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

## Build Model & Trainer

In [None]:
import inspect
from ptx_decompiler.model import PTXDecompilerModel
from ptx_decompiler.training import Trainer, get_cosine_schedule_with_warmup

# Map PTX tokens → AST vocab for copy mechanism
ptx_to_ast = torch.full((len(ptx_tokenizer),), -1, dtype=torch.long)
for tok, ptx_id in ptx_tokenizer.vocab.items():
    if tok in ast_tokenizer.vocab:
        ptx_to_ast[ptx_id] = ast_tokenizer.vocab[tok]

model = PTXDecompilerModel(
    ptx_vocab_size=len(ptx_tokenizer),
    ast_vocab_size=len(ast_tokenizer),
    d_model=256,
    n_heads=8,
    d_ff=1024,
    encoder_layers=6,
    decoder_layers=6,
    dropout=0.1,
    use_copy=True,
    ptx_to_ast_map=ptx_to_ast,
    use_gradient_checkpointing=DEVICE.type == "cuda",  # save memory so we can use batch 32
).to(DEVICE)

# torch.compile: faster after first epoch (PyTorch 2.0+, CUDA)
USE_COMPILE = True  # set False if first epoch is too slow or errors
if USE_COMPILE and DEVICE.type == "cuda" and hasattr(torch, "compile"):
    model = torch.compile(model, mode="reduce-overhead")
    print("Model compiled with torch.compile (first epoch will be slow).")

NUM_EPOCHS = 30
try:
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, fused=DEVICE.type == "cuda")
except TypeError:
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
num_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=1000, num_training_steps=num_steps
)

trainer_kw = dict(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
    pad_id_ast=ast_tokenizer.pad_id,
    eos_id_ast=ast_tokenizer.eos_id,
    label_smoothing=0.1,
    use_amp=True,
    curriculum_sampler=curriculum_sampler,
    save_dir=SAVE_DIR,
    use_wandb=False,
)
if "val_every" in inspect.signature(Trainer.__init__).parameters:
    trainer_kw["val_every"] = 2  # validate every 2 epochs (faster)
trainer = Trainer(**trainer_kw)

print(f"Model parameters: {model.count_parameters():,}")
print(f"AMP enabled: {trainer.use_amp}")
print(f"Training on: {DEVICE}")

## Train

In [None]:
trainer.train(num_epochs=NUM_EPOCHS)

In [None]:
# Save final checkpoint
SAVE_DIR.mkdir(parents=True, exist_ok=True)
final_path = SAVE_DIR / "checkpoint_final.pt"
torch.save({"model": model.state_dict(), "epoch": NUM_EPOCHS - 1}, final_path)
print(f"Saved final checkpoint to {final_path}")