# WikiText-2 Depth Experiment - Colab Notebook (PyTorch)

This notebook runs PyTorch depth experiments for the project in `03-wikitext2`.

**What this notebook now does:**
- installs PyTorch dependencies in Colab
- clones the project directly from GitHub (no ZIP upload)
- verifies required project/data files
- runs and saves PyTorch experiment results

**Run time estimate:**
- PyTorch (all depths): ~30-60 min on GPU


## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q matplotlib numpy kagglehub

print("✓ Dependencies installed")


In [None]:
# Optional: mount Google Drive for persistent outputs
MOUNT_DRIVE = False

if MOUNT_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    print("✓ Google Drive mounted")
else:
    print("Drive mount skipped (set MOUNT_DRIVE=True to enable)")


In [None]:
# Clone project from GitHub (no manual ZIP upload required)
import os
import shutil
import subprocess

REPO_URL = 'https://github.com/ng3gn/ground-up-vla.git'
REPO_BRANCH = 'main'
PROJECT_SUBDIR = '03-wikitext2'  # set to '' if this project is at repo root
CHECKOUT_DIR = '/content/ground-up-vla'

if os.path.exists(CHECKOUT_DIR):
    shutil.rmtree(CHECKOUT_DIR)

subprocess.run(['git', 'clone', '--depth', '1', '--branch', REPO_BRANCH, REPO_URL, CHECKOUT_DIR], check=True)
PROJECT_DIR = os.path.join(CHECKOUT_DIR, PROJECT_SUBDIR) if PROJECT_SUBDIR else CHECKOUT_DIR
os.chdir(PROJECT_DIR)

print(f"✓ Repo cloned to: {CHECKOUT_DIR}")
print(f"✓ Working directory: {PROJECT_DIR}")


## 2. Verify Project Checkout

This replaces manual Drive/ZIP copy steps.


In [None]:
import os

required_files = [
    'vocab.py',
    'tokenizer.py',
    'dataset.py',
    'main.py',
    'model/transformer_full.py',
]

missing = []
for f in required_files:
    path = os.path.join(PROJECT_DIR, f)
    if os.path.exists(path):
        print(f"✓ {f}")
    else:
        print(f"✗ {f}")
        missing.append(f)

if missing:
    raise FileNotFoundError(f"Missing required project files: {missing}")

