In [None]:
"""
LSTM Scaling Study - Complete Training Pipeline
Standard LSTM architecture with modifications for music data

CS-GY 6923 Scaling Laws Project
"""

import os
import json
import time
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# =============================================================================
# MODEL ARCHITECTURE
# =============================================================================

@dataclass
class LSTMConfig:
    """LSTM model configuration"""
    vocab_size: int = 256
    embedding_dim: int = 128
    hidden_dim: int = 256
    num_layers: int = 2
    dropout: float = 0.2
    bidirectional: bool = False


class LSTMLanguageModel(nn.Module):
    """Standard LSTM for language modeling"""

    def __init__(self, config: LSTMConfig):
        super().__init__()
        self.config = config

        # Embedding layer
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.embedding_dim,
            padding_idx=0
        )

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=config.embedding_dim,
            hidden_size=config.hidden_dim,
            num_layers=config.num_layers,
            dropout=config.dropout if config.num_layers > 1 else 0.0,
            bidirectional=config.bidirectional,
            batch_first=True
        )

        # Dropout layer
        self.dropout = nn.Dropout(config.dropout)

        # Output layer
        lstm_output_dim = config.hidden_dim * (2 if config.bidirectional else 1)
        self.fc = nn.Linear(lstm_output_dim, config.vocab_size)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights"""
        # Embedding initialization
        nn.init.uniform_(self.embedding.weight, -0.1, 0.1)

        # LSTM initialization
        for name, param in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                # Set forget gate bias to 1
                n = param.size(0)
                param.data[n//4:n//2].fill_(1)

        # Linear layer initialization
        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x: torch.Tensor, hidden: Optional[Tuple] = None) -> Tuple[torch.Tensor, Tuple, torch.Tensor]:
        """
        Forward pass
        Args:
            x: input tensor [batch_size, seq_len]
            hidden: tuple of (h_0, c_0) or None
        Returns:
            logits: [batch_size, seq_len, vocab_size]
            hidden: tuple of (h_n, c_n)
            output: LSTM output [batch_size, seq_len, hidden_dim]
        """
        # Embedding: [batch, seq_len] -> [batch, seq_len, embedding_dim]
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)

        # LSTM: [batch, seq_len, embedding_dim] -> [batch, seq_len, hidden_dim]
        if hidden is not None:
            lstm_out, hidden = self.lstm(embedded, hidden)
        else:
            lstm_out, hidden = self.lstm(embedded)

        # Dropout
        lstm_out = self.dropout(lstm_out)

        # Output projection: [batch, seq_len, hidden_dim] -> [batch, seq_len, vocab_size]
        logits = self.fc(lstm_out)

        return logits, hidden, lstm_out

    def init_hidden(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize hidden state"""
        num_directions = 2 if self.config.bidirectional else 1
        h0 = torch.zeros(
            self.config.num_layers * num_directions,
            batch_size,
            self.config.hidden_dim,
            device=device
        )
        c0 = torch.zeros(
            self.config.num_layers * num_directions,
            batch_size,
            self.config.hidden_dim,
            device=device
        )
        return (h0, c0)

    def get_num_params(self) -> int:
        """Count total parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# =============================================================================
# MODEL CONFIGURATIONS FOR SCALING STUDY
# =============================================================================

MODEL_CONFIGS = {
    'tiny': LSTMConfig(
        vocab_size=256,
        embedding_dim=64,
        hidden_dim=128,
        num_layers=2,
        dropout=0.2
    ),
    'small': LSTMConfig(
        vocab_size=256,
        embedding_dim=128,
        hidden_dim=256,
        num_layers=3,
        dropout=0.2
    ),
    'medium': LSTMConfig(
        vocab_size=256,
        embedding_dim=256,
        hidden_dim=384,
        num_layers=4,
        dropout=0.3
    ),
    'large': LSTMConfig(
        vocab_size=256,
        embedding_dim=384,
        hidden_dim=512,
        num_layers=5,
        dropout=0.3
    ),
}


# =============================================================================
# DATASET
# =============================================================================

class SequenceDataset(Dataset):
    """Dataset for sequence modeling"""

    def __init__(self, data_path: str, vocab_path: str, seq_length: int = 256):
        """
        Args:
            data_path: path to text file
            vocab_path: path to vocabulary JSON
            seq_length: sequence length for training
        """
        self.seq_length = seq_length

        # Load vocabulary
        with open(vocab_path, 'r') as f:
            vocab = json.load(f)

        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.char_to_idx = {ch: i for i, ch in enumerate(vocab)}
        self.idx_to_char = {i: ch for i, ch in enumerate(vocab)}

        # Load and tokenize data
        print(f"Loading data from {data_path}...")
        with open(data_path, 'r', encoding='utf-8') as f:
            text = f.read()

        print("Tokenizing...")
        self.data = self._encode(text)
        print(f"Dataset contains {len(self.data):,} tokens")

    def _encode(self, text: str) -> List[int]:
        """Encode text to indices"""
        unk_idx = self.char_to_idx.get('<UNK>', 0)
        return [self.char_to_idx.get(ch, unk_idx) for ch in text]

    def decode(self, indices: List[int]) -> str:
        """Decode indices to text"""
        return ''.join([self.idx_to_char.get(idx, '<UNK>') for idx in indices])

    def __len__(self) -> int:
        return len(self.data) - self.seq_length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a training example"""
        # Input: sequence of length seq_length
        # Target: next token for each position
        x = torch.tensor(self.data[idx:idx + self.seq_length], dtype=torch.long)
        y = torch.tensor(self.data[idx + 1:idx + self.seq_length + 1], dtype=torch.long)
        return x, y


