In [None]:
#========================================================
# ULTIMATE CLAUDE — NanoDiidi with Modern Transformer Architecture
#========================================================
#
# This is NanoDiidi rebuilt from the ground up with every
# architectural refinement used in frontier models like Claude:
#
#   1.  RoPE        — Rotary Position Embeddings (replaces learned pos emb)
#   2.  RMSNorm     — Root Mean Square Normalization (replaces LayerNorm)
#   3.  SwiGLU      — Gated feed-forward with SiLU activation
#   4.  GQA         — Grouped Query Attention (shared K/V heads)
#   5.  QK-Norm     — RMSNorm on Q and K before attention
#   6.  MoE         — Mixture of Experts with top-k routing
#   7.  Sliding Window Attention — bounded local attention
#   8.  KV-Cache    — efficient autoregressive generation
#   9.  Residual Scaling — 1/sqrt(n_layers) for training stability
#  10.  Vision Encoder — ViT-style multi-modal support (optional)
#  11.  Reward Model — scalar scoring head for RLHF
#  12.  Speculative Decoding — draft-then-verify for fast inference
#
#========================================================


#========================================================
# --- 1. IMPORTS
#========================================================
import os
import math
import time
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from tqdm.auto import tqdm
from pathlib import Path
import tiktoken
from typing import Optional, Callable, Tuple, Dict, List
from dataclasses import dataclass


#========================================================
# --- 2. CONFIGURATION
#========================================================
@dataclass
class ModelConfig:
    # --- Model architecture ---
    vocab_size: int = 100277       # tiktoken cl100k_base
    n_layer: int = 6
    n_head: int = 6                # query heads
    n_kv_head: int = 2             # key/value heads (GQA: n_head must be divisible by n_kv_head)
    n_embd: int = 384
    block_size: int = 256          # max sequence length
    dropout: float = 0.1

    # MoE
    n_expert: int = 4              # number of experts (1 = dense, no MoE)
    n_expert_active: int = 2       # top-k experts routed per token
    moe_every_n: int = 2           # apply MoE every N layers; others use dense SwiGLU

    # Sliding window (0 = full attention in all layers)
    sliding_window: int = 128

    # Vision encoder (disabled by default for text-only training)
    use_vision: bool = False
    image_size: int = 224
    patch_size: int = 16
    vision_n_layer: int = 4
    vision_n_head: int = 4
    vision_n_embd: int = 384

    # --- Training ---
    batch_size: int = 16
    learning_rate: float = 6e-4
    warmup_iters: int = 500
    max_iters: int = 5000
    eval_interval: int = 500

    # --- Derived properties ---
    @property
    def head_dim(self) -> int:
        return self.n_embd // self.n_head

    @property
    def n_kv_groups(self) -> int:
        return self.n_head // self.n_kv_head

    @property
    def n_patches(self) -> int:
        return (self.image_size // self.patch_size) ** 2


# Instantiate default config
config = ModelConfig()

# --- Training settings ---
TRAIN_NEW_MODEL = False
TOKENIZATION_MODE = 'bpe'
GITHUB_TOKEN_FILE = 'github_token_secret.txt'
LOCAL_FILE = 'training_data_tinystories_ver_4.txt'
TRAINING_DATA_URL = "https://raw.githubusercontent.com/diidihamm/Project_7/main/training_data_tinystories_ver_3.txt"
MODEL_FILENAME = 'tiny_qa_model_2.pth'
TRAINED_MODEL_URL = "https://github.com/diidihamm/Project_7/releases/download/model-20260208_014754/tiny_qa_model_2.pth"

# --- RLHF settings ---
RUN_RLHF = False          # Set True to run RLHF after base training
RLHF_EPOCHS = 2           # PPO epochs
RLHF_BATCH_SIZE = 4       # Prompts per PPO batch
RLHF_LR = 1e-5            # Much smaller LR for RLHF (fine-tuning)
RLHF_KL_COEFF = 0.1       # KL penalty — prevents policy from drifting too far from reference
RLHF_CLIP_RANGE = 0.2     # PPO clipping range
RLHF_PROMPTS = [
    "Tell me a short story about a cat.",
    "Write a story about a little girl who finds a magic flower.",
    "Tell me about a boy who learns to be brave.",
    "Write a short story about friendship.",
    "Tell me a bedtime story about the moon.",
    "Write a story about a dog who goes on an adventure.",
    "Tell me about a princess who saves a dragon.",
    "Write a short story about sharing.",
]

# --- Device ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


#========================================================
# --- 3. CORE COMPONENTS
#========================================================

# ---- 3a. RMSNorm ----
class RMSNorm(nn.Module):
    """Replaces LayerNorm. Simpler, faster, no mean-centering."""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: Tensor) -> Tensor:
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return self.weight * (x / rms)


# ---- 3b. Rotary Position Embeddings (RoPE) ----
class RotaryPositionEmbedding(nn.Module):
    """Replaces learned position embeddings. Encodes relative position
    via rotation of Q/K vectors. Enables length extrapolation."""
    def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        self.register_buffer('cos_cache', freqs.cos(), persistent=False)
        self.register_buffer('sin_cache', freqs.sin(), persistent=False)

    def forward(self, seq_len: int, offset: int = 0) -> Tuple[Tensor, Tensor]:
        end = offset + seq_len
        if end > self.cos_cache.shape[0]:
            self._build_cache(end)
        return self.cos_cache[offset:end], self.sin_cache[offset:end]


def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
    """Rotate Q or K by position-dependent angles.
    x: [B, n_heads, T, head_dim]  cos/sin: [T, head_dim//2]"""
    d = x.shape[-1] // 2
    x1, x2 = x[..., :d], x[..., d:]
    cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, T, d]
    sin = sin.unsqueeze(0).unsqueeze(0)
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)


# ---- 3c. SwiGLU Feed-Forward ----
class SwiGLU(nn.Module):
    """Gated FFN with SiLU. Three matrices instead of two.
    gate path decides WHICH information passes; value path provides WHAT."""
    def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(2 * dim * 4 / 3)
            hidden_dim = 64 * ((hidden_dim + 63) // 64)
        self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.w_up   = nn.Linear(dim, hidden_dim, bias=False)
        self.w_down = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))