print("
✓ Project files verified")


## 3. Download WikiText-2 from Kaggle Hub


In [None]:
import os
import kagglehub

# Set this if you know the exact Kaggle dataset handle. Leave None to auto-try common handles.
KAGGLE_WIKITEXT2_HANDLE = None
CANDIDATE_HANDLES = [
    'jboysen/wikitext2',
    'jboysen/wikitext-2',
    'mrityunjaybiswas/wikitext2',
    'mrityunjaybiswas/wikitext-2',
]

handles_to_try = [KAGGLE_WIKITEXT2_HANDLE] if KAGGLE_WIKITEXT2_HANDLE else CANDIDATE_HANDLES
last_err = None
raw_download_dir = None

for handle in handles_to_try:
    try:
        print(f"Trying Kaggle handle: {handle}")
        raw_download_dir = kagglehub.dataset_download(handle)
        print(f"✓ Downloaded with handle: {handle}")
        break
    except Exception as e:
        last_err = e
        print(f"✗ Failed: {handle} -> {e}")

if raw_download_dir is None:
    raise RuntimeError(
        "Could not download WikiText-2 via kagglehub. "
        "Set KAGGLE_WIKITEXT2_HANDLE to your dataset handle and rerun this cell.
"
        f"Last error: {last_err}"
    )

# Find the directory that actually contains WikiText-2 text files.
required = {'wiki.train.tokens', 'wiki.valid.tokens', 'wiki.test.tokens'}
WIKITEXT2_DIR = None
for root, _, files in os.walk(raw_download_dir):
    if required.issubset(set(files)):
        WIKITEXT2_DIR = root
        break

if WIKITEXT2_DIR is None:
    raise FileNotFoundError(
        "Downloaded dataset did not contain expected files: "
        "wiki.train.tokens, wiki.valid.tokens, wiki.test.tokens"
    )

print(f"✓ WikiText-2 directory: {WIKITEXT2_DIR}")
for name in sorted(required):
    p = os.path.join(WIKITEXT2_DIR, name)
    size_mb = os.path.getsize(p) / 1024 / 1024
    print(f"  - {name} ({size_mb:.2f} MB)")


In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


## 4. Import Modules & Configuration

In [None]:
import os
import sys
import json
import time
import math
import numpy as np
import matplotlib.pyplot as plt

# Add project directory to path
sys.path.insert(0, PROJECT_DIR)

from vocab import Vocabulary
from tokenizer import NL2BashTokenizer
from dataset import NL2BashDataset, create_pytorch_dataloader

print("✓ All modules imported successfully")


In [None]:
# Configuration
PYTORCH_CONFIG = {
    'd_model': 128,
    'n_heads': 1,
    'd_ff': 512,
    'batch_size': 16,
    'lr': 0.0001,
    'n_epochs': 20,
    'depths': [1, 2, 4, 8, 16],
    'dropout': 0.0,
    'max_len': 128,
}

# Output directory (use Drive path if mounted)
OUTPUT_DIR = os.path.join(PROJECT_DIR, 'experiment_outputs')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("✓ Configuration loaded")
print(f"  PyTorch depths: {PYTORCH_CONFIG['depths']}")


## 5. Load Data


In [None]:
def load_data(data_dir='data'):
    """Load vocabulary, tokenizer, and dataset (legacy NL2Bash path)."""
    data_dir = os.path.join(PROJECT_DIR, data_dir)
    vocab_path = os.path.join(data_dir, 'shared_vocab.txt')
    vocab = Vocabulary.load(vocab_path)
    tokenizer = NL2BashTokenizer(vocab)

    nl_file = os.path.join(data_dir, 'all.nl')
    cm_file = os.path.join(data_dir, 'all.cm')
    dataset = NL2BashDataset(nl_file, cm_file, tokenizer)

    train_dataset, dev_dataset, test_dataset = dataset.split(
        train_ratio=10, dev_ratio=1, test_ratio=1, seed=42
    )
    return vocab, tokenizer, train_dataset, dev_dataset, test_dataset

print(f"WikiText-2 files are available at: {WIKITEXT2_DIR}")
print("If your branch now uses WikiText-2 loaders, replace this cell with that loader logic.")
print("For now, the code below still uses the legacy NL2Bash loader.")

print("Loading data...")
vocab, tokenizer, train_dataset, dev_dataset, test_dataset = load_data()

print(f"✓ Data loaded")
print(f"  Vocab size: {len(vocab)}")
print(f"  Train examples: {len(train_dataset)}")
print(f"  Dev examples: {len(dev_dataset)}")
print(f"  Test examples: {len(test_dataset)}")


## 6. Helper Functions

In [None]:
import torch

def compute_exact_match_pytorch(model, dataset, tokenizer, device, num_samples=100):
    """Generate commands and compute exact match accuracy (PyTorch)."""
    model.eval()
    correct = 0
    total = min(num_samples, len(dataset))

    with torch.no_grad():
        for i in range(total):
            ex = dataset[i]
            nl_length = ex['nl_length']
            nl_ids = ex['combined_ids'][:nl_length]
            nl_tokens = torch.LongTensor([nl_ids]).to(device)

            generated = model.generate(
                nl_tokens,
                start_id=tokenizer.vocab.start_id,
                end_id=tokenizer.vocab.end_id,
                max_len=64
            )

            generated_ids = generated[0].cpu().tolist()
            generated_text = tokenizer.decode_cm(generated_ids, skip_special_tokens=True)

            if generated_text.strip() == ex['cm_text'].strip():
                correct += 1

    return correct / total if total > 0 else 0.0

print("✓ PyTorch helper functions defined")

In [None]:
def plot_depth_comparison(results, framework_name, config, output_path):
    """Plot train and dev loss curves for each depth."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    fig.suptitle(f'{framework_name} — Depth Comparison (d_model={config["d_model"]})', fontsize=14)

    colors = plt.cm.viridis(np.linspace(0, 0.9, len(results)))

    for (depth, data), color in zip(sorted(results.items()), colors):
        epochs = list(range(1, len(data['train_losses']) + 1))
        label = f'n_layers={depth} ({data["n_params"]:,} params)'

        ax1.plot(epochs, data['train_losses'], color=color, label=label, marker='o', markersize=3)
        ax2.plot(epochs, data['dev_losses'], color=color, label=label, marker='o', markersize=3)

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend(fontsize=8)
    ax1.grid(True, alpha=0.3)

    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Dev Loss')
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved plot to {output_path}")

print("✓ Plotting functions defined")

In [None]:
def plot_final_bar_chart(results, framework_name, config, output_path):
    """Bar chart of final train/dev/test loss and exact match per depth."""
    depths = sorted(results.keys())
    train_vals = [results[d]['final_train_loss'] for d in depths]
    dev_vals = [results[d]['final_dev_loss'] for d in depths]
    test_vals = [results[d]['final_test_loss'] for d in depths]
    em_vals = [results[d]['exact_match'] for d in depths]

    x = np.arange(len(depths))
    width = 0.22

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    fig.suptitle(f'{framework_name} — Final Metrics by Depth (d_model={config["d_model"]})', fontsize=14)

    # Loss bar chart
    ax1.bar(x - width, train_vals, width, label='Train', color='#2196F3')
    ax1.bar(x, dev_vals, width, label='Dev', color='#FF9800')
    ax1.bar(x + width, test_vals, width, label='Test', color='#4CAF50')
    ax1.set_xlabel('n_layers')
    ax1.set_ylabel('Loss')
    ax1.set_title('Final Loss')
    ax1.set_xticks(x)
    ax1.set_xticklabels([str(d) for d in depths])
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')

    # Exact match bar chart
    bars = ax2.bar(x, [v * 100 for v in em_vals], width * 2, color='#9C27B0')
    ax2.set_xlabel('n_layers')
    ax2.set_ylabel('Exact Match (%)')
    ax2.set_title('Test Exact Match Accuracy')
    ax2.set_xticks(x)
    ax2.set_xticklabels([str(d) for d in depths])
    ax2.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar, val in zip(bars, em_vals):
        ax2.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
                 f'{val:.1%}', ha='center', va='bottom', fontsize=9)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Saved bar chart to {output_path}")

print("✓ Bar chart functions defined")

In [None]:
def save_results(results, output_path):
    """Save results dict to JSON."""
    serializable = {}
    for depth, data in results.items():
        serializable[str(depth)] = {
            'train_losses': [float(x) for x in data['train_losses']],
            'dev_losses': [float(x) for x in data['dev_losses']],
            'final_train_loss': float(data['final_train_loss']),
            'final_dev_loss': float(data['final_dev_loss']),
            'final_test_loss': float(data['final_test_loss']),
            'exact_match': float(data['exact_match']),
            'n_params': int(data['n_params']),
        }
    with open(output_path, 'w') as f:
        json.dump(serializable, f, indent=2)
    print(f"Saved results to {output_path}")

print("✓ Results saving functions defined")

## 7. Run PyTorch Experiment

⚠️ This takes 30-60 minutes with GPU acceleration. You can reduce the `depths` list or `n_epochs` for faster runs.

In [None]:
# Optional: Reduce config for faster testing
# PYTORCH_CONFIG['depths'] = [1]  # Test only depth 1
# PYTORCH_CONFIG['n_epochs'] = 2  # Test only 2 epochs

print("Ready to run PyTorch experiment")
print(f"Depths: {PYTORCH_CONFIG['depths']}")
print(f"Epochs: {PYTORCH_CONFIG['n_epochs']}")
print(f"Batch size: {PYTORCH_CONFIG['batch_size']}")

In [None]:
# Import PyTorch training functions
import torch
import torch.optim as optim
from model.transformer_full import TransformerDecoder
from main import compute_masked_loss, train_epoch, evaluate

def run_pytorch_experiment(vocab, tokenizer, train_dataset, dev_dataset, test_dataset, config):
    """Train PyTorch models at each depth and return loss histories + final metrics."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nPyTorch device: {device}")

    results = {}

    for depth in config['depths']:
        print(f"\n{'='*60}")
        print(f"PyTorch: Training depth={depth}")
        print(f"{'='*60}")

        model = TransformerDecoder(
            vocab_size=len(vocab),
            d_model=config['d_model'],
            n_heads=config['n_heads'],
            n_layers=depth,
            d_ff=config['d_ff'],
            dropout=config['dropout'],
            max_len=config['max_len'],
        ).to(device)

        n_params = sum(p.numel() for p in model.parameters())
        print(f"  Parameters: {n_params:,}")

        optimizer = optim.Adam(model.parameters(), lr=config['lr'])

        train_loader = create_pytorch_dataloader(
            train_dataset, config['batch_size'], shuffle=True, pad_id=vocab.pad_id
        )
        dev_loader = create_pytorch_dataloader(
            dev_dataset, config['batch_size'], shuffle=False, pad_id=vocab.pad_id
        )
        test_loader = create_pytorch_dataloader(
            test_dataset, config['batch_size'], shuffle=False, pad_id=vocab.pad_id
        )

        train_losses = []
        dev_losses = []

        for epoch in range(1, config['n_epochs'] + 1):
            train_metrics = train_epoch(model, train_loader, optimizer, device, epoch, log_interval=9999)
            dev_metrics = evaluate(model, dev_loader, device)

            train_losses.append(train_metrics['loss'])
            dev_losses.append(dev_metrics['loss'])

            print(f"  Epoch {epoch:2d}/{config['n_epochs']} | "
                  f"Train loss: {train_metrics['loss']:.4f} | "
                  f"Dev loss: {dev_metrics['loss']:.4f} | "
                  f"Time: {train_metrics['time']:.1f}s")

        # Final evaluation on test set
        test_metrics = evaluate(model, test_loader, device)
        print(f"\n  Test loss: {test_metrics['loss']:.4f}")

        # Exact match accuracy on test set
        print(f"  Computing exact match on test set...")
        exact_match = compute_exact_match_pytorch(model, test_dataset, tokenizer, device, num_samples=100)
        print(f"  Exact match accuracy: {exact_match:.2%}")

        results[depth] = {
            'train_losses': train_losses,
            'dev_losses': dev_losses,
            'final_train_loss': train_losses[-1],
            'final_dev_loss': dev_losses[-1],
            'final_test_loss': test_metrics['loss'],
            'exact_match': exact_match,
            'n_params': n_params,
        }

        print(f"\n  >> depth={depth}: "
              f"train={train_losses[-1]:.4f} | "
              f"dev={dev_losses[-1]:.4f} | "
              f"test={test_metrics['loss']:.4f} | "
              f"exact_match={exact_match:.1%}")

    return results

