In [None]:
"""
Transformer Scaling Study - Complete Training Pipeline
Based on nanoGPT architecture with modifications for music data

CS-GY 6923 Scaling Laws Project
"""

import os
import json
import math
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict, List

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# =============================================================================
# MODEL ARCHITECTURE (Based on nanoGPT)
# =============================================================================

@dataclass
class GPTConfig:
    """GPT model configuration"""
    vocab_size: int = 256
    block_size: int = 256  # context length
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1
    bias: bool = False

    def get_param_count(self):
        """Estimate parameter count"""
        # Embedding
        params = self.vocab_size * self.n_embd
        # Transformer blocks
        params += self.n_layer * (
            4 * self.n_embd * self.n_embd +  # attention
            8 * self.n_embd * self.n_embd     # FFN (4x expansion)
        )
        # Output projection
        params += self.vocab_size * self.n_embd
        return params


class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # Key, query, value projections
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                            .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()  # batch, sequence length, embedding dim

        # Calculate query, key, values
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    """Feed-forward network"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    """Transformer block"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    """GPT Language Model"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying
        self.transformer.wte.weight = self.lm_head.weight

        # Initialize weights
        self.apply(self._init_weights)

        # Report number of parameters
        print(f"Model parameters: {self.get_num_params():,}")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"

        pos = torch.arange(0, t, dtype=torch.long, device=device)

        # Forward pass
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    def get_num_params(self):
        """Count total parameters"""
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens"""
        for _ in range(max_new_tokens):
            # Crop to block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # Forward pass
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            # Top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # Sample
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx


# =============================================================================
# MODEL CONFIGURATIONS FOR SCALING STUDY
# =============================================================================

MODEL_CONFIGS = {
    'tiny': GPTConfig(
        vocab_size=256,
        block_size=256,
        n_layer=2,
        n_head=2,
        n_embd=64,
        dropout=0.1
    ),
    'small': GPTConfig(
        vocab_size=256,
        block_size=256, # context
        n_layer=4, # nl
        n_head=4, # nh
        n_embd=128, # dmodel
        dropout=0.1
    ),
    'medium': GPTConfig(
        vocab_size=256,
        block_size=256,
        n_layer=6,
        n_head=6,
        n_embd=252,
        dropout=0.1
    ),
    'large': GPTConfig(
        vocab_size=256,
        block_size=256,
        n_layer=8,
        n_head=8,
        n_embd=384,
        dropout=0.1
    ),
    'xl': GPTConfig(
        vocab_size=256,
        block_size=256,
        n_layer=12,
        n_head=12,
        n_embd=516,
        dropout=0.1
    ),
}


# =============================================================================
# DATASET
# =============================================================================

class MusicDataset(Dataset):
    """Music text dataset"""

    def __init__(self, data_path: str, vocab_path: str, block_size: int):
        self.block_size = block_size

        # Load vocabulary
        with open(vocab_path, 'r') as f:
            vocab = json.load(f)

        self.vocab = vocab
        self.char_to_idx = {ch: i for i, ch in enumerate(vocab)}
        self.idx_to_char = {i: ch for i, ch in enumerate(vocab)}

        # Load data
        print(f"Loading data from {data_path}...")
        with open(data_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Tokenize
        print("Tokenizing...")
        self.tokens = self.encode(text)
        print(f"Dataset has {len(self.tokens):,} tokens")

    def encode(self, text: str) -> List[int]:
        """Encode text to token indices"""
        return [self.char_to_idx.get(ch, self.char_to_idx['<UNK>']) for ch in text]

    def decode(self, indices: List[int]) -> str:
        """Decode token indices to text"""
        return ''.join([self.idx_to_char[i] for i in indices])

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        # Get chunk of data
        chunk = self.tokens[idx:idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y


# =============================================================================
# TRAINING UTILITIES
# =============================================================================

@dataclass
class TrainingConfig:
    """Training configuration"""
    # Data
    data_dir: str = "/content/drive/MyDrive/processed_music_data"
    vocab_file: str = "vocab.json"

    # Training
    batch_size: int = 32
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0
    warmup_steps: int = 2000
    max_epochs: int = 1

    # Evaluation
    eval_interval: int = 500
    eval_iters: int = 100

    # System
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    compile_model: bool = False  # torch.compile for speed

    # Logging
    output_dir: str = "/content/drive/MyDrive/models/checkpoints"


def get_lr(step: int, config: TrainingConfig, total_steps: int):
    """Learning rate schedule with warmup and cosine decay"""
    # Warmup
    if step < config.warmup_steps:
        return config.learning_rate * step / config.warmup_steps
    # Cosine decay
    progress = (step - config.warmup_steps) / (total_steps - config.warmup_steps)
    return config.learning_rate * 0.1 + 0.9 * config.learning_rate * (1 + math.cos(math.pi * progress)) / 2


@torch.no_grad()
def estimate_loss(model, eval_loader, eval_iters, device):
    """Estimate loss on evaluation set"""
    model.eval()
    losses = []

    for i, (x, y) in enumerate(eval_loader):
        if i >= eval_iters:
            break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        losses.append(loss.item())

    model.train()
    return np.mean(losses)


def train_epoch(
    model: GPT,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    config: TrainingConfig,
    epoch: int,
    total_steps: int
):
    """Train for one epoch"""
    model.train()

    losses = []
    train_history = []
    step = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(config.device), y.to(config.device)

        # Forward pass
        _, loss = model(x, y)

        # Backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

        # Update weights
        optimizer.step()

        # Update learning rate
        lr = get_lr(step, config, total_steps)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Track metrics
        losses.append(loss.item())
        train_history.append({
            'step': step,
            'loss': loss.item(),
            'lr': lr
        })

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{lr:.2e}'
        })

        # Periodic evaluation
        if step % config.eval_interval == 0 and step > 0:
            val_loss = estimate_loss(model, val_loader, config.eval_iters, config.device)
            print(f"\nStep {step}: val_loss = {val_loss:.4f}")

        step += 1

    return np.mean(losses), train_history


def train_model(
    model_name: str,
    config: GPTConfig,
    train_config: TrainingConfig
):
    """Complete training pipeline for one model"""

    print(f"\n{'='*70}")
    print(f"Training {model_name.upper()} model")
    print(f"{'='*70}")

    # Update vocab size from data
    vocab_path = Path(train_config.data_dir) / train_config.vocab_file
    with open(vocab_path, 'r') as f:
        vocab = json.load(f)
    config.vocab_size = len(vocab)

    # Create datasets
    train_dataset = MusicDataset(
        Path(train_config.data_dir) / "train.txt",
        vocab_path,
        config.block_size
    )

    val_dataset = MusicDataset(
        Path(train_config.data_dir) / "val.txt",
        vocab_path,
        config.block_size
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=train_config.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    # Create model
    model = GPT(config).to(train_config.device)

    # Count parameters
    n_params = model.get_num_params()

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=train_config.learning_rate,
        betas=(train_config.beta1, train_config.beta2),
        weight_decay=train_config.weight_decay
    )

    # Calculate total steps
    total_steps = len(train_loader) * train_config.max_epochs

    # Training tracking
    start_time = time.time()
    all_train_history = []

    # Train
    for epoch in range(train_config.max_epochs):
        train_loss, train_history = train_epoch(
            model, train_loader, val_loader, optimizer,
            train_config, epoch + 1, total_steps
        )
        all_train_history.extend(train_history)

    # Final evaluation
    print("\nFinal evaluation...")
    val_loss = estimate_loss(model, val_loader, len(val_loader), train_config.device)

    # Calculate metrics
    training_time = time.time() - start_time

    # Get GPU memory if available
    gpu_memory = 0
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB

    results = {
        'model_name': model_name,
        'n_params': n_params,
        'config': asdict(config),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'training_time': training_time,
        'gpu_memory_gb': gpu_memory,
        'train_history': all_train_history
    }

    # Save results
    output_dir = Path(train_config.output_dir) / model_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save model
    torch.save(model.state_dict(), output_dir / "model.pt")

    # Save config
    with open(output_dir / "config.json", 'w') as f:
        json.dump(asdict(config), f, indent=2)

    # Save results
    with open(output_dir / "results.json", 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n{'='*70}")
    print(f"TRAINING COMPLETE: {model_name}")
    print(f"{'='*70}")
    print(f"Parameters:     {n_params:,}")
    print(f"Train loss:     {train_loss:.4f}")
    print(f"Val loss:       {val_loss:.4f}")
    print(f"Training time:  {training_time/60:.1f} minutes")
    print(f"GPU memory:     {gpu_memory:.2f} GB")
    print(f"Saved to:       {output_dir}")
    print(f"{'='*70}\n")

    return results


# =============================================================================
# MAIN TRAINING SCRIPT
# =============================================================================

def main():
    """Train all models for scaling study"""

    # Configuration
    train_config = TrainingConfig(
        data_dir="/content/drive/MyDrive/processed_music_data",
        batch_size=32,
        learning_rate=3e-4,
        max_epochs=1,
        output_dir="/content/drive/MyDrive/models/checkpoints"
    )

    print(f"Device: {train_config.device}")
    print(f"Data directory: {train_config.data_dir}")
    print(f"Output directory: {train_config.output_dir}")

    # Train all models
    all_results = {}

    for model_name, config in MODEL_CONFIGS.items():
        results = train_model(model_name, config, train_config)
        all_results[model_name] = results

        # Clear GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Save combined results
    output_dir = Path(train_config.output_dir)
    output_file = output_dir / "all_results.json"

    # Step 1 — read existing dict if present
    if output_file.exists():
        with open(output_file, "r") as f:
            old_data = json.load(f)
    else:
        old_data = {}

    # Step 2 — ensure new data is dict
    if not isinstance(all_results, dict):
        raise ValueError("all_results must be a dictionary when prepending to an object.")

    # Step 3 — prepend: new keys first, then old keys
    combined = {**all_results, **old_data}

    # Step 4 — save back
    with open(output_file, "w") as f:
        json.dump(combined, f, indent=2)

    # with open(output_dir / "all_results.json", 'w') as f:
    #     json.dump(all_results, f, indent=2)

    # Print summary
    print("\n" + "="*70)
    print("ALL MODELS TRAINED - SUMMARY")
    print("="*70)
    print(f"{'Model':<10} {'Params':<12} {'Val Loss':<10} {'Time (min)':<12}")
    print("-"*70)
    for name, res in all_results.items():
        print(f"{name:<10} {res['n_params']:>11,} {res['val_loss']:>10.4f} {res['training_time']/60:>11.1f}")
    print("="*70)

    print(f"\nResults saved to: {output_dir / 'all_results.json'}")


if __name__ == "__main__":
    main()

In [None]:
"""
Scaling Law Analysis for Transformer Models
Analyzes training results and generates scaling plots