# ---- 3d. Mixture of Experts ----
class MoELayer(nn.Module):
    """N expert FFNs, router picks top-k per token. Each token only
    activates k experts, so compute is k/N of full cost."""
    def __init__(self, dim: int, n_experts: int, n_active: int, dropout: float = 0.0):
        super().__init__()
        self.n_experts = n_experts
        self.n_active = n_active
        self.router = nn.Linear(dim, n_experts, bias=False)
        self.experts = nn.ModuleList([SwiGLU(dim, dropout=dropout) for _ in range(n_experts)])

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        B, T, D = x.shape
        x_flat = x.view(-1, D)

        # Router selects experts
        router_logits = self.router(x_flat)
        router_probs = F.softmax(router_logits, dim=-1)
        top_k_probs, top_k_idx = torch.topk(router_probs, self.n_active, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Dispatch tokens to experts
        output = torch.zeros_like(x_flat)
        for k in range(self.n_active):
            for e in range(self.n_experts):
                mask = (top_k_idx[:, k] == e)
                if mask.any():
                    expert_out = self.experts[e](x_flat[mask])
                    output[mask] += top_k_probs[mask, k].unsqueeze(-1) * expert_out

        # Load-balancing auxiliary loss
        tokens_per_expert = torch.zeros(self.n_experts, device=x.device)
        for k in range(self.n_active):
            for e in range(self.n_experts):
                tokens_per_expert[e] += (top_k_idx[:, k] == e).float().sum()
        tokens_per_expert = tokens_per_expert / (B * T * self.n_active)
        prob_per_expert = router_probs.mean(dim=0)
        aux_loss = self.n_experts * (tokens_per_expert * prob_per_expert).sum()

        return output.view(B, T, D), aux_loss


#========================================================
# --- 4. GROUPED QUERY ATTENTION
#========================================================
class GroupedQueryAttention(nn.Module):
    """GQA with RoPE, QK-Norm, optional sliding window, and KV-cache.
    Multiple Q heads share fewer K/V heads → smaller KV-cache at inference."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_kv_groups = config.n_kv_groups
        self.head_dim = config.head_dim
        self.sliding_window = config.sliding_window

        # Projections: Q has n_head heads, K/V have n_kv_head heads
        self.q_proj = nn.Linear(config.n_embd, config.n_head * config.head_dim, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_kv_head * config.head_dim, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_kv_head * config.head_dim, bias=False)
        self.o_proj = nn.Linear(config.n_head * config.head_dim, config.n_embd, bias=False)

        # QK-Norm: prevents attention logits from exploding in deep networks
        self.q_norm = RMSNorm(config.head_dim)
        self.k_norm = RMSNorm(config.head_dim)

        # RoPE (shared across all heads, operates on head_dim)
        self.rope = RotaryPositionEmbedding(config.head_dim, config.block_size * 4)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

    def forward(self, x: Tensor, kv_cache: Optional[Tuple[Tensor, Tensor]] = None
                ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        B, T, D = x.shape

        # Project to Q (n_head), K (n_kv_head), V (n_kv_head)
        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)

        # QK-Norm
        q = self.q_norm(q)
        k = self.k_norm(k)

        # RoPE: position-aware rotation
        offset = 0 if kv_cache is None else kv_cache[0].shape[2]
        cos, sin = self.rope(T, offset=offset)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        # KV-cache: append new K/V to cached K/V from previous tokens
        if kv_cache is not None:
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)
        new_kv_cache = (k, v)

        # Expand K/V heads to match Q heads (GQA repeat)
        k_expanded = k.repeat_interleave(self.n_kv_groups, dim=1)
        v_expanded = v.repeat_interleave(self.n_kv_groups, dim=1)

        # Scaled dot-product attention
        full_len = k_expanded.shape[2]
        scale = 1.0 / math.sqrt(self.head_dim)
        attn = (q @ k_expanded.transpose(-2, -1)) * scale

        # Causal mask: each query attends only to past positions
        causal_mask = torch.triu(
            torch.ones(T, full_len, device=x.device, dtype=torch.bool),
            diagonal=full_len - T + 1
        )

        # Sliding window mask: restrict attention to recent positions
        if self.sliding_window > 0 and full_len > self.sliding_window:
            positions = torch.arange(full_len, device=x.device)
            query_positions = torch.arange(full_len - T, full_len, device=x.device)
            window_mask = (query_positions.unsqueeze(1) - positions.unsqueeze(0)) >= self.sliding_window
            causal_mask = causal_mask | window_mask

        attn = attn.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        out = (attn @ v_expanded).transpose(1, 2).contiguous().view(B, T, D)
        return self.resid_dropout(self.o_proj(out)), new_kv_cache


#========================================================
# --- 5. TRANSFORMER BLOCK
#========================================================
class TransformerBlock(nn.Module):
    """Pre-norm block: RMSNorm → GQA → residual, RMSNorm → FFN/MoE → residual.
    Residual outputs are scaled by 1/sqrt(n_layers) for stability."""
    def __init__(self, layer_idx: int, config: ModelConfig):
        super().__init__()
        self.attn_norm = RMSNorm(config.n_embd)
        self.ffn_norm  = RMSNorm(config.n_embd)
        self.attn = GroupedQueryAttention(config)

        # MoE or dense SwiGLU
        self.use_moe = (config.n_expert > 1) and (layer_idx % config.moe_every_n == 0)
        if self.use_moe:
            self.ffn = MoELayer(config.n_embd, config.n_expert, config.n_expert_active, config.dropout)
        else:
            self.ffn = SwiGLU(config.n_embd, dropout=config.dropout)

        self.residual_scale = 1.0 / math.sqrt(config.n_layer)

    def forward(self, x: Tensor, kv_cache=None) -> Tuple[Tensor, Tuple, float]:
        # Attention
        attn_out, new_kv_cache = self.attn(self.attn_norm(x), kv_cache)
        x = x + self.residual_scale * attn_out

        # FFN (MoE or dense)
        aux_loss = 0.0
        if self.use_moe:
            ffn_out, aux_loss = self.ffn(self.ffn_norm(x))
        else:
            ffn_out = self.ffn(self.ffn_norm(x))
        x = x + self.residual_scale * ffn_out

        return x, new_kv_cache, aux_loss


#========================================================
# --- 6. VISION ENCODER (optional, for multi-modal)
#========================================================
class VisionEncoder(nn.Module):
    """ViT-style encoder: image → patches → embeddings → transformer → project.
    Output is a sequence of visual tokens in the language model's embedding space."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        n_patches = config.n_patches
        patch_dim = config.patch_size * config.patch_size * 3

        self.patch_embed = nn.Linear(patch_dim, config.vision_n_embd)
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches, config.vision_n_embd) * 0.02)

        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'norm1': RMSNorm(config.vision_n_embd),
                'attn': nn.MultiheadAttention(config.vision_n_embd, config.vision_n_head, batch_first=True),
                'norm2': RMSNorm(config.vision_n_embd),
                'ffn': SwiGLU(config.vision_n_embd),
            })
            for _ in range(config.vision_n_layer)
        ])
        self.norm_out = RMSNorm(config.vision_n_embd)
        self.proj = nn.Linear(config.vision_n_embd, config.n_embd)

    def forward(self, images: Tensor) -> Tensor:
        B, C, H, W = images.shape
        p = self.config.patch_size
        patches = images.unfold(2, p, p).unfold(3, p, p)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        patches = patches.view(B, -1, C * p * p)

        x = self.patch_embed(patches) + self.pos_embed
        for layer in self.layers:
            normed = layer['norm1'](x)
            x = x + layer['attn'](normed, normed, normed)[0]
            x = x + layer['ffn'](layer['norm2'](x))
        return self.proj(self.norm_out(x))