print("✓ PyTorch experiment function defined")

In [None]:
# Run the experiment
print("="*60)
print("PYTORCH EXPERIMENT")
print("="*60)

pytorch_results = run_pytorch_experiment(
    vocab, tokenizer, train_dataset, dev_dataset, test_dataset, PYTORCH_CONFIG
)

In [None]:
# Save and visualize results
save_results(pytorch_results, os.path.join(OUTPUT_DIR, 'pytorch_results.json'))

plot_depth_comparison(
    pytorch_results, 'PyTorch', PYTORCH_CONFIG,
    os.path.join(OUTPUT_DIR, 'pytorch_depth_comparison.png')
)

plot_final_bar_chart(
    pytorch_results, 'PyTorch', PYTORCH_CONFIG,
    os.path.join(OUTPUT_DIR, 'pytorch_final_bar_chart.png')
)

## 8. Summary & Next Steps


In [None]:
print('\n' + '='*60)
print('EXPERIMENT COMPLETE')
print('='*60)
print(f'\nResults saved to: {OUTPUT_DIR}')
print('\nGenerated files:')
for f in os.listdir(OUTPUT_DIR):
    path = os.path.join(OUTPUT_DIR, f)
    if os.path.isfile(path):
        size = os.path.getsize(path) / 1024
        print(f'  - {f} ({size:.1f} KB)')

print('\n✓ PyTorch experiment completed')


In [None]:
# Optional: copy results to Google Drive (if mounted)
# import shutil
# drive_dir = '/content/drive/MyDrive/wikitext2_results'
# shutil.copytree(OUTPUT_DIR, drive_dir, dirs_exist_ok=True)
# print(f"Results copied to Drive: {drive_dir}")

print("To copy results to Drive, mount Drive and uncomment this cell.")