CS-GY 6923 Scaling Laws Project
"""

import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.optimize import curve_fit
from typing import Dict, List, Tuple

# =============================================================================
# SCALING LAW FITTING
# =============================================================================

def power_law(N, a, alpha, c):
    """
    Power law function: L = a * N^(-alpha) + c

    Args:
        N: Model parameter count
        a: Scaling coefficient
        alpha: Scaling exponent
        c: Irreducible loss offset
    """
    return a * np.power(N, -alpha) + c


def fit_scaling_law(param_counts: np.ndarray, losses: np.ndarray) -> Tuple[np.ndarray, Dict]:
    """
    Fit power law to scaling data

    Returns:
        fitted_params: [a, alpha, c]
        fit_info: Dictionary with fit statistics
    """
    # Initial guess for parameters
    p0 = [1.0, 0.1, 0.5]

    # Fit curve
    try:
        popt, pcov = curve_fit(
            power_law,
            param_counts,
            losses,
            p0=p0,
            maxfev=10000
        )

        # Calculate R-squared
        residuals = losses - power_law(param_counts, *popt)
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((losses - np.mean(losses))**2)
        r_squared = 1 - (ss_res / ss_tot)

        # Standard errors
        perr = np.sqrt(np.diag(pcov))

        fit_info = {
            'a': popt[0],
            'alpha': popt[1],
            'c': popt[2],
            'a_stderr': perr[0],
            'alpha_stderr': perr[1],
            'c_stderr': perr[2],
            'r_squared': r_squared
        }

        return popt, fit_info

    except Exception as e:
        print(f"Error fitting scaling law: {e}")
        return None, None


# =============================================================================
# PLOTTING FUNCTIONS
# =============================================================================

def plot_scaling_law(results_dict: Dict, output_dir: Path):
    """
    Create main scaling law plot with power law fit
    """
    # Extract data
    models = []
    param_counts = []
    val_losses = []

    for model_name, results in results_dict.items():
        models.append(model_name)
        param_counts.append(results['n_params'])
        val_losses.append(results['val_loss'])

    param_counts = np.array(param_counts)
    val_losses = np.array(val_losses)

    # Sort by parameter count
    sort_idx = np.argsort(param_counts)
    param_counts = param_counts[sort_idx]
    val_losses = val_losses[sort_idx]
    models = [models[i] for i in sort_idx]

    # Fit scaling law
    popt, fit_info = fit_scaling_law(param_counts, val_losses)

    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot data points
    ax.scatter(param_counts, val_losses, s=100, alpha=0.7,
               c='steelblue', edgecolors='black', linewidth=1.5,
               label='Observed', zorder=3)

    # Add model labels
    for i, (n, l, name) in enumerate(zip(param_counts, val_losses, models)):
        ax.annotate(name, (n, l),
                   xytext=(10, 5), textcoords='offset points',
                   fontsize=9, alpha=0.8)

    if popt is not None:
        # Generate smooth curve for fitted line
        N_smooth = np.logspace(np.log10(param_counts.min()),
                               np.log10(param_counts.max()),
                               100)
        L_fitted = power_law(N_smooth, *popt)

        ax.plot(N_smooth, L_fitted, 'r--', linewidth=2,
                label=f'Power Law Fit', alpha=0.8, zorder=2)

        # Add fit equation to plot
        eq_text = f'$L = {popt[0]:.3f} \\cdot N^{{-{popt[1]:.3f}}} + {popt[2]:.3f}$\n'
        eq_text += f'$\\alpha = {fit_info["alpha"]:.4f} \\pm {fit_info["alpha_stderr"]:.4f}$\n'
        eq_text += f'$R^2 = {fit_info["r_squared"]:.4f}$'

        ax.text(0.05, 0.95, eq_text, transform=ax.transAxes,
                fontsize=11, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))

    # Formatting
    ax.set_xscale('log')
    ax.set_xlabel('Number of Parameters (N)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Validation Loss (L)', fontsize=12, fontweight='bold')
    ax.set_title('Neural Scaling Law: Model Size vs. Performance',
                 fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=10, loc='upper right')

    plt.tight_layout()
    plt.savefig(output_dir / 'scaling_law.png', dpi=300, bbox_inches='tight')
    print(f"Saved scaling law plot to {output_dir / 'scaling_law.png'}")

    # Print fit results
    if fit_info is not None:
        print("\n" + "="*70)
        print("SCALING LAW FIT RESULTS")
        print("="*70)
        print(f"Power Law: L = a * N^(-α) + c")
        print(f"\nFitted Parameters:")
        print(f"  a (coefficient):      {fit_info['a']:.6f} ± {fit_info['a_stderr']:.6f}")
        print(f"  α (exponent):         {fit_info['alpha']:.6f} ± {fit_info['alpha_stderr']:.6f}")
        print(f"  c (irreducible loss): {fit_info['c']:.6f} ± {fit_info['c_stderr']:.6f}")
        print(f"\nGoodness of Fit:")
        print(f"  R² = {fit_info['r_squared']:.6f}")
        print(f"\nInterpretation:")
        print(f"  • Scaling exponent α = {fit_info['alpha']:.4f} indicates that")
        print(f"    validation loss decreases as N^(-{fit_info['alpha']:.4f})")
        print(f"  • To halve the loss, you need ~{2**(1/fit_info['alpha']):.1f}× more parameters")
        print(f"  • Irreducible loss c = {fit_info['c']:.4f} represents the")
        print(f"    theoretical minimum achievable loss")
        print("="*70 + "\n")

        # Save fit results
        with open(output_dir / 'scaling_law_fit.json', 'w') as f:
            json.dump(fit_info, f, indent=2)

    return fig, fit_info


def plot_training_curves(results_dict: Dict, output_dir: Path):
    """
    Plot training loss curves over time for all models
    """
    fig, ax = plt.subplots(figsize=(12, 6))

    colors = plt.cm.viridis(np.linspace(0, 1, len(results_dict)))

    for (model_name, results), color in zip(results_dict.items(), colors):
        if 'train_history' in results and results['train_history']:
            history = results['train_history']
            steps = [h['step'] for h in history]
            losses = [h['loss'] for h in history]

            # Plot with smoothing (moving average)
            window = max(1, len(losses) // 50)
            if window > 1:
                losses_smooth = np.convolve(losses, np.ones(window)/window, mode='valid')
                steps_smooth = steps[:len(losses_smooth)]
                ax.plot(steps_smooth, losses_smooth, label=f"{model_name} ({results['n_params']:,} params)",
                       linewidth=2, color=color, alpha=0.8)
            else:
                ax.plot(steps, losses, label=f"{model_name} ({results['n_params']:,} params)",
                       linewidth=2, color=color, alpha=0.8)

    ax.set_xlabel('Training Step', fontsize=12, fontweight='bold')
    ax.set_ylabel('Training Loss', fontsize=12, fontweight='bold')
    ax.set_title('Training Loss Curves by Model Size', fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=9, loc='best')

    plt.tight_layout()
    plt.savefig(output_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
    print(f"Saved training curves to {output_dir / 'training_curves.png'}")

    return fig


def plot_training_metrics(results_dict: Dict, output_dir: Path):
    """
    Plot training time and GPU memory usage
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Extract data
    models = []
    param_counts = []
    training_times = []
    gpu_memory = []

    for model_name, results in results_dict.items():
        models.append(model_name)
        param_counts.append(results['n_params'])
        training_times.append(results['training_time'] / 60)  # Convert to minutes
        gpu_memory.append(results.get('gpu_memory_gb', 0))

    # Sort by parameter count
    param_counts = np.array(param_counts)
    sort_idx = np.argsort(param_counts)
    param_counts = param_counts[sort_idx]
    training_times = [training_times[i] for i in sort_idx]
    gpu_memory = [gpu_memory[i] for i in sort_idx]
    models = [models[i] for i in sort_idx]

    # Plot 1: Training Time
    ax1.scatter(param_counts, training_times, s=100, alpha=0.7,
               c='coral', edgecolors='black', linewidth=1.5)
    for i, (n, t, name) in enumerate(zip(param_counts, training_times, models)):
        ax1.annotate(name, (n, t),
                    xytext=(10, 5), textcoords='offset points',
                    fontsize=9, alpha=0.8)

    ax1.set_xscale('log')
    ax1.set_xlabel('Number of Parameters (N)', fontsize=11, fontweight='bold')
    ax1.set_ylabel('Training Time (minutes)', fontsize=11, fontweight='bold')
    ax1.set_title('Wall-Clock Training Time vs. Model Size', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3, linestyle='--')

    # Plot 2: GPU Memory
    if any(m > 0 for m in gpu_memory):
        ax2.scatter(param_counts, gpu_memory, s=100, alpha=0.7,
                   c='mediumseagreen', edgecolors='black', linewidth=1.5)
        for i, (n, m, name) in enumerate(zip(param_counts, gpu_memory, models)):
            ax2.annotate(name, (n, m),
                        xytext=(10, 5), textcoords='offset points',
                        fontsize=9, alpha=0.8)

        ax2.set_xscale('log')
        ax2.set_xlabel('Number of Parameters (N)', fontsize=11, fontweight='bold')
        ax2.set_ylabel('Peak GPU Memory (GB)', fontsize=11, fontweight='bold')
        ax2.set_title('GPU Memory Usage vs. Model Size', fontsize=12, fontweight='bold')
        ax2.grid(True, alpha=0.3, linestyle='--')
    else:
        ax2.text(0.5, 0.5, 'No GPU memory data available',
                transform=ax2.transAxes, ha='center', va='center',
                fontsize=12, style='italic')
        ax2.set_xticks([])
        ax2.set_yticks([])

    plt.tight_layout()
    plt.savefig(output_dir / 'training_metrics.png', dpi=300, bbox_inches='tight')
    print(f"Saved training metrics to {output_dir / 'training_metrics.png'}")

    return fig