#========================================================
# --- 7. MAIN MODEL: UltimateClaude
#========================================================
class UltimateClaude(nn.Module):
    """The complete model. Same skeleton as NanoDiidi — decoder-only
    transformer with causal attention — but every component is upgraded."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # Token embedding (NO learned position embedding — RoPE handles it)
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.emb_scale = math.sqrt(config.n_embd)
        self.emb_dropout = nn.Dropout(config.dropout)

        # Transformer stack
        self.layers = nn.ModuleList([
            TransformerBlock(i, config) for i in range(config.n_layer)
        ])
        self.norm_out = RMSNorm(config.n_embd)

        # Language model head (weight-tied with token embedding)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_head.weight = self.tok_emb.weight

        # Vision encoder (only if enabled)
        self.vision_encoder = VisionEncoder(config) if config.use_vision else None

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def get_num_params(self, non_embedding: bool = True) -> int:
        n = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n -= self.tok_emb.weight.numel()
        return n

    def forward(self, idx: Tensor, targets: Optional[Tensor] = None,
                images: Optional[Tensor] = None,
                kv_caches: Optional[List] = None
                ) -> Tuple[Tensor, Optional[Tensor], List]:
        B, T = idx.shape

        # Token embeddings with scaling
        x = self.tok_emb(idx) * self.emb_scale

        # Prepend vision tokens if images provided
        if images is not None and self.vision_encoder is not None:
            vis_tokens = self.vision_encoder(images)
            x = torch.cat([vis_tokens, x], dim=1)
            T = x.shape[1]

        x = self.emb_dropout(x)

        # Transformer layers
        new_kv_caches = []
        total_aux_loss = 0.0
        for i, layer in enumerate(self.layers):
            kv_cache = kv_caches[i] if kv_caches is not None else None
            x, new_kv, aux_loss = layer(x, kv_cache)
            new_kv_caches.append(new_kv)
            total_aux_loss += aux_loss

        x = self.norm_out(x)
        logits = self.lm_head(x)

        # Compute loss if targets provided
        loss = None
        if targets is not None:
            if images is not None and self.vision_encoder is not None:
                logits = logits[:, self.config.n_patches:, :]
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = loss + 0.01 * total_aux_loss  # MoE load-balancing

        return logits, loss, new_kv_caches

    @torch.no_grad()
    def generate(self, idx: Tensor, max_new_tokens: int,
                 temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9,
                 repetition_penalty: float = 1.2,
                 stop_strings: Optional[List[str]] = None,
                 decode_fn: Optional[Callable] = None,
                 callback: Optional[Callable] = None) -> Tensor:
        """Autoregressive generation with KV-cache for efficiency."""
        self.eval()
        kv_caches = None
        all_generated = []

        for _ in range(max_new_tokens):
            if kv_caches is not None:
                idx_input = idx[:, -1:]
            else:
                idx_input = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            logits, _, kv_caches = self(idx_input, kv_caches=kv_caches)
            logits = logits[:, -1, :]

            # Repetition penalty
            if repetition_penalty != 1.0:
                for token_id in set(idx[0].tolist()):
                    if logits[0, token_id] > 0:
                        logits[0, token_id] /= repetition_penalty
                    else:
                        logits[0, token_id] *= repetition_penalty

            logits = logits / temperature

            # Top-k filtering
            if top_k > 0:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(logits, descending=True)
                cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                mask = (cumulative - F.softmax(sorted_logits, dim=-1)) > top_p
                sorted_logits[mask] = float('-inf')
                logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)

            token_id = idx_next[0].item()
            all_generated.append(token_id)

            if callback:
                callback(token_id)

            if stop_strings and decode_fn:
                text = decode_fn(all_generated)
                if any(s in text for s in stop_strings):
                    break

        return idx


#========================================================
# --- 8. REWARD MODEL (for RLHF)
#========================================================
class RewardModel(nn.Module):
    """Same transformer, but instead of predicting next token it outputs
    a scalar score for 'how good is this response'. Used to train the
    main model via reinforcement learning (RLHF)."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.emb_scale = math.sqrt(config.n_embd)
        self.layers = nn.ModuleList([
            TransformerBlock(i, config) for i in range(config.n_layer)
        ])
        self.norm_out = RMSNorm(config.n_embd)
        self.score_head = nn.Linear(config.n_embd, 1, bias=False)

    def forward(self, idx: Tensor) -> Tensor:
        x = self.tok_emb(idx) * self.emb_scale
        for i, layer in enumerate(self.layers):
            x, _, _ = layer(x)
        x = self.norm_out(x)
        return self.score_head(x[:, -1, :]).squeeze(-1)


