# üêâ Chimera Prose - 500M Eloquent Writing Model

**Target**: ~500M parameter hybrid recurrent-attention LLM for creative writing

**Hardware**: A100 40GB (Colab Pro+)

**Architecture**: Chimera (Griffin-style RG-LRU + Sliding Window GQA)

**Training tricks from successful models**:
- ŒºP-inspired scaling (Cerebras/Microsoft)
- Cosine LR with warmup (GPT/Llama)
- AdamW Œ≤2=0.95 (Llama-2)
- Gradient checkpointing
- bf16 mixed precision
- torch.compile (PyTorch 2.0+)
- Packed sequences (no padding waste)
- Quality-filtered data blend

**Data blend**:
- 40% Literary prose (pg19/BookCorpus)
- 30% Creative writing (WritingPrompts)
- 20% Conversational (OpenAssistant)
- 10% High-quality web (Dolma/FineWeb subset)

In [None]:
#@title 1. Setup & Dependencies { display-mode: "form" }
import os
import subprocess

# Check GPU
!nvidia-smi --query-gpu=name,memory.total --format=csv

# Install dependencies
!pip install -q transformers datasets accelerate wandb sentencepiece
!pip install -q flash-attn --no-build-isolation 2>/dev/null || echo "Flash attention not available, using standard attention"

# Clone repo
if not os.path.exists('chimera'):
    !git clone https://github.com/entropadeus/chimera.git
os.chdir('chimera')

print("\n" + "="*50)
print("Setup complete!")
print("="*50)

In [None]:
#@title 2. Model Configuration - 500M Prose Model { display-mode: "form" }

import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional

