# WikiText-2 Depth Experiment - Colab Notebook (PyTorch)

This notebook runs decoder-only Transformer depth experiments on WikiText-2.

What this notebook does:
- installs dependencies
- clones the project directly from GitHub (no ZIP upload)
- downloads WikiText-2 via `kagglehub`
- builds a simple language modeling dataset
- trains PyTorch models at multiple depths and compares metrics


## 1. Setup and Dependencies


In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q matplotlib numpy kagglehub

print('Dependencies installed')


In [None]:
# Optional: mount Google Drive for persistent outputs
MOUNT_DRIVE = False

if MOUNT_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    print('Google Drive mounted')
else:
    print('Drive mount skipped (set MOUNT_DRIVE=True to enable)')


In [None]:
# Clone project from GitHub (no manual ZIP upload required)
import os
import shutil
import subprocess

REPO_URL = 'https://github.com/ng3gn/ground-up-vla.git'
REPO_BRANCH = 'main'
PROJECT_SUBDIR = '03-wikitext2'
CHECKOUT_DIR = '/content/ground-up-vla'

if os.path.exists(CHECKOUT_DIR):
    shutil.rmtree(CHECKOUT_DIR)

subprocess.run(['git', 'clone', '--depth', '1', '--branch', REPO_BRANCH, REPO_URL, CHECKOUT_DIR], check=True)
PROJECT_DIR = os.path.join(CHECKOUT_DIR, PROJECT_SUBDIR)
os.chdir(PROJECT_DIR)

print(f'Repo cloned to: {CHECKOUT_DIR}')
print(f'Working directory: {PROJECT_DIR}')


## 2. Download WikiText-2


In [None]:
import os
import kagglehub

WIKITEXT2_DIR = kagglehub.dataset_download('vivekmettu/wikitext2-data')

required = {'wiki.train.tokens', 'wiki.valid.tokens', 'wiki.test.tokens'}
if not required.issubset(set(os.listdir(WIKITEXT2_DIR))):
    # Some Kaggle datasets place files in a nested subfolder.
    found = None
    for root, _, files in os.walk(WIKITEXT2_DIR):
        if required.issubset(set(files)):
            found = root
            break
    if found is None:
        raise FileNotFoundError('Could not find wiki.train/valid/test.tokens in downloaded dataset')
    WIKITEXT2_DIR = found

print('Path to dataset files:', WIKITEXT2_DIR)
for fname in sorted(required):
    path = os.path.join(WIKITEXT2_DIR, fname)
    size_mb = os.path.getsize(path) / 1024 / 1024
    print(f'  - {fname} ({size_mb:.2f} MB)')


In [None]:
import torch

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')


## 3. Imports and Configuration


In [None]:
import os
import json
import time
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Reuse project model implementation
import sys
sys.path.insert(0, PROJECT_DIR)
from model.transformer_full import TransformerDecoder

print('Modules imported')


In [None]:
CONFIG = {
    'd_model': 256,
    'n_heads': 4,
    'd_ff': 1024,
    'dropout': 0.1,
    'max_len': 128,
    'batch_size': 32,
    'lr': 3e-4,
    'weight_decay': 0.01,
    'n_epochs': 8,
    'depths': [1, 2, 4, 8],
    'max_vocab_size': 20000,
    'min_freq': 2,
    'max_train_tokens': 1200000,
    'grad_clip': 1.0,
}

OUTPUT_DIR = os.path.join(PROJECT_DIR, 'experiment_outputs_wikitext2')
os.makedirs(OUTPUT_DIR, exist_ok=True)

print('Configuration loaded')
print('Depths:', CONFIG['depths'])
print('Output dir:', OUTPUT_DIR)


## 4. Build WikiText-2 Dataset