#========================================================
# --- 9. SPECULATIVE DECODING
#========================================================
@torch.no_grad()
def speculative_decode(target: UltimateClaude, draft: UltimateClaude,
                       idx: Tensor, max_new_tokens: int,
                       n_draft: int = 4, temperature: float = 0.8) -> Tensor:
    """Draft model proposes tokens fast, target model verifies in parallel.
    Accepted tokens are 'free'; rejected ones get resampled from target.
    Typical acceptance rate: 70-85%, giving ~3x speedup."""
    generated = 0
    while generated < max_new_tokens:
        # Draft: generate n_draft candidates autoregressively
        draft_idx = idx.clone()
        draft_probs_list = []
        for _ in range(min(n_draft, max_new_tokens - generated)):
            logits, _, _ = draft(draft_idx[:, -draft.config.block_size:])
            logits = logits[:, -1, :] / temperature
            dp = F.softmax(logits, dim=-1)
            draft_probs_list.append(dp)
            draft_idx = torch.cat([draft_idx, torch.multinomial(dp, 1)], dim=1)

        n_proposed = len(draft_probs_list)

        # Target: verify all draft tokens in ONE forward pass
        verify_input = draft_idx[:, -target.config.block_size:]
        target_logits, _, _ = target(verify_input)

        # Accept/reject each draft token
        n_accepted = 0
        for i in range(n_proposed):
            pos = -(n_proposed - i)
            tp = F.softmax(target_logits[:, pos - 1, :] / temperature, dim=-1)
            tok = draft_idx[:, idx.shape[1] + i]
            p_t = tp[0, tok[0]].item()
            p_d = draft_probs_list[i][0, tok[0]].item()

            if torch.rand(1).item() < min(1.0, p_t / max(p_d, 1e-10)):
                n_accepted += 1
            else:
                adjusted = F.relu(tp - draft_probs_list[i])
                adjusted = adjusted / (adjusted.sum(dim=-1, keepdim=True) + 1e-10)
                new_tok = torch.multinomial(adjusted, 1)
                idx = torch.cat([idx, draft_idx[:, idx.shape[1]:idx.shape[1]+n_accepted], new_tok], dim=1)
                generated += n_accepted + 1
                break
        else:
            bonus = F.softmax(target_logits[:, -1, :] / temperature, dim=-1)
            idx = torch.cat([idx, draft_idx[:, idx.shape[1]:], torch.multinomial(bonus, 1)], dim=1)
            generated += n_proposed + 1

    return idx


#========================================================
# --- 10. RLHF TRAINING PIPELINE
#========================================================
# RLHF turns a text-completion engine into a helpful assistant.
# Four models work together:
#   1. Policy  (trainable)  — the model we're improving
#   2. Reference (frozen)   — copy of policy BEFORE RLHF, prevents drift
#   3. Reward model (frozen) — scores "how good" a response is
#   4. Value head (trainable) — estimates expected future reward at each token
#
# Training loop (PPO):
#   For each prompt:
#     → Policy generates a response
#     → Reward model scores it
#     → Compare policy vs reference to compute KL penalty
#     → Compute advantages (how much better than expected)
#     → Update policy with clipped PPO objective
#========================================================

def compute_log_probs(model: UltimateClaude, input_ids: Tensor) -> Tensor:
    """Compute per-token log probabilities for a sequence.
    Returns log P(token_t | tokens_0..t-1) for each position."""
    logits, _, _ = model(input_ids)
    log_probs = F.log_softmax(logits, dim=-1)
    # Gather log prob of each actual next token
    # logits[:, t, :] predicts token at position t+1
    target_tokens = input_ids[:, 1:].unsqueeze(-1)
    token_log_probs = log_probs[:, :-1, :].gather(-1, target_tokens).squeeze(-1)
    return token_log_probs


class PolicyWithValueHead(nn.Module):
    """Wraps UltimateClaude with a value head for PPO.
    The value head shares the transformer body but outputs a scalar
    estimate of 'how much reward will the rest of this sequence get'."""
    def __init__(self, policy: UltimateClaude):
        super().__init__()
        self.policy = policy
        self.value_head = nn.Linear(policy.config.n_embd, 1, bias=False)
        nn.init.zeros_(self.value_head.weight)

    def forward(self, idx: Tensor) -> Tuple[Tensor, Tensor]:
        """Returns (logits, values) — values are per-token reward estimates."""
        x = self.policy.tok_emb(idx) * self.policy.emb_scale
        x = self.policy.emb_dropout(x)
        for layer in self.policy.layers:
            x, _, _ = layer(x)
        x = self.policy.norm_out(x)
        logits = self.policy.lm_head(x)
        values = self.value_head(x).squeeze(-1)
        return logits, values

    def generate(self, *args, **kwargs):
        return self.policy.generate(*args, **kwargs)


@torch.no_grad()
def generate_preference_data(model: UltimateClaude, prompts: List[str],
                             encode: Callable, decode: Callable) -> List[Tuple[str, str, str]]:
    """Generate synthetic preference pairs.
    For each prompt, generate multiple responses at different temperatures.
    Rank by perplexity (lower = more coherent = 'chosen').

    In production RLHF, this data comes from human annotators ranking
    response pairs. Here we approximate with self-play."""
    model.eval()
    preferences = []

    for prompt in prompts:
        formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
        prompt_ids = encode(formatted)
        input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)

        responses = []
        for temp in [0.5, 1.2]:  # low temp = coherent, high temp = creative/noisy
            output = model.generate(input_tensor.clone(), max_new_tokens=80,
                                    temperature=temp, top_k=40, top_p=0.9)
            response_ids = output[0, len(prompt_ids):].tolist()
            response_text = decode(response_ids)
            # Score by perplexity (use the model's own loss as proxy)
            with torch.no_grad():
                _, loss, _ = model(output, output.clone())
            responses.append((response_text, loss.item() if loss is not None else float('inf')))

        # Sort by loss: lower loss = more coherent = "chosen"
        responses.sort(key=lambda x: x[1])
        chosen = responses[0][0]
        rejected = responses[-1][0]

        if chosen != rejected:
            preferences.append((formatted, chosen, rejected))
            print(f"  Prompt: {prompt[:50]}...")
            print(f"    Chosen  (loss={responses[0][1]:.3f}): {chosen[:60]}...")
            print(f"    Rejected(loss={responses[-1][1]:.3f}): {rejected[:60]}...")

    print(f"\nGenerated {len(preferences)} preference pairs")
    return preferences