# =============================================================================
# TRAINING CONFIGURATION
# =============================================================================

@dataclass
class TrainingConfig:
    """Training hyperparameters"""
    # Data
    data_dir: str = "/content/drive/MyDrive/processed_music_data"
    vocab_file: str = "vocab.json"
    seq_length: int = 256

    # Training
    batch_size: int = 64
    num_epochs: int = 1
    learning_rate: float = 0.001
    grad_clip: float = 5.0

    # Optimizer
    optimizer: str = 'adam'  # 'adam' or 'sgd'
    momentum: float = 0.9  # for SGD
    weight_decay: float = 1e-5

    # Learning rate schedule
    use_scheduler: bool = True
    scheduler_type: str = 'step'  # 'step' or 'reduce_on_plateau'
    step_size: int = 5
    gamma: float = 0.5
    patience: int = 2

    # Evaluation
    eval_interval: int = 500

    # System
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_workers: int = 2

    # Output
    output_dir: str = "/content/drive/MyDrive/models/lstm_checkpoints"
    save_best_only: bool = True


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_epoch(
    model: LSTMLanguageModel,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    grad_clip: float,
    epoch: int
) -> float:
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

    for batch_idx, (x, y) in enumerate(pbar):
        x = x.to(device)
        y = y.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        logits, _, _ = model(x)

        # Reshape for loss calculation
        # logits: [batch, seq_len, vocab_size] -> [batch*seq_len, vocab_size]
        # y: [batch, seq_len] -> [batch*seq_len]
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        # Backward pass
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        # Update weights
        optimizer.step()

        # Track loss
        total_loss += loss.item()
        num_batches += 1

        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / num_batches


@torch.no_grad()
def evaluate(
    model: LSTMLanguageModel,
    val_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device
) -> Tuple[float, float]:
    """Evaluate model"""
    model.eval()
    total_loss = 0
    num_batches = 0

    for x, y in val_loader:
        x = x.to(device)
        y = y.to(device)

        # Forward pass
        logits, _, _ = model(x)

        # Calculate loss
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    perplexity = np.exp(avg_loss)

    return avg_loss, perplexity


def train_model(
    model_name: str,
    config: LSTMConfig,
    train_config: TrainingConfig
) -> Dict:
    """Complete training pipeline"""

    print(f"\n{'='*70}")
    print(f"Training {model_name.upper()} Model")
    print(f"{'='*70}")

    # Update vocab size
    vocab_path = Path(train_config.data_dir) / train_config.vocab_file
    with open(vocab_path, 'r') as f:
        vocab = json.load(f)
    config.vocab_size = len(vocab)

    # Create datasets
    train_dataset = SequenceDataset(
        Path(train_config.data_dir) / "train.txt",
        vocab_path,
        train_config.seq_length
    )

    val_dataset = SequenceDataset(
        Path(train_config.data_dir) / "val.txt",
        vocab_path,
        train_config.seq_length
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_config.batch_size,
        shuffle=True,
        num_workers=train_config.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=train_config.batch_size,
        shuffle=False,
        num_workers=train_config.num_workers,
        pin_memory=True
    )

    # Create model
    device = torch.device(train_config.device)
    model = LSTMLanguageModel(config).to(device)

    n_params = model.get_num_params()
    print(f"Model parameters: {n_params:,}")

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # Optimizer
    if train_config.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=train_config.learning_rate,
            weight_decay=train_config.weight_decay
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=train_config.learning_rate,
            momentum=train_config.momentum,
            weight_decay=train_config.weight_decay
        )

    # Learning rate scheduler
    scheduler = None
    if train_config.use_scheduler:
        if train_config.scheduler_type == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=train_config.step_size,
                gamma=train_config.gamma
            )
        else:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=train_config.gamma,
                patience=train_config.patience
            )

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    start_time = time.time()

    for epoch in range(1, train_config.num_epochs + 1):
        # Train
        train_loss = train_epoch(
            model, train_loader, criterion, optimizer,
            device, train_config.grad_clip, epoch
        )
        train_losses.append(train_loss)

        # Evaluate
        val_loss, perplexity = evaluate(model, val_loader, criterion, device)
        val_losses.append(val_loss)

        print(f"\nEpoch {epoch}/{train_config.num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Perplexity: {perplexity:.2f}")

        # Update learning rate
        if scheduler is not None:
            if train_config.scheduler_type == 'step':
                scheduler.step()
            else:
                scheduler.step(val_loss)
            print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            output_dir = Path(train_config.output_dir) / model_name
            output_dir.mkdir(parents=True, exist_ok=True)

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'config': asdict(config)
            }, output_dir / "best_model.pt")
            print(f"Saved best model (val_loss: {val_loss:.4f})")

    training_time = time.time() - start_time

    # Final evaluation
    final_val_loss, final_perplexity = evaluate(model, val_loader, criterion, device)

    # Get GPU memory
    gpu_memory = 0
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.max_memory_allocated() / 1024**3

    # Results
    results = {
        'model_name': model_name,
        'n_params': n_params,
        'config': asdict(config),
        'best_val_loss': best_val_loss,
        'final_val_loss': final_val_loss,
        'final_perplexity': final_perplexity,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'gpu_memory_gb': gpu_memory
    }

    # Save results
    output_dir = Path(train_config.output_dir) / model_name
    output_dir.mkdir(parents=True, exist_ok=True)

    with open(output_dir / "results.json", 'w') as f:
        json.dump(results, f, indent=2)

    with open(output_dir / "config.json", 'w') as f:
        json.dump(asdict(config), f, indent=2)

    print(f"\n{'='*70}")
    print(f"Training Complete: {model_name}")
    print(f"{'='*70}")
    print(f"Parameters: {n_params:,}")
    print(f"Best Val Loss: {best_val_loss:.4f}")
    print(f"Final Perplexity: {final_perplexity:.2f}")
    print(f"Training Time: {training_time/60:.1f} minutes")
    print(f"GPU Memory: {gpu_memory:.2f} GB")
    print(f"{'='*70}\n")

    return results