def plot_loss_comparison(results_dict: Dict, output_dir: Path):
    """
    Compare training vs validation loss across models
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    models = []
    param_counts = []
    train_losses = []
    val_losses = []

    for model_name, results in results_dict.items():
        models.append(model_name)
        param_counts.append(results['n_params'])
        train_losses.append(results['train_loss'])
        val_losses.append(results['val_loss'])

    # Sort by parameter count
    param_counts = np.array(param_counts)
    sort_idx = np.argsort(param_counts)
    param_counts = param_counts[sort_idx]
    train_losses = [train_losses[i] for i in sort_idx]
    val_losses = [val_losses[i] for i in sort_idx]
    models = [models[i] for i in sort_idx]

    # Plot
    x = np.arange(len(models))
    width = 0.35

    bars1 = ax.bar(x - width/2, train_losses, width, label='Training Loss',
                   color='skyblue', edgecolor='black', linewidth=1.2)
    bars2 = ax.bar(x + width/2, val_losses, width, label='Validation Loss',
                   color='lightcoral', edgecolor='black', linewidth=1.2)

    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=8)

    ax.set_xlabel('Model', fontsize=12, fontweight='bold')
    ax.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax.set_title('Training vs. Validation Loss by Model', fontsize=14, fontweight='bold', pad=20)
    ax.set_xticks(x)
    ax.set_xticklabels([f"{m}\n({p:,})" for m, p in zip(models, param_counts)], fontsize=9)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3, linestyle='--', axis='y')

    plt.tight_layout()
    plt.savefig(output_dir / 'loss_comparison.png', dpi=300, bbox_inches='tight')
    print(f"Saved loss comparison to {output_dir / 'loss_comparison.png'}")

    return fig


def create_summary_table(results_dict: Dict, output_dir: Path):
    """
    Create a summary table of all metrics
    """
    print("\n" + "="*90)
    print("MODEL TRAINING SUMMARY")
    print("="*90)
    print(f"{'Model':<10} {'Parameters':<15} {'Train Loss':<12} {'Val Loss':<12} {'Time (min)':<12} {'GPU (GB)':<10}")
    print("-"*90)

    # Sort by parameter count
    sorted_items = sorted(results_dict.items(), key=lambda x: x[1]['n_params'])

    for model_name, results in sorted_items:
        print(f"{model_name:<10} {results['n_params']:>14,} "
              f"{results['train_loss']:>11.4f} "
              f"{results['val_loss']:>11.4f} "
              f"{results['training_time']/60:>11.1f} "
              f"{results.get('gpu_memory_gb', 0):>9.2f}")

    print("="*90 + "\n")


# =============================================================================
# MAIN ANALYSIS SCRIPT
# =============================================================================

def main():
    """
    Run complete scaling analysis
    """
    # Configuration
    results_file = Path("/content/drive/MyDrive/models/checkpoints/all_results.json")
    output_dir = Path("/content/drive/MyDrive/models/analysis")
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading results from: {results_file}")

    # Load results
    with open(results_file, 'r') as f:
        results_dict = json.load(f)

    print(f"Found results for {len(results_dict)} models\n")

    # Create summary table
    create_summary_table(results_dict, output_dir)

    # Generate all plots
    print("Generating plots...\n")

    # 1. Main scaling law plot
    plot_scaling_law(results_dict, output_dir)

    # 2. Training curves
    plot_training_curves(results_dict, output_dir)

    # 3. Training metrics (time and memory)
    plot_training_metrics(results_dict, output_dir)

    # 4. Loss comparison
    plot_loss_comparison(results_dict, output_dir)

    print(f"\n{'='*70}")
    print("ANALYSIS COMPLETE")
    print(f"{'='*70}")
    print(f"All plots saved to: {output_dir}")
    print(f"{'='*70}\n")


if __name__ == "__main__":
    main()