def train_reward_model(reward_model: RewardModel, preferences: List[Tuple[str, str, str]],
                       encode: Callable, n_epochs: int = 3, lr: float = 1e-4):
    """Train the reward model on preference pairs.
    Loss: -log(sigmoid(r_chosen - r_rejected))
    This is the Bradley-Terry model — the probability that response A is
    preferred over response B is sigmoid(r_A - r_B)."""
    optimizer = torch.optim.AdamW(reward_model.parameters(), lr=lr)
    reward_model.train()

    for epoch in range(n_epochs):
        total_loss = 0.0
        correct = 0

        for prompt, chosen, rejected in preferences:
            chosen_ids = torch.tensor([encode(prompt + chosen)], dtype=torch.long, device=device)
            rejected_ids = torch.tensor([encode(prompt + rejected)], dtype=torch.long, device=device)

            # Truncate to reasonable length
            chosen_ids = chosen_ids[:, :512]
            rejected_ids = rejected_ids[:, :512]

            r_chosen = reward_model(chosen_ids)
            r_rejected = reward_model(rejected_ids)

            # Bradley-Terry loss
            loss = -F.logsigmoid(r_chosen - r_rejected).mean()

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(reward_model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            if r_chosen.item() > r_rejected.item():
                correct += 1

        accuracy = correct / max(len(preferences), 1) * 100
        avg_loss = total_loss / max(len(preferences), 1)
        print(f"  Reward Model Epoch {epoch+1}/{n_epochs} — Loss: {avg_loss:.4f}, Accuracy: {accuracy:.1f}%")

    reward_model.eval()


def rlhf_ppo_train(policy_model: UltimateClaude, reward_model: RewardModel,
                    prompts: List[str], encode: Callable, decode: Callable,
                    n_epochs: int = 2):
    """Proximal Policy Optimization (PPO) for RLHF.

    The heart of RLHF: adjust the policy so it generates responses
    that score higher on the reward model, while staying close to
    the original reference policy (via KL penalty)."""

    import copy

    # --- Setup the four models ---
    # 1. Policy with value head (trainable)
    policy = PolicyWithValueHead(policy_model).to(device)
    policy.train()

    # 2. Reference model (frozen copy — the anchor)
    ref_model = copy.deepcopy(policy_model).to(device)
    ref_model.eval()
    for p in ref_model.parameters():
        p.requires_grad = False

    # 3. Reward model is already trained and frozen
    reward_model.eval()
    for p in reward_model.parameters():
        p.requires_grad = False

    optimizer = torch.optim.AdamW(policy.parameters(), lr=RLHF_LR)

    print(f"\n{'='*70}")
    print("PPO TRAINING")
    print(f"{'='*70}")
    print(f"  Epochs: {n_epochs} | Prompts: {len(prompts)}")
    print(f"  KL coeff: {RLHF_KL_COEFF} | Clip range: {RLHF_CLIP_RANGE}")
    print(f"  Reference model frozen. Policy training begins.\n")

    for epoch in range(n_epochs):
        epoch_rewards = []
        epoch_policy_loss = []
        epoch_value_loss = []
        epoch_kl = []

        for prompt_text in prompts:
            # --- ROLLOUT: Generate a response from current policy ---
            formatted = f"### Instruction:\n{prompt_text}\n\n### Response:\n"
            prompt_ids = encode(formatted)
            prompt_len = len(prompt_ids)
            input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)

            with torch.no_grad():
                output = policy.generate(input_tensor.clone(), max_new_tokens=80,
                                         temperature=0.8, top_k=40)
            full_ids = output  # [1, prompt_len + response_len]
            if full_ids.shape[1] <= prompt_len + 1:
                continue  # Skip if no response generated

            # --- SCORING ---
            # Reward from the reward model
            with torch.no_grad():
                reward = reward_model(full_ids).item()
            epoch_rewards.append(reward)

            # Log probs from policy and reference
            with torch.no_grad():
                old_log_probs = compute_log_probs(policy.policy, full_ids)
                ref_log_probs = compute_log_probs(ref_model, full_ids)

            # Value estimates from the value head
            with torch.no_grad():
                _, old_values = policy(full_ids)
                old_values = old_values[:, :-1]  # align with log_probs

            # --- COMPUTE ADVANTAGES ---
            # Only for response tokens (not the prompt)
            resp_start = prompt_len - 1  # -1 because log_probs is shifted
            resp_old_log = old_log_probs[:, resp_start:]
            resp_ref_log = ref_log_probs[:, resp_start:]
            resp_values = old_values[:, resp_start:]

            # KL penalty per token
            kl = resp_old_log - resp_ref_log
            epoch_kl.append(kl.mean().item())

            # Per-token rewards: KL penalty everywhere, actual reward at last token
            per_token_reward = -RLHF_KL_COEFF * kl
            per_token_reward[:, -1] += reward

            # Advantages = rewards - value baseline
            advantages = per_token_reward - resp_values.detach()
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            # Return targets for value head
            returns = per_token_reward.detach()

            # --- PPO UPDATE ---
            # Recompute log probs under CURRENT policy (may have changed)
            new_logits, new_values = policy(full_ids)
            new_log_probs = F.log_softmax(new_logits, dim=-1)
            new_token_log = new_log_probs[:, :-1, :].gather(
                -1, full_ids[:, 1:].unsqueeze(-1)
            ).squeeze(-1)
            new_resp_log = new_token_log[:, resp_start:]
            new_resp_values = new_values[:, :-1][:, resp_start:]

            # PPO ratio
            ratio = torch.exp(new_resp_log - resp_old_log.detach())
            clipped_ratio = torch.clamp(ratio, 1.0 - RLHF_CLIP_RANGE, 1.0 + RLHF_CLIP_RANGE)

            # Policy loss (maximize advantage, clipped)
            policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()

            # Value loss (predict returns accurately)
            value_loss = F.mse_loss(new_resp_values, returns)

            # Total loss
            loss = policy_loss + 0.5 * value_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
            optimizer.step()

            epoch_policy_loss.append(policy_loss.item())
            epoch_value_loss.append(value_loss.item())

        # Epoch summary
        avg_reward = sum(epoch_rewards) / max(len(epoch_rewards), 1)
        avg_ploss = sum(epoch_policy_loss) / max(len(epoch_policy_loss), 1)
        avg_vloss = sum(epoch_value_loss) / max(len(epoch_value_loss), 1)
        avg_kl = sum(epoch_kl) / max(len(epoch_kl), 1)
        print(f"  PPO Epoch {epoch+1}/{n_epochs} — "
              f"Reward: {avg_reward:.4f} | Policy Loss: {avg_ploss:.4f} | "
              f"Value Loss: {avg_vloss:.4f} | KL: {avg_kl:.4f}")

    # Copy trained weights back to the original model
    policy_model.load_state_dict(policy.policy.state_dict())
    print(f"\nRLHF complete. Policy updated.")
    return policy_model


