In [None]:
"""
Best Model Training with Hyperparameter Tuning
Complete pipeline: Grid search → Best model training → Sample evaluation
Fixed at 5 epochs for final training

CS-GY 6923 Scaling Laws Project
"""

import os
import json
import math
import time
import re
import itertools
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict, List, Tuple
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from music21 import converter


# Import the GPT model from document 3 (same architecture)
@dataclass
class GPTConfig:
    vocab_size: int = 256
    block_size: int = 256
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1
    bias: bool = False


class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                            .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))


class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        b, t = idx.size()
        pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
        x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
        for block in self.transformer.h:
            x = block(x)
        logits = self.lm_head(self.transformer.ln_f(x))
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) if targets is not None else None
        return logits, loss

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            idx = torch.cat((idx, torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)), dim=1)
        return idx


class MusicDataset(Dataset):
    def __init__(self, data_path: str, vocab_path: str, block_size: int):
        self.block_size = block_size
        with open(vocab_path, 'r') as f:
            vocab = json.load(f)
        self.char_to_idx = {ch: i for i, ch in enumerate(vocab)}
        self.idx_to_char = {i: ch for i, ch in enumerate(vocab)}
        with open(data_path, 'r', encoding='utf-8') as f:
            text = f.read()
        self.tokens = [self.char_to_idx.get(ch, 0) for ch in text]

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx + self.block_size + 1]
        return torch.tensor(chunk[:-1], dtype=torch.long), torch.tensor(chunk[1:], dtype=torch.long)

    def decode(self, indices):
        return ''.join([self.idx_to_char.get(i, '<UNK>') for i in indices])


@dataclass
class TrainingConfig:
    """Unified training configuration"""
    model_name: str = 'best'
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    block_size: int = 512
    dropout: float = 0.1

    batch_size: int = 32
    epochs: int = 5  # Fixed
    learning_rate: float = 6e-4
    weight_decay: float = 0.1
    warmup_steps: int = 2000
    grad_clip: float = 1.0

    eval_interval: int = 500
    save_interval: int = 1000

    data_dir: str = "/content/drive/MyDrive/processed_music_data"
    output_dir: str = "/content/drive/MyDrive/models/best_model"
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'


def get_lr(step, config):
    """Cosine LR with warmup"""
    if step < config.warmup_steps:
        return config.learning_rate * step / config.warmup_steps
    total = len(train_loader) * config.epochs if 'train_loader' in globals() else 10000
    progress = (step - config.warmup_steps) / (total - config.warmup_steps)
    return config.learning_rate * 0.1 + 0.9 * config.learning_rate * (1 + math.cos(math.pi * progress)) / 2


@torch.no_grad()
def evaluate(model, loader, device, max_iters=100):
    model.eval()
    losses = []
    for i, (x, y) in enumerate(loader):
        if i >= max_iters:
            break
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        losses.append(loss.item())
    return np.mean(losses)


def save_checkpoint(model, optimizer, epoch, config, best_val_loss, step=0):
    """Save training checkpoint"""
    checkpoint_dir = Path(config.output_dir) / "checkpoints"
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    checkpoint = {
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
        'config': asdict(config)
    }

    # Save latest checkpoint
    latest_path = checkpoint_dir / "checkpoint_latest.pt"
    torch.save(checkpoint, latest_path)

    # Also save epoch-specific checkpoint
    epoch_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
    torch.save(checkpoint, epoch_path)

    print(f"Checkpoint saved at epoch {epoch}")
    return latest_path


def load_checkpoint(config):
    """Load the latest checkpoint if exists"""
    checkpoint_dir = Path(config.output_dir) / "checkpoints"
    latest_path = checkpoint_dir / "checkpoint_latest.pt"

    if latest_path.exists():
        print(f"Found checkpoint: {latest_path}")
        checkpoint = torch.load(latest_path, map_location='cpu', weights_only=False)
        return checkpoint

    return None