# =============================================================================
# MAIN
# =============================================================================

def main():
    """Train all models"""

    train_config = TrainingConfig(
        data_dir="/content/drive/MyDrive/processed_music_data",
        batch_size=64,
        num_epochs=1,
        learning_rate=0.001,
        output_dir="/content/drive/MyDrive/models/lstm_checkpoints"
    )

    print(f"Device: {train_config.device}")
    print(f"Data directory: {train_config.data_dir}")
    print(f"Output directory: {train_config.output_dir}")

    all_results = {}

    for model_name, config in MODEL_CONFIGS.items():
        results = train_model(model_name, config, train_config)
        all_results[model_name] = results

        # Clear cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Save combined results
    output_dir = Path(train_config.output_dir)
    with open(output_dir / "all_results.json", 'w') as f:
        json.dump(all_results, f, indent=2)

    # Summary
    print("\n" + "="*70)
    print("Training Summary")
    print("="*70)
    print(f"{'Model':<10} {'Params':<12} {'Val Loss':<12} {'Perplexity':<12}")
    print("-"*70)
    for name, res in all_results.items():
        print(f"{name:<10} {res['n_params']:>11,} {res['final_val_loss']:>11.4f} {res['final_perplexity']:>11.2f}")
    print("="*70)


if __name__ == "__main__":
    main()

In [None]:
"""
RNN (LSTM) vs Transformer (GPT) Scaling Law Comparison
Complete analysis of architectural scaling differences

CS-GY 6923 Scaling Laws Project
"""

import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.optimize import curve_fit
from typing import Dict, List, Tuple
import seaborn as sns

# Set style for publication-quality plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# =============================================================================
# SCALING LAW FUNCTIONS
# =============================================================================

def power_law(N, a, alpha, c):
    """Power law: L = a * N^(-alpha) + c"""
    return a * np.power(N, -alpha) + c


def fit_scaling_law(param_counts: np.ndarray, losses: np.ndarray) -> Tuple[np.ndarray, Dict]:
    """Fit power law to scaling data"""
    p0 = [1.0, 0.1, 0.5]

    try:
        popt, pcov = curve_fit(power_law, param_counts, losses, p0=p0, maxfev=10000)

        # Calculate R-squared
        residuals = losses - power_law(param_counts, *popt)
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((losses - np.mean(losses))**2)
        r_squared = 1 - (ss_res / ss_tot)

        perr = np.sqrt(np.diag(pcov))

        fit_info = {
            'a': popt[0],
            'alpha': popt[1],
            'c': popt[2],
            'a_stderr': perr[0],
            'alpha_stderr': perr[1],
            'c_stderr': perr[2],
            'r_squared': r_squared
        }

        return popt, fit_info

    except Exception as e:
        print(f"Error fitting: {e}")
        return None, None


# =============================================================================
# DATA LOADING
# =============================================================================

def load_results(lstm_file: Path, gpt_file: Path) -> Tuple[Dict, Dict]:
    """Load results from both architectures"""

    print("Loading LSTM results...")
    with open(lstm_file, 'r') as f:
        lstm_results = json.load(f)

    print("Loading GPT results...")
    with open(gpt_file, 'r') as f:
        gpt_results = json.load(f)

    return lstm_results, gpt_results


def extract_metrics(results: Dict) -> Dict:
    """Extract key metrics from results"""
    params = []
    val_losses = []
    train_losses = []
    times = []
    memory = []
    model_names = []

    for name, res in results.items():
        model_names.append(name)
        params.append(res['n_params'])

        # Handle different key names for validation loss
        if 'val_loss' in res:
            val_losses.append(res['val_loss'])
        elif 'final_val_loss' in res:
            val_losses.append(res['final_val_loss'])
        else:
            val_losses.append(res.get('best_val_loss', 0))

        # Handle different key names for training loss
        if 'train_loss' in res:
            train_losses.append(res['train_loss'])
        elif 'train_losses' in res and res['train_losses']:
            train_losses.append(res['train_losses'][-1])
        else:
            train_losses.append(0)

        times.append(res.get('training_time', 0))
        memory.append(res.get('gpu_memory_gb', 0))

    # Sort by parameter count
    sort_idx = np.argsort(params)

    return {
        'model_names': [model_names[i] for i in sort_idx],
        'params': np.array([params[i] for i in sort_idx]),
        'val_losses': np.array([val_losses[i] for i in sort_idx]),
        'train_losses': np.array([train_losses[i] for i in sort_idx]),
        'times': np.array([times[i] for i in sort_idx]),
        'memory': np.array([memory[i] for i in sort_idx])
    }


# =============================================================================
# PLOTTING FUNCTIONS
# =============================================================================