def run_rlhf_pipeline(model: UltimateClaude, encode: Callable, decode: Callable,
                      config: ModelConfig):
    """Full RLHF pipeline: generate preferences → train reward → PPO.

    In production:
      - Preference data comes from thousands of human annotators
      - Reward model is trained on millions of comparisons
      - PPO runs for thousands of steps across many GPUs
      - Constitutional AI adds a self-critique loop

    Here we demonstrate the full pipeline at nano scale."""

    print("\n" + "=" * 70)
    print("RLHF PIPELINE")
    print("=" * 70)

    # Step 1: Generate synthetic preference data
    print("\n--- Step 1: Generating preference data ---")
    preferences = generate_preference_data(model, RLHF_PROMPTS, encode, decode)

    if len(preferences) < 2:
        print("Not enough preference pairs generated. Skipping RLHF.")
        return model

    # Step 2: Train reward model
    print("\n--- Step 2: Training reward model ---")
    reward_config = ModelConfig(
        vocab_size=config.vocab_size,
        n_layer=max(2, config.n_layer // 2),  # Smaller than policy
        n_head=config.n_head,
        n_kv_head=config.n_kv_head,
        n_embd=config.n_embd,
        block_size=config.block_size,
        n_expert=1,  # Dense (no MoE for reward model)
    )
    reward_model = RewardModel(reward_config).to(device)
    print(f"  Reward model: {sum(p.numel() for p in reward_model.parameters())/1e6:.2f}M params")
    train_reward_model(reward_model, preferences, encode, n_epochs=3)

    # Step 3: PPO
    print("\n--- Step 3: PPO training ---")
    model = rlhf_ppo_train(model, reward_model, RLHF_PROMPTS, encode, decode,
                           n_epochs=RLHF_EPOCHS)

    return model


#========================================================
# --- 11. TRAINING UTILITIES
#========================================================
def get_lr(it: int, config: ModelConfig) -> float:
    """Cosine learning rate schedule with linear warmup."""
    if it < config.warmup_iters:
        return config.learning_rate * it / max(config.warmup_iters, 1)
    if it > config.max_iters:
        return config.learning_rate * 0.1
    decay_ratio = (it - config.warmup_iters) / (config.max_iters - config.warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return 0.1 * config.learning_rate + coeff * 0.9 * config.learning_rate


def get_batch(data: Tensor, config: ModelConfig) -> Tuple[Tensor, Tensor]:
    ix = torch.randint(len(data) - config.block_size, (config.batch_size,))
    x = torch.stack([data[i:i+config.block_size] for i in ix])
    y = torch.stack([data[i+1:i+config.block_size+1] for i in ix])
    return x.to(device), y.to(device)


@torch.no_grad()
def estimate_loss(model: UltimateClaude, train_data: Tensor, val_data: Tensor,
                  config: ModelConfig) -> Dict[str, float]:
    out = {}
    model.eval()
    eval_iters = 5 if device == 'cpu' else 200
    for split, data in [('train', train_data), ('val', val_data)]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(data, config)
            _, loss, _ = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out


#========================================================
# --- 11. DATA LOADING
#========================================================
def load_training_data(local_file: str) -> Optional[str]:
    if os.path.exists(local_file):
        try:
            with open(local_file, 'r', encoding='utf-8', errors='replace') as f:
                text = f.read()
            print(f"Loaded {len(text):,} characters from {local_file}")
            return text
        except Exception as e:
            print(f"Error reading local file: {e}")
    try:
        import urllib.request
        print(f"Downloading training data from {TRAINING_DATA_URL}...")
        with urllib.request.urlopen(TRAINING_DATA_URL) as response:
            text = response.read().decode('utf-8', errors='replace')
            print(f"Downloaded {len(text):,} characters")
            return text
    except Exception as e:
        print(f"Download failed: {e}")
    return "To be, or not to be, that is the question."


#========================================================
# --- 12. MODEL SAVE / LOAD
#========================================================
def get_save_path(filename):
    try:
        import google.colab
        save_dir = '/content'
    except ImportError:
        save_dir = os.getcwd()
    if not os.path.isabs(filename):
        return os.path.join(save_dir, filename)
    return filename


def save_model(model: UltimateClaude, filepath: str, vocab_size: int,
               stoi: dict, itos: dict, config: ModelConfig,
               tokenization_mode: str = 'bpe'):
    if not os.path.isabs(filepath):
        filepath = get_save_path(filepath)
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'vocab_size': vocab_size,
        'stoi': stoi,
        'itos': itos,
        'tokenization_mode': tokenization_mode,
        'config': {
            'n_layer': config.n_layer,
            'n_head': config.n_head,
            'n_kv_head': config.n_kv_head,
            'n_embd': config.n_embd,
            'block_size': config.block_size,
            'dropout': config.dropout,
            'n_expert': config.n_expert,
            'n_expert_active': config.n_expert_active,
            'moe_every_n': config.moe_every_n,
            'sliding_window': config.sliding_window,
        }
    }
    torch.save(checkpoint, filepath)
    print(f"Model saved to {filepath} ({os.path.getsize(filepath)/1024/1024:.2f} MB)")
    return filepath


def load_model(filepath: str):
    if not os.path.exists(filepath):
        print(f"Model file '{filepath}' not found!")
        return None, None, None, None, None
    checkpoint = torch.load(filepath, map_location=device)
    saved_vocab = checkpoint['vocab_size']
    stoi = checkpoint.get('stoi', {})
    itos = checkpoint.get('itos', {})
    tokenization_mode = checkpoint.get('tokenization_mode', 'bpe')
    sc = checkpoint['config']
    model_config = ModelConfig(
        vocab_size=saved_vocab,
        n_layer=sc['n_layer'], n_head=sc['n_head'],
        n_kv_head=sc.get('n_kv_head', sc['n_head']),
        n_embd=sc['n_embd'], block_size=sc['block_size'],
        dropout=sc['dropout'],
        n_expert=sc.get('n_expert', 1),
        n_expert_active=sc.get('n_expert_active', 1),
        moe_every_n=sc.get('moe_every_n', 2),
        sliding_window=sc.get('sliding_window', 0),
    )
    model = UltimateClaude(model_config).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f"Model loaded from {filepath}")
    print(f"   Vocab: {saved_vocab:,} | Tokenization: {tokenization_mode} | Params: {model.get_num_params()/1e6:.2f}M")
    return model, stoi, itos, saved_vocab, tokenization_mode