In [None]:
def read_wikitext_split(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = [line.strip() for line in f]
    return [line for line in lines if line]


def build_vocab(train_lines, max_vocab_size=20000, min_freq=2):
    counter = Counter()
    for line in train_lines:
        counter.update(line.split())

    specials = ['<pad>', '<unk>', '<eos>']
    tokens = [tok for tok, freq in counter.most_common() if freq >= min_freq]
    tokens = tokens[: max_vocab_size - len(specials)]
    itos = specials + tokens
    stoi = {tok: i for i, tok in enumerate(itos)}
    return stoi, itos


def encode_lines(lines, stoi):
    unk_id = stoi['<unk>']
    eos_id = stoi['<eos>']
    ids = []
    for line in lines:
        ids.extend(stoi.get(tok, unk_id) for tok in line.split())
        ids.append(eos_id)
    return ids


def make_lm_blocks(token_ids, seq_len):
    n_blocks = (len(token_ids) - 1) // seq_len
    if n_blocks <= 0:
        raise ValueError('Not enough tokens to create at least one block')

    trim = n_blocks * seq_len + 1
    arr = torch.tensor(token_ids[:trim], dtype=torch.long)
    x = arr[:-1].view(n_blocks, seq_len)
    y = arr[1:].view(n_blocks, seq_len)
    return x, y


train_lines = read_wikitext_split(os.path.join(WIKITEXT2_DIR, 'wiki.train.tokens'))
valid_lines = read_wikitext_split(os.path.join(WIKITEXT2_DIR, 'wiki.valid.tokens'))
test_lines = read_wikitext_split(os.path.join(WIKITEXT2_DIR, 'wiki.test.tokens'))

stoi, itos = build_vocab(
    train_lines,
    max_vocab_size=CONFIG['max_vocab_size'],
    min_freq=CONFIG['min_freq'],
)

train_ids = encode_lines(train_lines, stoi)
valid_ids = encode_lines(valid_lines, stoi)
test_ids = encode_lines(test_lines, stoi)

if CONFIG['max_train_tokens'] is not None:
    train_ids = train_ids[:CONFIG['max_train_tokens']]

train_x, train_y = make_lm_blocks(train_ids, CONFIG['max_len'])
valid_x, valid_y = make_lm_blocks(valid_ids, CONFIG['max_len'])
test_x, test_y = make_lm_blocks(test_ids, CONFIG['max_len'])

train_loader = DataLoader(TensorDataset(train_x, train_y), batch_size=CONFIG['batch_size'], shuffle=True)
valid_loader = DataLoader(TensorDataset(valid_x, valid_y), batch_size=CONFIG['batch_size'], shuffle=False)
test_loader = DataLoader(TensorDataset(test_x, test_y), batch_size=CONFIG['batch_size'], shuffle=False)

print('WikiText-2 prepared')
print('Vocab size:', len(itos))
print('Train blocks:', len(train_x))
print('Valid blocks:', len(valid_x))
print('Test blocks:', len(test_x))


## 5. Helper Functions


In [None]:
def evaluate_lm(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    total_correct = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)

            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='sum')
            total_loss += loss.item()

            preds = logits.argmax(dim=-1)
            total_correct += (preds == y).sum().item()
            total_tokens += y.numel()

    avg_loss = total_loss / max(total_tokens, 1)
    ppl = float(np.exp(avg_loss))
    acc = total_correct / max(total_tokens, 1)

    return {'loss': avg_loss, 'perplexity': ppl, 'token_acc': acc}


def train_epoch_lm(model, dataloader, optimizer, device, grad_clip=1.0):
    model.train()
    total_loss = 0.0
    total_tokens = 0
    t0 = time.time()

    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), reduction='mean')
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        total_loss += loss.item() * y.numel()
        total_tokens += y.numel()

    avg_loss = total_loss / max(total_tokens, 1)
    return {
        'loss': avg_loss,
        'perplexity': float(np.exp(avg_loss)),
        'time': time.time() - t0,
    }