def plot_combined_scaling_laws(lstm_metrics: Dict, gpt_metrics: Dict, output_dir: Path):
    """Main comparison plot: both scaling curves together"""

    fig, ax = plt.subplots(figsize=(12, 7))

    # Fit scaling laws
    lstm_popt, lstm_fit = fit_scaling_law(lstm_metrics['params'], lstm_metrics['val_losses'])
    gpt_popt, gpt_fit = fit_scaling_law(gpt_metrics['params'], gpt_metrics['val_losses'])

    # Plot LSTM
    ax.scatter(lstm_metrics['params'], lstm_metrics['val_losses'],
              s=150, alpha=0.7, c='#E74C3C', edgecolors='black',
              linewidth=1.5, marker='s', label='LSTM (observed)', zorder=3)

    for name, n, l in zip(lstm_metrics['model_names'], lstm_metrics['params'], lstm_metrics['val_losses']):
        ax.annotate(f"LSTM-{name}", (n, l), xytext=(8, -8),
                   textcoords='offset points', fontsize=8, alpha=0.7)

    # Plot GPT
    ax.scatter(gpt_metrics['params'], gpt_metrics['val_losses'],
              s=150, alpha=0.7, c='#3498DB', edgecolors='black',
              linewidth=1.5, marker='o', label='Transformer (observed)', zorder=3)

    for name, n, l in zip(gpt_metrics['model_names'], gpt_metrics['params'], gpt_metrics['val_losses']):
        ax.annotate(f"GPT-{name}", (n, l), xytext=(8, 8),
                   textcoords='offset points', fontsize=8, alpha=0.7)

    # Plot fitted curves
    if lstm_popt is not None:
        N_smooth = np.logspace(np.log10(lstm_metrics['params'].min()),
                               np.log10(lstm_metrics['params'].max()), 100)
        L_fitted = power_law(N_smooth, *lstm_popt)
        ax.plot(N_smooth, L_fitted, '--', linewidth=2.5, color='#E74C3C',
                label=f'LSTM fit (α={lstm_fit["alpha"]:.4f})', alpha=0.8, zorder=2)

    if gpt_popt is not None:
        N_smooth = np.logspace(np.log10(gpt_metrics['params'].min()),
                               np.log10(gpt_metrics['params'].max()), 100)
        L_fitted = power_law(N_smooth, *gpt_popt)
        ax.plot(N_smooth, L_fitted, '--', linewidth=2.5, color='#3498DB',
                label=f'Transformer fit (α={gpt_fit["alpha"]:.4f})', alpha=0.8, zorder=2)

    # Formatting
    ax.set_xscale('log')
    ax.set_xlabel('Number of Parameters (N)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Validation Loss (L)', fontsize=13, fontweight='bold')
    ax.set_title('Scaling Laws: LSTM vs Transformer Architecture',
                 fontsize=15, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=10, loc='upper right', framealpha=0.9)

    # Add comparison box
    if lstm_fit and gpt_fit:
        comparison_text = (
            f"Scaling Exponents:\n"
            f"LSTM: α = {lstm_fit['alpha']:.4f} ± {lstm_fit['alpha_stderr']:.4f}\n"
            f"Transformer: α = {gpt_fit['alpha']:.4f} ± {gpt_fit['alpha_stderr']:.4f}\n"
            f"\n"
            f"R² Values:\n"
            f"LSTM: {lstm_fit['r_squared']:.4f}\n"
            f"Transformer: {gpt_fit['r_squared']:.4f}\n"
            f"\n"
            f"{'Transformer scales better!' if gpt_fit['alpha'] > lstm_fit['alpha'] else 'LSTM scales better!'}"
        )

        ax.text(0.03, 0.97, comparison_text, transform=ax.transAxes,
                fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.85))

    plt.tight_layout()
    plt.savefig(output_dir / 'combined_scaling_laws.png', dpi=300, bbox_inches='tight')
    print(f"Saved combined scaling plot to {output_dir / 'combined_scaling_laws.png'}")

    return fig, lstm_fit, gpt_fit