#========================================================
# --- 13. GITHUB UTILITIES
#========================================================
def get_github_token():
    token_files = [GITHUB_TOKEN_FILE, f'/content/{GITHUB_TOKEN_FILE}',
                   os.path.join(os.getcwd(), GITHUB_TOKEN_FILE)]
    for tf in token_files:
        if os.path.exists(tf):
            try:
                with open(tf, 'r') as f:
                    token = f.readline().strip()
                    if token and not token.startswith('#'):
                        return token
            except:
                continue
    token = os.environ.get('GITHUB_TOKEN')
    if token:
        return token
    return None


def save_to_github_release(filename):
    import requests
    from datetime import datetime
    USERNAME, REPO_NAME = "diidihamm", "Project_7"
    try:
        TOKEN = get_github_token()
        if not TOKEN:
            print("No GitHub token found, skipping upload.")
            return False
        filepath = filename if os.path.isabs(filename) else get_save_path(filename)
        asset_name = os.path.basename(filename)
        if not os.path.exists(filepath):
            return False
        file_size = os.path.getsize(filepath)
        print(f"Uploading: {asset_name} ({file_size / (1024*1024):.2f} MB)")
        headers = {'Authorization': f'token {TOKEN}', 'Accept': 'application/vnd.github.v3+json'}
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        release_url = f"https://api.github.com/repos/{USERNAME}/{REPO_NAME}/releases"
        release_data = {'tag_name': f"model-{timestamp}",
                        'name': f"UltimateClaude Model - {datetime.now()}", 'draft': False}
        resp = requests.post(release_url, headers=headers, json=release_data)
        if resp.status_code != 201:
            print(f"Failed to create release: {resp.status_code}")
            return False
        upload_url = resp.json()['upload_url'].replace('{?name,label}', '')
        upload_headers = {'Authorization': f'token {TOKEN}', 'Content-Type': 'application/octet-stream'}
        with open(filepath, 'rb') as f:
            upload_resp = requests.post(f"{upload_url}?name={asset_name}",
                                        headers=upload_headers, data=f.read())
        if upload_resp.status_code == 201:
            print(f"Uploaded to GitHub Releases: {upload_resp.json()['browser_download_url']}")
            return True
        print(f"Upload failed: {upload_resp.status_code}")
        return False
    except Exception as e:
        print(f"GitHub upload error: {e}")
        return False


#========================================================
# --- 14. INTERACTIVE CHAT
#========================================================
def interactive_chat(model: UltimateClaude, encode: Callable, decode: Callable):
    print("\n" + "=" * 70)
    print("ULTIMATE CLAUDE — INTERACTIVE CHAT (Alpaca format)")
    print("=" * 70)
    print("\nCommands: 'quit', 'settings', 'help', 'clear', 'stats'")

    params = {'max_tokens': 300, 'temperature': 0.7, 'top_k': 40,
              'top_p': 0.9, 'repetition_penalty': 1.2}
    stop_strings = ["<|endoftext|>", "### Instruction:", "### Input:"]

    print(f"\nGeneration: max_tokens={params['max_tokens']}, temp={params['temperature']}")
    print(f"Stop strings: {stop_strings}")
    print(f"Architecture: RoPE + RMSNorm + SwiGLU + GQA + MoE + Sliding Window")
    print("=" * 70)

    while True:
        try:
            prompt = input("\nYou: ").strip()
            if not prompt:
                continue

            cmd = prompt.lower()
            if cmd in ['quit', 'exit', 'q']:
                print("Goodbye!")
                break
            elif cmd == 'help':
                print("\nCommands: quit, settings, help, clear, stats")
                continue
            elif cmd == 'clear':
                os.system('cls' if os.name == 'nt' else 'clear')
                continue
            elif cmd == 'stats':
                print(f"\nModel: {model.get_num_params()/1e6:.2f}M parameters")
                print(f"Config: {model.config.n_layer} layers, {model.config.n_head} Q-heads, "
                      f"{model.config.n_kv_head} KV-heads")
                moe_layers = sum(1 for i in range(model.config.n_layer)
                                 if model.config.n_expert > 1 and i % model.config.moe_every_n == 0)
                print(f"MoE layers: {moe_layers}/{model.config.n_layer} "
                      f"({model.config.n_expert} experts, top-{model.config.n_expert_active})")
                print(f"Sliding window: {model.config.sliding_window}")
                continue
            elif cmd == 'settings':
                print(f"\n1. max_tokens: {params['max_tokens']}")
                print(f"2. temperature: {params['temperature']}")
                print(f"3. top_k: {params['top_k']}")
                print(f"4. top_p: {params['top_p']}")
                print(f"5. repetition_penalty: {params['repetition_penalty']}")
                choice = input("Change (1-5): ").strip()
                try:
                    if choice == '1': params['max_tokens'] = int(input("New value: "))
                    elif choice == '2': params['temperature'] = float(input("New value: "))
                    elif choice == '3': params['top_k'] = int(input("New value: "))
                    elif choice == '4': params['top_p'] = float(input("New value: "))
                    elif choice == '5': params['repetition_penalty'] = float(input("New value: "))
                except ValueError:
                    print("Invalid input.")
                continue

            # Format prompt as Alpaca
            formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"

            print(f"\nUltimateClaude: ", end='', flush=True)
            context = torch.tensor([encode(formatted)], dtype=torch.long, device=device)
            generated_tokens = []
            start_time = time.time()

            def stream_callback(token_id: int):
                generated_tokens.append(token_id)
                print(decode([token_id]), end='', flush=True)

            model.generate(
                context,
                max_new_tokens=params['max_tokens'],
                temperature=params['temperature'],
                top_k=params['top_k'],
                top_p=params['top_p'],
                repetition_penalty=params['repetition_penalty'],
                stop_strings=stop_strings,
                decode_fn=decode,
                callback=stream_callback,
            )

            gen_time = time.time() - start_time
            tok_per_sec = len(generated_tokens) / gen_time if gen_time > 0 else 0
            print(f"\n\n[{len(generated_tokens)} tokens in {gen_time:.2f}s ({tok_per_sec:.1f} tok/s)]")

        except KeyboardInterrupt:
            print("\n\nUse 'quit' to exit.")
        except Exception as e:
            print(f"\nError: {e}")