@dataclass
class ProseModelConfig:
    """
    ~500M params optimized for A100 40GB with headroom.
    
    Architecture choices:
    - d_model=1536: Sweet spot for expressiveness vs memory
    - 24 layers: Good depth for style/voice development
    - 24 heads: 64 head_dim (standard, hardware efficient)
    - 6 KV heads: 4:1 GQA ratio (memory efficient)
    - 2048 context: Enough for prose, fits in memory
    - 1024 sliding window: Local coherence
    
    Memory estimate:
    - Model: ~1GB (bf16)
    - Optimizer states: ~4GB (AdamW)
    - Activations: ~8-12GB (with grad checkpointing)
    - Gradients: ~1GB
    - Total: ~15-20GB, leaving ~20GB headroom
    """
    d_model: int = 1536
    n_layers: int = 24
    n_heads: int = 24          # 1536/24 = 64 head_dim
    n_kv_heads: int = 6        # 4:1 GQA
    vocab_size: int = 32000
    max_seq_len: int = 2048
    sliding_window: int = 1024
    ffn_hidden_mult: float = 8/3  # SwiGLU ratio
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    rope_scaling: Optional[float] = None
    dropout: float = 0.0       # No dropout (modern practice)
    attention_every_n: int = 3  # 3:1 recurrent:attention
    tie_word_embeddings: bool = True
    
    def __post_init__(self):
        self.recurrence_dim = self.d_model
        self.ffn_hidden = int(self.d_model * self.ffn_hidden_mult)
        self.ffn_hidden = ((self.ffn_hidden + 255) // 256) * 256
        self.head_dim = self.d_model // self.n_heads

config = ProseModelConfig()

# Calculate param count
def estimate_params(cfg):
    embed = cfg.vocab_size * cfg.d_model
    
    n_attn = cfg.n_layers // cfg.attention_every_n
    n_recur = cfg.n_layers - n_attn
    
    # Attention layer params
    attn_qkv = cfg.d_model * (cfg.n_heads + 2 * cfg.n_kv_heads) * cfg.head_dim
    attn_out = cfg.n_heads * cfg.head_dim * cfg.d_model
    attn_total = (attn_qkv + attn_out) * n_attn
    
    # Recurrent layer params (with input gate)
    recur_input = cfg.d_model * cfg.d_model  # input_proj
    recur_gates = 2 * cfg.d_model * cfg.d_model  # input_gate + recurrence_gate
    recur_lambda = cfg.d_model  # lambda_param
    recur_out = cfg.d_model * cfg.d_model  # output_proj
    recur_total = (recur_input + recur_gates + recur_lambda + recur_out) * n_recur
    
    # FFN per layer
    ffn = 3 * cfg.d_model * cfg.ffn_hidden * cfg.n_layers
    
    # Norms
    norms = 2 * cfg.d_model * cfg.n_layers + cfg.d_model
    
    total = embed + attn_total + recur_total + ffn + norms
    return total

params = estimate_params(config)
print(f"Estimated parameters: {params:,} ({params/1e6:.1f}M)")
print(f"\nArchitecture:")
print(f"  Layers: {config.n_layers} ({config.n_layers - config.n_layers//config.attention_every_n} recurrent, {config.n_layers//config.attention_every_n} attention)")
print(f"  d_model: {config.d_model}")
print(f"  Heads: {config.n_heads} query, {config.n_kv_heads} KV (GQA)")
print(f"  Context: {config.max_seq_len} tokens")
print(f"  FFN hidden: {config.ffn_hidden}")

In [None]:
#@title 3. Download & Prepare High-Quality Training Data { display-mode: "form" }

from datasets import load_dataset
from transformers import AutoTokenizer
import random

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("\nDownloading datasets (this takes a few minutes)...")

# 1. Literary prose - Project Gutenberg (use HF parquet version)
print("  [1/4] Loading literary prose...")
try:
    # Try the working pg19 parquet version
    pg19 = load_dataset("emozilla/pg19", split="train", streaming=True)
    pg19_texts = []
    for i, x in enumerate(pg19):
        if i >= 3000:
            break
        if x.get('text') and len(x['text']) > 1000:
            pg19_texts.append(x['text'][:50000])
    print(f"    Loaded {len(pg19_texts)} literary texts")
except Exception as e:
    print(f"    pg19 failed: {e}")
    # Fallback: use a public domain books dataset
    try:
        books = load_dataset("storytracer/US_PD_Books", split="train", streaming=True)
        pg19_texts = []
        for i, x in enumerate(books):
            if i >= 3000:
                break
            if x.get('text') and len(x['text']) > 1000:
                pg19_texts.append(x['text'][:50000])
        print(f"    Loaded {len(pg19_texts)} book texts (fallback)")
    except:
        print("    Using tinystories as prose fallback...")
        ts = load_dataset("roneneldan/TinyStories", split="train[:50000]")
        pg19_texts = [x['text'] for x in ts if len(x['text']) > 200]

# 2. Creative writing - WritingPrompts
print("  [2/4] Loading creative fiction...")
try:
    wp = load_dataset("Lambent/writing-prompts-cleaned", split="train", streaming=True)
    wp_texts = []
    for i, x in enumerate(wp):
        if i >= 30000:
            break
        story = x.get('story') or x.get('text', '')
        prompt = x.get('prompt', '')
        if story and len(story) > 300:
            wp_texts.append(f"Prompt: {prompt}\n\nStory: {story}" if prompt else story)
    print(f"    Loaded {len(wp_texts)} creative texts")
except Exception as e:
    print(f"    WritingPrompts failed: {e}")
    try:
        # Alternative creative writing dataset
        wp = load_dataset("lksy/prompts_stories", split="train", streaming=True)
        wp_texts = []
        for i, x in enumerate(wp):
            if i >= 30000:
                break
            if x.get('story') and len(x['story']) > 300:
                wp_texts.append(x['story'])
        print(f"    Loaded {len(wp_texts)} stories (fallback)")
    except:
        wp_texts = []
        print("    No creative writing dataset available")

# 3. Conversational - OpenAssistant
print("  [3/4] Loading conversational data...")
try:
    oasst = load_dataset("OpenAssistant/oasst1", split="train")
    oasst_texts = []
    for x in oasst:
        if x.get('role') == 'assistant' and x.get('text') and len(x['text']) > 200:
            oasst_texts.append(x['text'])
    oasst_texts = oasst_texts[:20000]
    print(f"    Loaded {len(oasst_texts)} conversational texts")
except Exception as e:
    print(f"    OASST failed: {e}")
    oasst_texts = []

# 4. High-quality web text - FineWeb-Edu
print("  [4/4] Loading quality web text...")
try:
    fineweb = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT", split="train", streaming=True)
    fineweb_texts = []
    for i, x in enumerate(fineweb):
        if i >= 15000:
            break
        if x.get('text') and len(x['text']) > 500:
            fineweb_texts.append(x['text'][:10000])
    print(f"    Loaded {len(fineweb_texts)} web texts")
except Exception as e:
    print(f"    FineWeb failed: {e}")
    try:
        # Fallback to SlimPajama
        slim = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
        fineweb_texts = []
        for i, x in enumerate(slim):
            if i >= 15000:
                break
            if x.get('text') and len(x['text']) > 500:
                fineweb_texts.append(x['text'][:10000])
        print(f"    Loaded {len(fineweb_texts)} texts (SlimPajama fallback)")
    except:
        fineweb_texts = []
        print("    No web text available")

print(f"\nDataset sizes:")
print(f"  Literary: {len(pg19_texts):,} texts")
print(f"  Creative: {len(wp_texts):,} texts")
print(f"  Conversational: {len(oasst_texts):,} texts")
print(f"  Web: {len(fineweb_texts):,} texts")

# Blend with target ratios: 40% literary, 30% creative, 20% conversational, 10% web
all_texts = []

if pg19_texts:
    n_lit = min(40000, len(pg19_texts))
    all_texts.extend(random.sample(pg19_texts, n_lit) if len(pg19_texts) > n_lit else pg19_texts)
    
if wp_texts:
    n_creative = min(30000, len(wp_texts))
    all_texts.extend(random.sample(wp_texts, n_creative) if len(wp_texts) > n_creative else wp_texts)
    
if oasst_texts:
    n_conv = min(20000, len(oasst_texts))
    all_texts.extend(random.sample(oasst_texts, n_conv) if len(oasst_texts) > n_conv else oasst_texts)
    
if fineweb_texts:
    n_web = min(10000, len(fineweb_texts))
    all_texts.extend(random.sample(fineweb_texts, n_web) if len(fineweb_texts) > n_web else fineweb_texts)

random.shuffle(all_texts)

print(f"\nFinal blend: {len(all_texts):,} texts")

if len(all_texts) < 1000:
    raise ValueError("Not enough training data! Check dataset availability.")

# Save to disk
os.makedirs('data', exist_ok=True)
with open('data/prose_blend.txt', 'w', encoding='utf-8') as f:
    f.write('\n\n<|endoftext|>\n\n'.join(all_texts))

file_size = os.path.getsize('data/prose_blend.txt') / 1e9
print(f"Saved to data/prose_blend.txt ({file_size:.2f} GB)")

In [None]:
#@title 4. Tokenize & Pack Sequences { display-mode: "form" }

import torch
from transformers import AutoTokenizer
from tqdm.auto import tqdm
import numpy as np

SEQ_LEN = 2048  # Match model config

print("Tokenizing corpus...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

with open('data/prose_blend.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Tokenize entire corpus
print(f"Text length: {len(text):,} chars")
tokens = tokenizer.encode(text, add_special_tokens=False)
print(f"Token count: {len(tokens):,}")

# Pack into sequences (no padding waste)
n_seqs = len(tokens) // SEQ_LEN
tokens = tokens[:n_seqs * SEQ_LEN]
packed = np.array(tokens, dtype=np.int32).reshape(n_seqs, SEQ_LEN)

print(f"\nPacked sequences: {n_seqs:,} x {SEQ_LEN}")
print(f"Total tokens: {n_seqs * SEQ_LEN:,}")

# Save
np.save('data/prose_packed.npy', packed)
print(f"Saved to data/prose_packed.npy")

# Quick stats
print(f"\nDataset size: {packed.nbytes / 1e9:.2f} GB")

In [None]:
#@title 5. Training Configuration { display-mode: "form" }

from dataclasses import dataclass

@dataclass
class TrainConfig:
    """
    Training configuration with tricks from successful models.
    
    Key insights:
    - Llama-2: AdamW Œ≤2=0.95, weight_decay=0.1, grad_clip=1.0
    - GPT-3: Cosine LR decay, warmup 0.1% of steps
    - Chinchilla: ~20 tokens per parameter optimal
    - ŒºP: Scale LR inversely with width for transfer
    """
    # Batch size (tune for A100 40GB)
    micro_batch_size: int = 4         # Per-GPU batch
    gradient_accumulation: int = 16   # Effective batch = 4 * 16 = 64
    
    # Learning rate (ŒºP-inspired: scale down for wider models)
    # Base LR 3e-4 for 768 width, scale as sqrt(768/width)
    base_lr: float = 2e-4             # Slightly lower for 1536 width
    min_lr: float = 2e-5              # 10% of base
    warmup_steps: int = 1000          # ~2% of training
    
    # Optimizer (Llama-2 style)
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95               # Lower than default for stability
    eps: float = 1e-8
    grad_clip: float = 1.0
    
    # Training duration
    max_steps: int = 50000            # ~3.2B tokens with batch 64 * 2048
    eval_interval: int = 500
    save_interval: int = 2500
    log_interval: int = 10
    
    # Memory optimization
    gradient_checkpointing: bool = True
    mixed_precision: str = "bf16"     # A100 native bf16
    compile_model: bool = True        # torch.compile speedup
    
    # Data
    seq_len: int = 2048
    
    # Paths
    data_path: str = "data/prose_packed.npy"
    checkpoint_dir: str = "checkpoints"
    
    # Wandb
    use_wandb: bool = True
    wandb_project: str = "chimera-prose"
    wandb_run_name: str = "prose-500m-a100"

train_config = TrainConfig()

# Calculate training stats
effective_batch = train_config.micro_batch_size * train_config.gradient_accumulation
tokens_per_step = effective_batch * train_config.seq_len
total_tokens = train_config.max_steps * tokens_per_step

print("Training Configuration")
print("=" * 50)
print(f"Effective batch size: {effective_batch}")
print(f"Tokens per step: {tokens_per_step:,}")
print(f"Total training tokens: {total_tokens:,} ({total_tokens/1e9:.1f}B)")
print(f"\nLearning rate: {train_config.base_lr} ‚Üí {train_config.min_lr}")
print(f"Warmup steps: {train_config.warmup_steps}")
print(f"\nOptimizations:")
print(f"  Gradient checkpointing: {train_config.gradient_checkpointing}")
print(f"  Mixed precision: {train_config.mixed_precision}")
print(f"  torch.compile: {train_config.compile_model}")

In [None]:
#@title 6. Build Model with Optimizations { display-mode: "form" }

import torch
import torch.nn as nn
from model import Chimera, ChimeraConfig

# Build config
model_config = ChimeraConfig(
    d_model=1536,
    n_layers=24,
    n_heads=24,
    n_kv_heads=6,
    vocab_size=32000,
    max_seq_len=2048,
    sliding_window=1024,
    attention_every_n=3,
    dropout=0.0,
)

print("Building model...")
model = Chimera(model_config)

# Move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Enable gradient checkpointing
if train_config.gradient_checkpointing:
    # Wrap each layer with checkpointing
    from torch.utils.checkpoint import checkpoint
    
    class CheckpointedChimeraBlock(nn.Module):
        def __init__(self, block):
            super().__init__()
            self.block = block
            self.use_attention = block.use_attention
        
        def forward(self, x, cache=None, position_offset=0, use_cache=False):
            if self.training and not use_cache:
                # Use checkpointing during training
                def custom_forward(x_inner):
                    return self.block(x_inner, None, position_offset, False)[0]
                return checkpoint(custom_forward, x, use_reentrant=False), None
            return self.block(x, cache, position_offset, use_cache)
    
    model.layers = nn.ModuleList([
        CheckpointedChimeraBlock(layer) for layer in model.layers
    ])
    print("  ‚úì Gradient checkpointing enabled")

# Compile model (PyTorch 2.0+)
if train_config.compile_model and hasattr(torch, 'compile'):
    print("  Compiling model (takes ~2 min first run)...")
    model = torch.compile(model, mode="reduce-overhead")
    print("  ‚úì torch.compile enabled")

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters: {n_params:,} ({n_params/1e6:.1f}M)")
print(f"  Trainable parameters: {n_trainable:,}")
print(f"  Model size (bf16): {n_params * 2 / 1e9:.2f} GB")

# Memory estimate
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"  GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
#@title 7. Setup Optimizer & Scheduler { display-mode: "form" }

import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

# Separate weight decay for different param types (Llama-2 style)
def get_param_groups(model, weight_decay):
    """Apply weight decay only to weights, not biases/norms/embeddings."""
    decay = set()
    no_decay = set()
    
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # No decay for biases, norms, embeddings
        if 'bias' in name or 'norm' in name or 'embed' in name or 'lambda' in name:
            no_decay.add(name)
        else:
            decay.add(name)
    
    param_dict = {name: param for name, param in model.named_parameters()}
    
    return [
        {"params": [param_dict[n] for n in sorted(decay)], "weight_decay": weight_decay},
        {"params": [param_dict[n] for n in sorted(no_decay)], "weight_decay": 0.0},
    ]

param_groups = get_param_groups(model, train_config.weight_decay)
print(f"Param groups: {len(param_groups[0]['params'])} with decay, {len(param_groups[1]['params'])} without")

# AdamW optimizer (Llama-2 betas)
optimizer = AdamW(
    param_groups,
    lr=train_config.base_lr,
    betas=(train_config.beta1, train_config.beta2),
    eps=train_config.eps,
)

# Cosine schedule with warmup
def get_lr_lambda(step):
    """Cosine decay with linear warmup."""
    if step < train_config.warmup_steps:
        # Linear warmup
        return step / train_config.warmup_steps
    else:
        # Cosine decay to min_lr
        progress = (step - train_config.warmup_steps) / (train_config.max_steps - train_config.warmup_steps)
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
        # Decay from base_lr to min_lr
        lr_range = 1.0 - (train_config.min_lr / train_config.base_lr)
        return (train_config.min_lr / train_config.base_lr) + lr_range * cosine_decay

scheduler = LambdaLR(optimizer, get_lr_lambda)

# Mixed precision scaler (bf16 doesn't need scaling on A100, but we set it up anyway)
scaler = torch.cuda.amp.GradScaler(enabled=(train_config.mixed_precision == "fp16"))

print(f"\nOptimizer: AdamW")
print(f"  Base LR: {train_config.base_lr}")
print(f"  Min LR: {train_config.min_lr}")
print(f"  Weight decay: {train_config.weight_decay}")
print(f"  Betas: ({train_config.beta1}, {train_config.beta2})")
print(f"\nScheduler: Cosine with {train_config.warmup_steps} warmup steps")

In [None]:
#@title 8. Data Loader { display-mode: "form" }

import numpy as np
from torch.utils.data import Dataset, DataLoader

class PackedDataset(Dataset):
    """Dataset for pre-packed token sequences."""
    
    def __init__(self, data_path):
        self.data = np.load(data_path)
        print(f"Loaded {len(self.data):,} sequences")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = torch.from_numpy(self.data[idx].astype(np.int64))
        # Input: all but last token, Target: all but first token
        return tokens[:-1], tokens[1:]

dataset = PackedDataset(train_config.data_path)

# DataLoader with shuffling
dataloader = DataLoader(
    dataset,
    batch_size=train_config.micro_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

print(f"\nDataLoader:")
print(f"  Batch size: {train_config.micro_batch_size}")
print(f"  Batches per epoch: {len(dataloader):,}")
print(f"  Tokens per epoch: {len(dataset) * train_config.seq_len:,}")

In [None]:
#@title 9. Training Loop { display-mode: "form" }

import time
import wandb
from tqdm.auto import tqdm

# Initialize wandb
if train_config.use_wandb:
    wandb.init(
        project=train_config.wandb_project,
        name=train_config.wandb_run_name,
        config={
            "model": {
                "d_model": model_config.d_model,
                "n_layers": model_config.n_layers,
                "n_heads": model_config.n_heads,
                "params": n_params,
            },
            "training": vars(train_config),
        }
    )

# Create checkpoint dir
os.makedirs(train_config.checkpoint_dir, exist_ok=True)

# Training state
global_step = 0
tokens_seen = 0
best_loss = float('inf')

# Precision context
autocast_dtype = torch.bfloat16 if train_config.mixed_precision == "bf16" else torch.float16

print("="*60)
print("STARTING TRAINING")
print("="*60)

model.train()
optimizer.zero_grad()

data_iter = iter(dataloader)
pbar = tqdm(total=train_config.max_steps, desc="Training")

accumulation_loss = 0.0
start_time = time.time()

while global_step < train_config.max_steps:
    # Accumulate gradients
    for micro_step in range(train_config.gradient_accumulation):
        try:
            inputs, targets = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            inputs, targets = next(data_iter)
        
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast(dtype=autocast_dtype):
            logits, _ = model(inputs)
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, model_config.vocab_size),
                targets.view(-1),
                ignore_index=-100,
            )
            loss = loss / train_config.gradient_accumulation
        
        # Backward pass
        loss.backward()
        accumulation_loss += loss.item()
    
    # Gradient clipping
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip)
    
    # Optimizer step
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
    
    # Update counters
    global_step += 1
    tokens_seen += train_config.micro_batch_size * train_config.gradient_accumulation * train_config.seq_len
    
    # Logging
    if global_step % train_config.log_interval == 0:
        elapsed = time.time() - start_time
        tokens_per_sec = tokens_seen / elapsed
        current_lr = scheduler.get_last_lr()[0]
        
        pbar.set_postfix({
            'loss': f'{accumulation_loss:.4f}',
            'ppl': f'{math.exp(accumulation_loss):.2f}',
            'lr': f'{current_lr:.2e}',
            'tok/s': f'{tokens_per_sec/1000:.1f}k',
        })
        
        if train_config.use_wandb:
            wandb.log({
                "train/loss": accumulation_loss,
                "train/perplexity": math.exp(accumulation_loss),
                "train/lr": current_lr,
                "train/grad_norm": grad_norm.item(),
                "train/tokens_seen": tokens_seen,
                "train/tokens_per_sec": tokens_per_sec,
            }, step=global_step)
        
        accumulation_loss = 0.0
    
    pbar.update(1)
    
    # Save checkpoint
    if global_step % train_config.save_interval == 0:
        checkpoint = {
            'step': global_step,
            'tokens_seen': tokens_seen,
            'model': model.state_dict() if not hasattr(model, '_orig_mod') else model._orig_mod.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'config': vars(model_config),
            'train_config': vars(train_config),
        }
        
        ckpt_path = f"{train_config.checkpoint_dir}/step_{global_step}.pt"
        torch.save(checkpoint, ckpt_path)
        print(f"\nüíæ Saved checkpoint: {ckpt_path}")
        
        # Also save as latest
        torch.save(checkpoint, f"{train_config.checkpoint_dir}/latest.pt")

pbar.close()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Total steps: {global_step:,}")
print(f"Total tokens: {tokens_seen:,}")
print(f"Time: {(time.time() - start_time)/3600:.2f} hours")

if train_config.use_wandb:
    wandb.finish()

In [None]:
#@title 10. Extract Model Weights { display-mode: "form" }

print("Extracting model weights for inference...")

# Load latest checkpoint
ckpt = torch.load(f"{train_config.checkpoint_dir}/latest.pt", map_location='cpu')

# Save just the model weights
model_weights = ckpt['model']
torch.save(model_weights, f"{train_config.checkpoint_dir}/prose_500m.pt")

print(f"‚úì Saved model weights to {train_config.checkpoint_dir}/prose_500m.pt")
print(f"  Size: {os.path.getsize(f'{train_config.checkpoint_dir}/prose_500m.pt') / 1e9:.2f} GB")

In [None]:
#@title 11. Test Generation { display-mode: "form" }

from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

# Put model in eval mode
model.eval()

@torch.no_grad()
def generate(prompt, max_new_tokens=200, temperature=0.8, top_p=0.9):
    """Generate text from a prompt."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    
    generated = input_ids
    cache = None
    
    for _ in range(max_new_tokens):
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            if cache is None:
                logits, cache = model(generated, use_cache=True)
            else:
                logits, cache = model(generated[:, -1:], cache=cache, 
                                     position_offset=generated.size(1)-1, use_cache=True)
        
        # Sample next token
        logits = logits[:, -1, :] / temperature
        
        # Top-p sampling
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
        sorted_indices_to_remove[:, 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = float('-inf')
        
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        generated = torch.cat([generated, next_token], dim=1)
        
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Test prompts
prompts = [
    "The old lighthouse keeper had seen many storms, but none quite like",
    "In the depths of the ancient library, she discovered a book that",
    "The city had changed since I'd been away. The streets were",
]

print("="*60)
print("GENERATION SAMPLES")
print("="*60)

for prompt in prompts:
    print(f"\nüìù Prompt: {prompt}")
    print("-" * 40)
    output = generate(prompt, max_new_tokens=150, temperature=0.8)
    print(output)
    print()

In [None]:
#@title 12. Save to Google Drive (Optional) { display-mode: "form" }

save_to_drive = True  #@param {type:"boolean"}

if save_to_drive:
    from google.colab import drive
    drive.mount('/content/drive')
    
    drive_path = "/content/drive/MyDrive/chimera_checkpoints"
    os.makedirs(drive_path, exist_ok=True)
    
    # Copy model weights
    !cp {train_config.checkpoint_dir}/prose_500m.pt {drive_path}/
    !cp {train_config.checkpoint_dir}/latest.pt {drive_path}/
    
    print(f"‚úì Saved checkpoints to Google Drive: {drive_path}")

## Training Tricks Summary

This notebook incorporates lessons from:

### Architecture (Griffin + Improvements)
- **Dual-gate RG-LRU**: Input gate + recurrence gate (not just one)
- **Learned base rates**: `a = œÉ(Œõ)^(c¬∑r_t)` per dimension
- **Log-space computation**: Numerical stability
- **3:1 recurrent:attention ratio**: Proven in Griffin/Jamba

### Optimization (Llama-2 Style)
- **AdamW Œ≤2=0.95**: More stable than 0.999
- **Weight decay 0.1**: Only on non-embedding weights
- **Gradient clipping 1.0**: Prevents explosions
- **Cosine LR with warmup**: Smooth convergence

### Efficiency (A100 Optimized)
- **bf16 mixed precision**: Native A100 support
- **Gradient checkpointing**: 3x memory savings
- **torch.compile**: 20-40% speedup
- **Packed sequences**: No padding waste

### Data (Quality Focus)
- **40% literary prose**: pg19/BookCorpus
- **30% creative fiction**: WritingPrompts
- **20% conversational**: OpenAssistant
- **10% quality web**: FineWeb-Edu

### Next Steps
1. Run full 50k steps (~6-8 hours on A100)
2. Monitor loss curve in wandb
3. Instruction fine-tune with `train_instruct.py`
4. Deploy with `chat_ui.py`