def plot_computational_efficiency(lstm_metrics: Dict, gpt_metrics: Dict, output_dir: Path):
    """Compare computational efficiency: time and memory per parameter"""

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Calculate efficiency metrics
    lstm_time_per_param = lstm_metrics['times'] / lstm_metrics['params'] * 1e6  # microseconds per param
    gpt_time_per_param = gpt_metrics['times'] / gpt_metrics['params'] * 1e6

    lstm_memory_per_param = lstm_metrics['memory'] / lstm_metrics['params'] * 1e9  # bytes per param
    gpt_memory_per_param = gpt_metrics['memory'] / gpt_metrics['params'] * 1e9

    # Plot 1: Training time per parameter
    ax1.scatter(lstm_metrics['params'], lstm_time_per_param,
               s=150, alpha=0.7, c='#E74C3C', edgecolors='black',
               linewidth=1.5, marker='s', label='LSTM', zorder=3)
    ax1.scatter(gpt_metrics['params'], gpt_time_per_param,
               s=150, alpha=0.7, c='#3498DB', edgecolors='black',
               linewidth=1.5, marker='o', label='Transformer', zorder=3)

    ax1.set_xscale('log')
    ax1.set_xlabel('Number of Parameters', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Training Time per Parameter (μs)', fontsize=12, fontweight='bold')
    ax1.set_title('Time Efficiency: Training Time per Parameter', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.legend(fontsize=10)

    # Add annotations
    for name, n, t in zip(lstm_metrics['model_names'], lstm_metrics['params'], lstm_time_per_param):
        ax1.annotate(name, (n, t), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
    for name, n, t in zip(gpt_metrics['model_names'], gpt_metrics['params'], gpt_time_per_param):
        ax1.annotate(name, (n, t), xytext=(5, -10), textcoords='offset points', fontsize=8, alpha=0.7)

    # Plot 2: Memory per parameter
    if any(lstm_memory_per_param > 0) and any(gpt_memory_per_param > 0):
        ax2.scatter(lstm_metrics['params'], lstm_memory_per_param,
                   s=150, alpha=0.7, c='#E74C3C', edgecolors='black',
                   linewidth=1.5, marker='s', label='LSTM', zorder=3)
        ax2.scatter(gpt_metrics['params'], gpt_memory_per_param,
                   s=150, alpha=0.7, c='#3498DB', edgecolors='black',
                   linewidth=1.5, marker='o', label='Transformer', zorder=3)

        ax2.set_xscale('log')
        ax2.set_xlabel('Number of Parameters', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Memory per Parameter (bytes)', fontsize=12, fontweight='bold')
        ax2.set_title('Memory Efficiency: GPU Memory per Parameter', fontsize=13, fontweight='bold')
        ax2.grid(True, alpha=0.3, linestyle='--')
        ax2.legend(fontsize=10)

        # Add annotations
        for name, n, m in zip(lstm_metrics['model_names'], lstm_metrics['params'], lstm_memory_per_param):
            ax2.annotate(name, (n, m), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
        for name, n, m in zip(gpt_metrics['model_names'], gpt_metrics['params'], gpt_memory_per_param):
            ax2.annotate(name, (n, m), xytext=(5, -10), textcoords='offset points', fontsize=8, alpha=0.7)

    plt.tight_layout()
    plt.savefig(output_dir / 'computational_efficiency.png', dpi=300, bbox_inches='tight')
    print(f"Saved computational efficiency plot to {output_dir / 'computational_efficiency.png'}")

    return fig


def plot_absolute_performance(lstm_metrics: Dict, gpt_metrics: Dict, output_dir: Path):
    """Compare absolute training time and memory usage"""

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    # Plot 1: Total training time
    ax1.scatter(lstm_metrics['params'], lstm_metrics['times'] / 60,
               s=150, alpha=0.7, c='#E74C3C', edgecolors='black',
               linewidth=1.5, marker='s', label='LSTM', zorder=3)
    ax1.scatter(gpt_metrics['params'], gpt_metrics['times'] / 60,
               s=150, alpha=0.7, c='#3498DB', edgecolors='black',
               linewidth=1.5, marker='o', label='Transformer', zorder=3)

    ax1.set_xscale('log')
    ax1.set_xlabel('Number of Parameters', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Total Training Time (minutes)', fontsize=12, fontweight='bold')
    ax1.set_title('Absolute Training Time vs Model Size', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.legend(fontsize=10)

    # Add annotations
    for name, n, t in zip(lstm_metrics['model_names'], lstm_metrics['params'], lstm_metrics['times'] / 60):
        ax1.annotate(name, (n, t), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
    for name, n, t in zip(gpt_metrics['model_names'], gpt_metrics['params'], gpt_metrics['times'] / 60):
        ax1.annotate(name, (n, t), xytext=(5, -10), textcoords='offset points', fontsize=8, alpha=0.7)

    # Plot 2: Total GPU memory
    if any(lstm_metrics['memory'] > 0) and any(gpt_metrics['memory'] > 0):
        ax2.scatter(lstm_metrics['params'], lstm_metrics['memory'],
                   s=150, alpha=0.7, c='#E74C3C', edgecolors='black',
                   linewidth=1.5, marker='s', label='LSTM', zorder=3)
        ax2.scatter(gpt_metrics['params'], gpt_metrics['memory'],
                   s=150, alpha=0.7, c='#3498DB', edgecolors='black',
                   linewidth=1.5, marker='o', label='Transformer', zorder=3)

        ax2.set_xscale('log')
        ax2.set_xlabel('Number of Parameters', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Peak GPU Memory (GB)', fontsize=12, fontweight='bold')
        ax2.set_title('Absolute GPU Memory vs Model Size', fontsize=13, fontweight='bold')
        ax2.grid(True, alpha=0.3, linestyle='--')
        ax2.legend(fontsize=10)

        # Add annotations
        for name, n, m in zip(lstm_metrics['model_names'], lstm_metrics['params'], lstm_metrics['memory']):
            ax2.annotate(name, (n, m), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
        for name, n, m in zip(gpt_metrics['model_names'], gpt_metrics['params'], gpt_metrics['memory']):
            ax2.annotate(name, (n, m), xytext=(5, -10), textcoords='offset points', fontsize=8, alpha=0.7)

    plt.tight_layout()
    plt.savefig(output_dir / 'absolute_resources.png', dpi=300, bbox_inches='tight')
    print(f"Saved absolute resources plot to {output_dir / 'absolute_resources.png'}")

    return fig


def plot_sample_efficiency(lstm_metrics: Dict, gpt_metrics: Dict, output_dir: Path):
    """Compare sample efficiency: validation loss achieved vs training time"""

    fig, ax = plt.subplots(figsize=(10, 7))

    # Plot loss vs training time
    ax.scatter(lstm_metrics['times'] / 60, lstm_metrics['val_losses'],
              s=200, alpha=0.7, c='#E74C3C', edgecolors='black',
              linewidth=1.5, marker='s', label='LSTM', zorder=3)
    ax.scatter(gpt_metrics['times'] / 60, gpt_metrics['val_losses'],
              s=200, alpha=0.7, c='#3498DB', edgecolors='black',
              linewidth=1.5, marker='o', label='Transformer', zorder=3)

    # Add model size as labels
    for name, t, l, p in zip(lstm_metrics['model_names'], lstm_metrics['times'] / 60,
                             lstm_metrics['val_losses'], lstm_metrics['params']):
        ax.annotate(f"{name}\n({p:,})", (t, l), xytext=(8, 8),
                   textcoords='offset points', fontsize=8, alpha=0.8)

    for name, t, l, p in zip(gpt_metrics['model_names'], gpt_metrics['times'] / 60,
                            gpt_metrics['val_losses'], gpt_metrics['params']):
        ax.annotate(f"{name}\n({p:,})", (t, l), xytext=(8, -12),
                   textcoords='offset points', fontsize=8, alpha=0.8)

    ax.set_xlabel('Training Time (minutes)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Validation Loss', fontsize=12, fontweight='bold')
    ax.set_title('Sample Efficiency: Loss vs Training Time\n(Lower-left is better)',
                 fontsize=14, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(fontsize=11, loc='best')

    plt.tight_layout()
    plt.savefig(output_dir / 'sample_efficiency.png', dpi=300, bbox_inches='tight')
    print(f"Saved sample efficiency plot to {output_dir / 'sample_efficiency.png'}")

    return fig


def plot_loss_comparison_bar(lstm_metrics: Dict, gpt_metrics: Dict, output_dir: Path):
    """Side-by-side comparison of validation losses"""

    fig, ax = plt.subplots(figsize=(12, 7))

    # Combine data
    n_lstm = len(lstm_metrics['model_names'])
    n_gpt = len(gpt_metrics['model_names'])

    x = np.arange(max(n_lstm, n_gpt))
    width = 0.35

    # Create bars
    lstm_bars = ax.bar(x[:n_lstm] - width/2, lstm_metrics['val_losses'], width,
                       label='LSTM', color='#E74C3C', edgecolor='black', linewidth=1.2)
    gpt_bars = ax.bar(x[:n_gpt] + width/2, gpt_metrics['val_losses'], width,
                      label='Transformer', color='#3498DB', edgecolor='black', linewidth=1.2)

    # Add value labels
    for bars in [lstm_bars, gpt_bars]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.4f}', ha='center', va='bottom', fontsize=9)

    # Formatting
    ax.set_xlabel('Model Size Category', fontsize=12, fontweight='bold')
    ax.set_ylabel('Validation Loss', fontsize=12, fontweight='bold')
    ax.set_title('Validation Loss Comparison: LSTM vs Transformer',
                 fontsize=14, fontweight='bold', pad=20)
    ax.set_xticks(x)

    # Create labels with parameter counts
    labels = []
    for i in range(max(n_lstm, n_gpt)):
        if i < n_lstm and i < n_gpt:
            label = f"{lstm_metrics['model_names'][i]}/{gpt_metrics['model_names'][i]}\n"
            label += f"({lstm_metrics['params'][i]:,} / {gpt_metrics['params'][i]:,})"
        elif i < n_lstm:
            label = f"{lstm_metrics['model_names'][i]}\n({lstm_metrics['params'][i]:,})"
        else:
            label = f"{gpt_metrics['model_names'][i]}\n({gpt_metrics['params'][i]:,})"
        labels.append(label)

    ax.set_xticklabels(labels, fontsize=9)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, linestyle='--', axis='y')

    plt.tight_layout()
    plt.savefig(output_dir / 'loss_comparison_bars.png', dpi=300, bbox_inches='tight')
    print(f"Saved loss comparison bars to {output_dir / 'loss_comparison_bars.png'}")

    return fig


# =============================================================================
# ANALYSIS AND REPORTING
# =============================================================================

def generate_comparison_report(lstm_fit: Dict, gpt_fit: Dict,
                              lstm_metrics: Dict, gpt_metrics: Dict,
                              output_dir: Path):
    """Generate comprehensive comparison report"""

    report = []
    report.append("="*80)
    report.append("LSTM VS TRANSFORMER SCALING LAW COMPARISON")
    report.append("="*80)
    report.append("")

    # Scaling law comparison
    report.append("1. SCALING LAW PARAMETERS")
    report.append("-" * 80)
    report.append(f"{'Architecture':<15} {'α (exponent)':<20} {'a (coeff)':<20} {'c (offset)':<15}")
    report.append("-" * 80)

    if lstm_fit:
        report.append(f"{'LSTM':<15} {lstm_fit['alpha']:.6f} ± {lstm_fit['alpha_stderr']:.6f}  "
                     f"{lstm_fit['a']:.6f} ± {lstm_fit['a_stderr']:.6f}  {lstm_fit['c']:.6f}")

    if gpt_fit:
        report.append(f"{'Transformer':<15} {gpt_fit['alpha']:.6f} ± {gpt_fit['alpha_stderr']:.6f}  "
                     f"{gpt_fit['a']:.6f} ± {gpt_fit['a_stderr']:.6f}  {gpt_fit['c']:.6f}")

    report.append("")

    # Interpretation
    report.append("2. SCALING EFFICIENCY ANALYSIS")
    report.append("-" * 80)

    if lstm_fit and gpt_fit:
        alpha_diff = gpt_fit['alpha'] - lstm_fit['alpha']
        better_arch = "Transformer" if alpha_diff > 0 else "LSTM"

        report.append(f"Scaling Exponent Difference: Δα = {alpha_diff:.6f}")
        report.append(f"Better Scaling Architecture: {better_arch}")
        report.append("")

        # Calculate parameter multipliers for same loss reduction
        lstm_mult = 2 ** (1 / lstm_fit['alpha'])
        gpt_mult = 2 ** (1 / gpt_fit['alpha'])

        report.append(f"To halve the loss:")
        report.append(f"  LSTM needs:        {lstm_mult:.2f}× more parameters")
        report.append(f"  Transformer needs: {gpt_mult:.2f}× more parameters")
        report.append(f"  Efficiency ratio:  {lstm_mult / gpt_mult:.2f}× (Transformer advantage)")
        report.append("")

    # Computational efficiency
    report.append("3. COMPUTATIONAL EFFICIENCY")
    report.append("-" * 80)

    lstm_avg_time_per_param = np.mean(lstm_metrics['times'] / lstm_metrics['params'])
    gpt_avg_time_per_param = np.mean(gpt_metrics['times'] / gpt_metrics['params'])

    lstm_avg_mem_per_param = np.mean(lstm_metrics['memory'] / lstm_metrics['params']) if any(lstm_metrics['memory'] > 0) else 0
    gpt_avg_mem_per_param = np.mean(gpt_metrics['memory'] / gpt_metrics['params']) if any(gpt_metrics['memory'] > 0) else 0

    report.append(f"Average Training Time per Parameter:")
    report.append(f"  LSTM:        {lstm_avg_time_per_param*1e6:.2f} μs/param")
    report.append(f"  Transformer: {gpt_avg_time_per_param*1e6:.2f} μs/param")
    report.append(f"  Ratio:       {lstm_avg_time_per_param / gpt_avg_time_per_param:.2f}× "
                 f"({'LSTM faster' if lstm_avg_time_per_param < gpt_avg_time_per_param else 'Transformer faster'})")
    report.append("")

    if lstm_avg_mem_per_param > 0 and gpt_avg_mem_per_param > 0:
        report.append(f"Average GPU Memory per Parameter:")
        report.append(f"  LSTM:        {lstm_avg_mem_per_param*1e9:.2f} bytes/param")
        report.append(f"  Transformer: {gpt_avg_mem_per_param*1e9:.2f} bytes/param")
        report.append(f"  Ratio:       {lstm_avg_mem_per_param / gpt_avg_mem_per_param:.2f}× "
                     f"({'LSTM more efficient' if lstm_avg_mem_per_param < gpt_avg_mem_per_param else 'Transformer more efficient'})")
        report.append("")

    # Sample efficiency
    report.append("4. SAMPLE EFFICIENCY")
    report.append("-" * 80)

    # Best loss per architecture
    lstm_best_idx = np.argmin(lstm_metrics['val_losses'])
    gpt_best_idx = np.argmin(gpt_metrics['val_losses'])

    report.append(f"Best Validation Loss:")
    report.append(f"  LSTM:        {lstm_metrics['val_losses'][lstm_best_idx]:.6f} "
                 f"({lstm_metrics['model_names'][lstm_best_idx]}, {lstm_metrics['params'][lstm_best_idx]:,} params)")
    report.append(f"  Transformer: {gpt_metrics['val_losses'][gpt_best_idx]:.6f} "
                 f"({gpt_metrics['model_names'][gpt_best_idx]}, {gpt_metrics['params'][gpt_best_idx]:,} params)")
    report.append("")

    # Loss per training time
    lstm_loss_per_time = lstm_metrics['val_losses'] / (lstm_metrics['times'] / 3600)  # loss per hour
    gpt_loss_per_time = gpt_metrics['val_losses'] / (gpt_metrics['times'] / 3600)

    report.append(f"Loss Achieved per Training Hour (lower is better):")
    report.append(f"  LSTM average:        {np.mean(lstm_loss_per_time):.6f}")
    report.append(f"  Transformer average: {np.mean(gpt_loss_per_time):.6f}")
    report.append("")

    # Discussion
    report.append("5. DISCUSSION: WHY THESE DIFFERENCES?")
    report.append("-" * 80)
    report.append("")
    report.append("Architectural Differences:")
    report.append("")
    report.append("a) Parallelization:")
    report.append("   - Transformers: Process entire sequence in parallel (attention)")
    report.append("   - LSTMs: Sequential processing (one token at a time)")
    report.append("   → Transformers train faster per epoch but may use more memory")
    report.append("")
    report.append("b) Long-range Dependencies:")
    report.append("   - Transformers: Direct connections via self-attention (O(n²))")
    report.append("   - LSTMs: Information flows through hidden state (sequential)")
    report.append("   → Transformers better at capturing long-range patterns")
    report.append("")
    report.append("c) Parameter Efficiency:")

    gpt_alpha = f"{gpt_fit['alpha']:.4f}" if gpt_fit else "N/A"
    lstm_alpha = f"{lstm_fit['alpha']:.4f}" if lstm_fit else "N/A"

    report.append(f"   - Higher α value ({gpt_alpha} vs {lstm_alpha}) means:")

    if lstm_fit and gpt_fit:
        if gpt_fit['alpha'] > lstm_fit['alpha']:
            report.append("   - Transformers scale better: more benefit from additional parameters")
            report.append("   - Each new parameter in Transformer provides more loss reduction")
        else:
            report.append("   - LSTMs scale better: more benefit from additional parameters")
            report.append("   - Each new parameter in LSTM provides more loss reduction")

    report.append("")
    report.append("d) Computational Trade-offs:")
    report.append("   - Transformers: Higher throughput (parallel), higher memory (attention)")
    report.append("   - LSTMs: Lower memory, sequential bottleneck limits speed")
    report.append("")

    # Recommendations
    report.append("6. PRACTICAL RECOMMENDATIONS")
    report.append("-" * 80)

    if lstm_fit and gpt_fit and gpt_fit['alpha'] > lstm_fit['alpha']:
        report.append("Choose TRANSFORMERS when:")
        report.append("  - Scaling to large models (better parameter efficiency)")
        report.append("  - GPU memory is available (supports parallel attention)")
        report.append("  - Long-range dependencies are important")
        report.append("  - Training throughput is priority")
        report.append("")
        report.append("Choose LSTMS when:")
        report.append("  - Resource-constrained environments (lower memory)")
        report.append("  - Smaller models are sufficient")
        report.append("  - Sequential processing is acceptable")
        report.append("  - Inference latency matters more than throughput")
    else:
        report.append("Based on these results, consider your specific constraints")
        report.append("  and requirements when choosing architecture.")

    report.append("")
    report.append("="*80)
    report.append("END OF REPORT")
    report.append("="*80)

    # Print to console
    report_text = "\n".join(report)
    print("\n" + report_text)

    # Save to file
    with open(output_dir / "comparison_report.txt", 'w') as f:
        f.write(report_text)

    print(f"\nReport saved to {output_dir / 'comparison_report.txt'}")


def create_summary_table(lstm_metrics: Dict, gpt_metrics: Dict):
    """Print summary table"""

    print("\n" + "="*100)
    print("DETAILED MODEL COMPARISON")
    print("="*100)
    print(f"{'Architecture':<12} {'Model':<10} {'Params':<15} {'Val Loss':<12} {'Time (min)':<12} {'Memory (GB)':<12}")
    print("-"*100)

    for name, p, l, t, m in zip(lstm_metrics['model_names'], lstm_metrics['params'],
                                 lstm_metrics['val_losses'], lstm_metrics['times'] / 60,
                                 lstm_metrics['memory']):
        print(f"{'LSTM':<12} {name:<10} {p:>14,} {l:>11.6f} {t:>11.2f} {m:>11.2f}")

    for name, p, l, t, m in zip(gpt_metrics['model_names'], gpt_metrics['params'],
                                 gpt_metrics['val_losses'], gpt_metrics['times'] / 60,
                                 gpt_metrics['memory']):
        print(f"{'Transformer':<12} {name:<10} {p:>14,} {l:>11.6f} {t:>11.2f} {m:>11.2f}")

    print("="*100 + "\n")


# =============================================================================
# MAIN ANALYSIS
# =============================================================================

def main():
    """Run complete comparison analysis"""

    # Configuration
    lstm_results_file = Path("/content/drive/MyDrive/models/lstm_checkpoints/all_results.json")
    gpt_results_file = Path("/content/drive/MyDrive/models/checkpoints/all_results.json")
    output_dir = Path("/content/drive/MyDrive/models/comparison_analysis")
    output_dir.mkdir(parents=True, exist_ok=True)

    print("="*80)
    print("RNN (LSTM) VS TRANSFORMER (GPT) SCALING ANALYSIS")
    print("="*80)
    print(f"LSTM results: {lstm_results_file}")
    print(f"GPT results:  {gpt_results_file}")
    print(f"Output dir:   {output_dir}")
    print("="*80 + "\n")

    # Load data
    lstm_results, gpt_results = load_results(lstm_results_file, gpt_results_file)

    print(f"Loaded {len(lstm_results)} LSTM models")
    print(f"Loaded {len(gpt_results)} GPT models\n")

    # Extract metrics
    lstm_metrics = extract_metrics(lstm_results)
    gpt_metrics = extract_metrics(gpt_results)

    # Print summary table
    create_summary_table(lstm_metrics, gpt_metrics)

    # Generate all plots
    print("\nGenerating comparison plots...\n")

    # 1. Combined scaling laws (main plot)
    fig1, lstm_fit, gpt_fit = plot_combined_scaling_laws(lstm_metrics, gpt_metrics, output_dir)

    # 2. Computational efficiency
    fig2 = plot_computational_efficiency(lstm_metrics, gpt_metrics, output_dir)

    # 3. Absolute resources
    fig3 = plot_absolute_performance(lstm_metrics, gpt_metrics, output_dir)

    # 4. Sample efficiency
    fig4 = plot_sample_efficiency(lstm_metrics, gpt_metrics, output_dir)

    # 5. Loss comparison bars
    fig5 = plot_loss_comparison_bar(lstm_metrics, gpt_metrics, output_dir)

    # Generate comprehensive report
    generate_comparison_report(lstm_fit, gpt_fit, lstm_metrics, gpt_metrics, output_dir)

    # Save comparison data
    comparison_data = {
        'lstm_fit': lstm_fit,
        'gpt_fit': gpt_fit,
        'lstm_summary': {
            'models': lstm_metrics['model_names'],
            'params': lstm_metrics['params'].tolist(),
            'val_losses': lstm_metrics['val_losses'].tolist(),
            'times': lstm_metrics['times'].tolist(),
            'memory': lstm_metrics['memory'].tolist()
        },
        'gpt_summary': {
            'models': gpt_metrics['model_names'],
            'params': gpt_metrics['params'].tolist(),
            'val_losses': gpt_metrics['val_losses'].tolist(),
            'times': gpt_metrics['times'].tolist(),
            'memory': gpt_metrics['memory'].tolist()
        }
    }

    with open(output_dir / 'comparison_data.json', 'w') as f:
        json.dump(comparison_data, f, indent=2)

    print(f"\n{'='*80}")
    print("ANALYSIS COMPLETE")
    print(f"{'='*80}")
    print(f"All plots and reports saved to: {output_dir}")
    print(f"{'='*80}\n")


if __name__ == "__main__":
    main()