#========================================================
# --- 15. MAIN
#========================================================
def main():
    global config

    print(f"\nThe current folder is: {os.getcwd()}")
    print("=" * 70)

    model_path = get_save_path(MODEL_FILENAME)

    # --- Try loading existing model ---
    if not TRAIN_NEW_MODEL and os.path.exists(model_path):
        print("Loading existing model...")
        model, stoi, itos, vocab_size, tokenization_mode = load_model(model_path)
        if model is not None:
            if tokenization_mode == 'bpe':
                enc = tiktoken.get_encoding("cl100k_base")
                encode = lambda s: enc.encode(s, allowed_special='all')
                decode = lambda l: enc.decode(l)
            else:
                encode = lambda s: [stoi.get(c, 0) for c in s]
                decode = lambda l: ''.join([itos.get(i, '') for i in l])
            interactive_chat(model, encode, decode)
            return

    # --- Train new model ---
    print("Training new model...")
    print("=" * 70)
    print("\nArchitecture: RoPE + RMSNorm + SwiGLU + GQA + QK-Norm + MoE")
    print(f"  Layers: {config.n_layer} | Q-heads: {config.n_head} | KV-heads: {config.n_kv_head}")
    print(f"  Experts: {config.n_expert} (top-{config.n_expert_active}) | Sliding window: {config.sliding_window}")
    print("=" * 70)

    text = load_training_data(LOCAL_FILE)

    # Setup tokenizer
    stoi, itos = {}, {}
    if TOKENIZATION_MODE == 'bpe':
        enc = tiktoken.get_encoding("cl100k_base")
        vocab_size = enc.n_vocab
        config.vocab_size = vocab_size
        encode = lambda s: enc.encode(s, allowed_special='all')
        decode = lambda l: enc.decode(l)
    else:
        chars = sorted(list(set(text)))
        vocab_size = len(chars)
        config.vocab_size = vocab_size
        stoi = {ch: i for i, ch in enumerate(chars)}
        itos = {i: ch for i, ch in enumerate(chars)}
        encode = lambda s: [stoi.get(c, 0) for c in s]
        decode = lambda l: ''.join([itos.get(i, '') for i in l])

    data = torch.tensor(encode(text), dtype=torch.long)
    n = int(0.9 * len(data))
    train_data, val_data = data[:n], data[n:]

    # Shrink block_size if dataset is too small
    min_tokens = min(len(train_data), len(val_data))
    if min_tokens <= config.block_size:
        config.block_size = max(8, min_tokens - 1)
        print(f"WARNING: Small dataset — block_size reduced to {config.block_size}")

    print(f"Training: {len(train_data):,} tokens | Validation: {len(val_data):,} tokens")

    model = UltimateClaude(config).to(device)
    print(f"Model: {model.get_num_params()/1e6:.2f}M parameters")

    # Report MoE layer distribution
    moe_layers = [i for i in range(config.n_layer)
                  if config.n_expert > 1 and i % config.moe_every_n == 0]
    dense_layers = [i for i in range(config.n_layer) if i not in moe_layers]
    print(f"  MoE layers: {moe_layers} | Dense layers: {dense_layers}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    use_amp = device == 'cuda'
    scaler = torch.cuda.amp.GradScaler() if use_amp else None

    start_time = time.time()
    pbar = tqdm(range(config.max_iters), desc="Training Progress")

    for iter_num in pbar:
        if iter_num % config.eval_interval == 0:
            losses = estimate_loss(model, train_data, val_data, config)
            msg = f"Step {iter_num} | Train: {losses['train']:.4f} | Val: {losses['val']:.4f}"
            print(f"\n{msg}")
            pbar.set_description(msg)

        xb, yb = get_batch(train_data, config)
        lr = get_lr(iter_num, config)
        for pg in optimizer.param_groups:
            pg['lr'] = lr

        if use_amp:
            with torch.cuda.amp.autocast():
                _, loss, _ = model(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            _, loss, _ = model(xb, yb)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    elapsed = (time.time() - start_time) / 60
    print(f"\nTraining complete in {elapsed:.2f} min")

    # Final eval
    final_losses = estimate_loss(model, train_data, val_data, config)
    print(f"Final — Train: {final_losses['train']:.4f} | Val: {final_losses['val']:.4f}")

    # Save
    print("\nSaving model...")
    saved_path = save_model(model, MODEL_FILENAME, vocab_size, stoi, itos, config, TOKENIZATION_MODE)

    if saved_path:
        print("\nUploading to GitHub Releases...")
        save_to_github_release(saved_path)

    # --- RLHF phase (optional) ---
    if RUN_RLHF:
        model = run_rlhf_pipeline(model, encode, decode, config)
        # Save the RLHF-aligned model
        rlhf_path = MODEL_FILENAME.replace('.pth', '_rlhf.pth')
        print("\nSaving RLHF-aligned model...")
        save_model(model, rlhf_path, vocab_size, stoi, itos, config, TOKENIZATION_MODE)

    interactive_chat(model, encode, decode)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\nGoodbye!")
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