def generate_text(model, prompt, stoi, itos, device, max_len_ctx=128, max_new_tokens=40, temperature=1.0):
    model.eval()
    unk_id = stoi['<unk>']
    eos_id = stoi['<eos>']

    ids = [stoi.get(tok, unk_id) for tok in prompt.split()]
    if not ids:
        ids = [eos_id]

    x = torch.tensor([ids], dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_new_tokens):
            if x.size(1) > max_len_ctx:
                x = x[:, -max_len_ctx:]

            logits = model(x)
            next_logits = logits[:, -1, :] / max(temperature, 1e-5)
            probs = torch.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            x = torch.cat([x, next_id], dim=1)

            if next_id.item() == eos_id:
                break

    out_ids = x[0].tolist()
    out_toks = [itos[i] if 0 <= i < len(itos) else '<unk>' for i in out_ids]
    return ' '.join(tok for tok in out_toks if tok != '<eos>')


print('Helper functions defined')


In [None]:
def plot_depth_comparison(results, config, output_path):
    depths = sorted(results.keys())
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(f"PyTorch WikiText-2 Depth Comparison (d_model={config['d_model']})", fontsize=14)

    colors = plt.cm.viridis(np.linspace(0, 0.9, len(depths)))

    for depth, color in zip(depths, colors):
        data = results[depth]
        epochs = list(range(1, len(data['train_losses']) + 1))
        label = f"layers={depth} ({data['n_params']:,} params)"

        axes[0].plot(epochs, data['train_losses'], color=color, marker='o', markersize=3, label=label)
        axes[1].plot(epochs, data['valid_losses'], color=color, marker='o', markersize=3, label=label)
        axes[2].plot(epochs, data['valid_ppls'], color=color, marker='o', markersize=3, label=label)

    axes[0].set_title('Train Loss')
    axes[1].set_title('Valid Loss')
    axes[2].set_title('Valid Perplexity')

    for ax in axes:
        ax.set_xlabel('Epoch')
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)

    axes[0].set_ylabel('Loss')
    axes[1].set_ylabel('Loss')
    axes[2].set_ylabel('Perplexity')

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved plot to', output_path)


print('Plotting function defined')