def train_model(config: TrainingConfig, is_tuning=False):
    """Core training function with auto-resume capability"""
    device = torch.device(config.device)

    # Load data
    vocab_path = Path(config.data_dir) / "vocab.json"
    with open(vocab_path) as f:
        vocab = json.load(f)

    train_ds = MusicDataset(Path(config.data_dir) / "train.txt", vocab_path, config.block_size)
    val_ds = MusicDataset(Path(config.data_dir) / "val.txt", vocab_path, config.block_size)
    test_ds = MusicDataset(Path(config.data_dir) / "test.txt", vocab_path, config.block_size)

    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False)

    # Model
    model_config = GPTConfig(
        vocab_size=len(vocab),
        block_size=config.block_size,
        n_layer=config.n_layer,
        n_head=config.n_head,
        n_embd=config.n_embd,
        dropout=config.dropout
    )
    model = GPT(model_config).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    # Initialize training state
    start_epoch = 0
    best_val_loss = float('inf')
    global_step = 0

    # AUTO-RESUME: Try to load checkpoint
    checkpoint = load_checkpoint(config)
    if checkpoint is not None:
        print("RESUMING FROM CHECKPOINT")
        print("="*80)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_loss = checkpoint['best_val_loss']
        global_step = checkpoint.get('step', 0)
        print(f"Resumed from epoch {start_epoch}")
        print(f"   Best val loss so far: {best_val_loss:.4f}")
        print("="*80 + "\n")
    else:
        print("No checkpoint found. Starting fresh training.\n")

    # Training loop
    try:
        for epoch in range(start_epoch, config.epochs):
            model.train()
            epoch_losses = []
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}")

            for batch_idx, (x, y) in enumerate(pbar):
                x, y = x.to(device), y.to(device)
                _, loss = model(x, y)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                optimizer.step()

                lr = get_lr(global_step, config)
                for pg in optimizer.param_groups:
                    pg['lr'] = lr

                epoch_losses.append(loss.item())
                pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{lr:.2e}'})

                # Periodic evaluation
                if global_step % config.eval_interval == 0 and global_step > 0:
                    val_loss = evaluate(model, val_loader, device)
                    print(f"\nStep {global_step}: val_loss = {val_loss:.4f}")

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        print(f"New best model!")

                    model.train()

                # Save checkpoint every save_interval steps
                if global_step % config.save_interval == 0 and global_step > 0:
                    save_checkpoint(model, optimizer, epoch, config, best_val_loss, global_step)

                global_step += 1

            # End of epoch: ALWAYS save checkpoint
            print(f"\nEnd of epoch {epoch+1} - Saving checkpoint...")
            save_checkpoint(model, optimizer, epoch + 1, config, best_val_loss, global_step)

            # Epoch summary
            train_loss = np.mean(epoch_losses)
            val_loss = evaluate(model, val_loader, device)

            print(f"\nEpoch {epoch+1}/{config.epochs} Summary:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss:   {val_loss:.4f}")
            print(f"  Best Val:   {best_val_loss:.4f}\n")

    except KeyboardInterrupt:
        print("\nTraining interrupted by user!")
        print("Saving emergency checkpoint...")
        save_checkpoint(model, optimizer, epoch, config, best_val_loss, global_step)
        print("Emergency checkpoint saved. You can resume training later.")
        raise

    except Exception as e:
        print(f"\nError during training: {e}")
        print("Saving emergency checkpoint...")
        save_checkpoint(model, optimizer, epoch, config, best_val_loss, global_step)
        print("Emergency checkpoint saved.")
        raise

    # Final eval
    test_loss = evaluate(model, test_loader, device, max_iters=999999)
    test_ppl = np.exp(test_loss)

    # Generate samples if not tuning
    samples = []
    if not is_tuning:
        samples = generate_samples(model, train_ds, config, device)

    return {
        'model': model,
        'val_loss': best_val_loss,
        'test_loss': test_loss,
        'test_perplexity': test_ppl,
        'samples': samples
    }


def generate_samples(model, dataset, config, device):
    """Generate and evaluate samples"""
    model.eval()
    samples = []

    for temp in [0.8, 0.9, 1.0]:
        for i in range(5):
            idx = torch.tensor([[0]], device=device)
            generated = model.generate(idx, max_new_tokens=512, temperature=temp, top_k=50)
            text = dataset.decode(generated[0].tolist())
            samples.append({'temp': temp, 'text': text, 'valid': len(text) > 50})

    return samples