In [None]:
def plot_final_metrics(results, output_path):
    depths = sorted(results.keys())

    test_losses = [results[d]['test_loss'] for d in depths]
    test_ppls = [results[d]['test_perplexity'] for d in depths]
    test_accs = [results[d]['test_token_acc'] for d in depths]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    axes[0].bar([str(d) for d in depths], test_losses)
    axes[0].set_title('Test Loss by Depth')
    axes[0].set_xlabel('Depth')
    axes[0].set_ylabel('Loss')

    axes[1].bar([str(d) for d in depths], test_ppls)
    axes[1].set_title('Test Perplexity by Depth')
    axes[1].set_xlabel('Depth')
    axes[1].set_ylabel('Perplexity')

    axes[2].bar([str(d) for d in depths], test_accs)
    axes[2].set_title('Test Token Accuracy by Depth')
    axes[2].set_xlabel('Depth')
    axes[2].set_ylabel('Accuracy')

    for ax in axes:
        ax.grid(True, axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved plot to', output_path)


print('Final metric plotting function defined')


In [None]:
def save_results(results, output_path):
    serializable = {}
    for depth, data in results.items():
        serializable[str(depth)] = {}
        for k, v in data.items():
            if isinstance(v, (np.floating, np.integer)):
                serializable[str(depth)][k] = float(v)
            else:
                serializable[str(depth)][k] = v

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(serializable, f, indent=2)

    print('Saved results to', output_path)


print('Result saver defined')


## 6. Run PyTorch Experiment


In [None]:
# Optional: reduce runtime for quick checks
# CONFIG['depths'] = [1, 2]
# CONFIG['n_epochs'] = 2
# CONFIG['max_train_tokens'] = 300000


In [None]:
def run_pytorch_experiment(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device:', device)

    results = {}

    for depth in config['depths']:
        print('\n' + '=' * 60)
        print(f'Training depth={depth}')
        print('=' * 60)

        model = TransformerDecoder(
            vocab_size=len(itos),
            d_model=config['d_model'],
            n_heads=config['n_heads'],
            n_layers=depth,
            d_ff=config['d_ff'],
            dropout=config['dropout'],
            max_len=config['max_len'],
        ).to(device)

        n_params = sum(p.numel() for p in model.parameters())
        print(f'Parameters: {n_params:,}')

        optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

        train_losses = []
        valid_losses = []
        valid_ppls = []

        for epoch in range(1, config['n_epochs'] + 1):
            train_metrics = train_epoch_lm(
                model,
                train_loader,
                optimizer,
                device,
                grad_clip=config['grad_clip'],
            )
            valid_metrics = evaluate_lm(model, valid_loader, device)

            train_losses.append(train_metrics['loss'])
            valid_losses.append(valid_metrics['loss'])
            valid_ppls.append(valid_metrics['perplexity'])

            print(
                f"Epoch {epoch:2d}/{config['n_epochs']} | "
                f"train loss {train_metrics['loss']:.4f} | "
                f"valid loss {valid_metrics['loss']:.4f} | "
                f"valid ppl {valid_metrics['perplexity']:.2f} | "
                f"time {train_metrics['time']:.1f}s"
            )

        test_metrics = evaluate_lm(model, test_loader, device)
        print(
            f"Test loss {test_metrics['loss']:.4f} | "
            f"Test ppl {test_metrics['perplexity']:.2f} | "
            f"Test token acc {test_metrics['token_acc']:.2%}"
        )

        sample = generate_text(
            model,
            prompt='the meaning of life is',
            stoi=stoi,
            itos=itos,
            device=device,
            max_len_ctx=config['max_len'],
            max_new_tokens=30,
            temperature=1.0,
        )
        print('Sample generation:', sample)

        results[depth] = {
            'n_params': n_params,
            'train_losses': train_losses,
            'valid_losses': valid_losses,
            'valid_ppls': valid_ppls,
            'test_loss': test_metrics['loss'],
            'test_perplexity': test_metrics['perplexity'],
            'test_token_acc': test_metrics['token_acc'],
            'sample_text': sample,
        }

    return results


print('Experiment runner defined')


In [None]:
print('=' * 60)
print('PYTORCH WIKITEXT-2 EXPERIMENT')
print('=' * 60)

results = run_pytorch_experiment(CONFIG)


In [None]:
save_results(results, os.path.join(OUTPUT_DIR, 'pytorch_wikitext2_results.json'))

plot_depth_comparison(
    results,
    CONFIG,
    os.path.join(OUTPUT_DIR, 'pytorch_wikitext2_depth_curves.png'),
)

plot_final_metrics(
    results,
    os.path.join(OUTPUT_DIR, 'pytorch_wikitext2_final_metrics.png'),
)


## 7. Summary


In [None]:
print('\n' + '=' * 60)
print('EXPERIMENT COMPLETE')
print('=' * 60)
print('\nResults saved to:', OUTPUT_DIR)
print('\nGenerated files:')
for f in sorted(os.listdir(OUTPUT_DIR)):
    path = os.path.join(OUTPUT_DIR, f)
    if os.path.isfile(path):
        size_kb = os.path.getsize(path) / 1024
        print(f'  - {f} ({size_kb:.1f} KB)')


In [None]:
# Optional: copy results to Google Drive (if mounted)
# import shutil
# drive_dir = '/content/drive/MyDrive/wikitext2_results'
# shutil.copytree(OUTPUT_DIR, drive_dir, dirs_exist_ok=True)
# print(f'Results copied to Drive: {drive_dir}')

print('To copy results to Drive, mount Drive and uncomment this cell.')