def hyperparameter_tuning(base_config: TrainingConfig):
    print("\n" + "="*80)
    print("HYPERPARAMETER TUNING (2 epochs per config)")
    print("="*80)

    # REDUCED SEARCH SPACE - Only 4 configs due to limited compute
    search_space = [
        # (lr, dropout, weight_decay)
        (3e-4, 0.1, 0.1),   # Conservative baseline
        (6e-4, 0.1, 0.1),   # Standard GPT-style
        (1e-3, 0.1, 0.0),   # Higher LR, no weight decay
        (6e-4, 0.0, 0.1),   # No dropout
    ]

    combinations = search_space

    print(f"Testing {len(combinations)} carefully selected configurations")
    print("This will save ~75% compute vs full grid search\n")

    results = []
    for idx, (lr, dropout, wd) in enumerate(combinations):
        print(f"\nConfig {idx+1}/{len(combinations)}: LR={lr:.2e}, Dropout={dropout:.2f}, WD={wd:.2f}")

        config = TrainingConfig(
            model_name=f'tune_{idx+1}',
            n_layer=base_config.n_layer,
            n_head=base_config.n_head,
            n_embd=base_config.n_embd,
            block_size=base_config.block_size,
            dropout=dropout,
            learning_rate=lr,
            weight_decay=wd,
            epochs=2,
            batch_size=base_config.batch_size,
            data_dir=base_config.data_dir,
            output_dir=f"{base_config.output_dir}/tuning/config_{idx+1}"
        )

        try:
            result = train_model(config, is_tuning=True)
            results.append({
                'config_id': idx+1,
                'lr': lr,
                'dropout': dropout,
                'weight_decay': wd,
                'val_loss': result['val_loss'],
                'test_ppl': result['test_perplexity']
            })
            print(f"Val Loss: {result['val_loss']:.4f}, Test PPL: {result['test_perplexity']:.2f}")
        except Exception as e:
            print(f"Failed: {e}")
            results.append({'config_id': idx+1, 'lr': lr, 'dropout': dropout, 'weight_decay': wd,
                           'val_loss': float('inf'), 'test_ppl': float('inf')})

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Sort by val loss
    results.sort(key=lambda x: x['val_loss'])

    # Print results
    print("\n" + "="*80)
    print("TUNING RESULTS")
    print("="*80)
    for rank, r in enumerate(results[:5], 1):
        print(f"{rank}. LR={r['lr']:.2e}, Dropout={r['dropout']:.2f}, WD={r['weight_decay']:.2f} "
              f"→ Val={r['val_loss']:.4f}, PPL={r['test_ppl']:.2f}")

    # Save results
    tuning_dir = Path(base_config.output_dir) / "tuning"
    tuning_dir.mkdir(parents=True, exist_ok=True)
    with open(tuning_dir / "results.json", 'w') as f:
        json.dump(results, f, indent=2)

    best = results[0]
    print(f"\nBEST: LR={best['lr']:.2e}, Dropout={best['dropout']:.2f}, WD={best['weight_decay']:.2f}")

    return best


def main():
    """Complete pipeline with tuning + final training + sample generation"""

    base_config = TrainingConfig(
        n_layer=12,
        n_head=12,
        n_embd=768,
        block_size=512,
        epochs=2,
        data_dir="/content/drive/MyDrive/processed_music_data",
        output_dir="/content/drive/MyDrive/models/best_model"
    )

    # Tuning
    print("\nHYPERPARAMETER TUNING")
    best_hp = hyperparameter_tuning(base_config)

    # Train with best hyperparameters + sample generation
    print("\n" + "="*80)
    print("FINAL TRAINING WITH BEST HYPERPARAMETERS + GENERATE SAMPLES")
    print("="*80)

    final_config = TrainingConfig(
        model_name='final_best',
        n_layer=base_config.n_layer,
        n_head=base_config.n_head,
        n_embd=base_config.n_embd,
        block_size=base_config.block_size,
        dropout=best_hp['dropout'],
        learning_rate=best_hp['lr'],
        weight_decay=best_hp['weight_decay'],
        epochs=5,
        batch_size=base_config.batch_size,
        data_dir=base_config.data_dir,
        output_dir=f"{base_config.output_dir}/final"
    )

    result = train_model(final_config, is_tuning=False)

    # Save samples
    samples_dir = Path(final_config.output_dir) / "samples"
    samples_dir.mkdir(parents=True, exist_ok=True)
    for i, s in enumerate(result['samples']):
        with open(samples_dir / f"sample_{i+1:03d}_temp{s['temp']}.txt", 'w') as f:
            f.write(s['text'])

    print("\n" + "="*80)
    print("COMPLETE!")
    print("="*80)
    print(f"Test Perplexity: {result['test_perplexity']:.2f}")
    print(f"Valid Samples: {sum(1 for s in result['samples'] if s['valid'])}/{len(result['samples'])}")
    print(f"Results: {final_config.output_dir}")
    print("="*80)

if __name__ == "__main__":
    main()