# Module 5: Full Training Pipeline and Story Generation

This is the capstone notebook where we bring everything together! We'll train a complete MoE-based story generation model and use it to create stories.

## Learning Objectives

By the end of this notebook, you will:
1. Set up the complete Storyteller model with MoE
2. Load and prepare real story datasets
3. Configure MLflow for experiment tracking
4. Train the full model end-to-end
5. Monitor training with MLflow UI
6. Implement text generation (sampling strategies)
7. Generate creative stories with your trained model
8. Evaluate and compare different checkpoints

## What We'll Build

- Complete MoE transformer for story generation
- Production-ready training pipeline
- MLflow experiment tracking integration
- Multiple generation strategies (greedy, sampling, top-k, nucleus)
- Interactive story generation interface

In [None]:
import sys
from pathlib import Path

# Add src to path so we can import our modules
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
import mlflow
from datetime import datetime

# Import our modules
from storyteller.model.transformer import StorytellerModel, TransformerConfig
from storyteller.data.dataset import TextDataset
from storyteller.training.trainer import Trainer
from storyteller.inference.generator import Generator, GenerationConfig

# Set seeds
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (14, 6)

# Device setup
device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## Part 1: Configuration

Let's load our model configuration from the YAML file.

In [None]:
# Load configuration
config_path = project_root / "configs" / "moe_model.yaml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f)

print("Model Configuration:")
print(yaml.dump(config["model"], default_flow_style=False))

print("\nTraining Configuration:")
print(yaml.dump(config["training"], default_flow_style=False))

### Understanding Our MoE Configuration

Key parameters:
- **Total parameters**: ~500M
- **Active per token**: ~100M (using 2 of 8 experts)
- **Layers**: 16 transformer blocks
- **Hidden size**: 1024
- **MoE frequency**: Every 2 layers (8 total MoE layers)

In [None]:
# Create model config
model_config = TransformerConfig(
    vocab_size=config["model"]["vocab_size"],
    max_seq_length=config["model"]["max_seq_length"],
    hidden_size=config["model"]["hidden_size"],
    num_layers=config["model"]["num_layers"],
    num_attention_heads=config["model"]["num_attention_heads"],
    intermediate_size=config["model"]["intermediate_size"],
    attention_dropout=config["model"]["attention_dropout"],
    hidden_dropout=config["model"]["hidden_dropout"],
    positional_encoding=config["model"]["positional_encoding"],
    activation=config["model"]["activation"],
    gradient_checkpointing=config["model"]["gradient_checkpointing"],
    use_moe=config["model"]["use_moe"],
    num_experts=config["model"]["num_experts"],
    top_k_experts=config["model"]["top_k_experts"],
    moe_frequency=config["model"]["moe_frequency"],
    expert_capacity_factor=config["model"]["expert_capacity_factor"],
    load_balancing_loss_weight=config["model"]["load_balancing_loss_weight"],
)

print("Model Config created successfully!")

## Part 2: Data Preparation

Load the datasets we prepared in Module 1.

In [None]:
# Paths to data
data_dir = project_root / "data" / "processed"
train_data_path = data_dir / "train.txt"
val_data_path = data_dir / "val.txt"

# Check if data exists
if not train_data_path.exists():
    print("‚ö†Ô∏è  Training data not found!")
    print("Please run the data preparation scripts first:")
    print("  1. storyteller-download")
    print("  2. storyteller-tokenizer")
    print("  3. storyteller-preprocess")
else:
    print(f"‚úì Training data found: {train_data_path}")
    print(f"‚úì Validation data found: {val_data_path}")

    # Create datasets
    train_dataset = TextDataset(
        data_path=str(train_data_path),
        max_seq_length=config["model"]["max_seq_length"],
    )

    val_dataset = TextDataset(
        data_path=str(val_data_path),
        max_seq_length=config["model"]["max_seq_length"],
    )

    print("\nDataset sizes:")
    print(f"  Training examples: {len(train_dataset):,}")
    print(f"  Validation examples: {len(val_dataset):,}")

### Create DataLoaders

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=True,
    num_workers=config["training"]["num_workers"],
    pin_memory=config["training"]["pin_memory"],
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["training"]["batch_size"],
    shuffle=False,
    num_workers=config["training"]["num_workers"],
    pin_memory=config["training"]["pin_memory"],
)

effective_batch_size = (
    config["training"]["batch_size"] * config["training"]["gradient_accumulation_steps"]
)

print("DataLoaders created:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Gradient accumulation: {config['training']['gradient_accumulation_steps']}")
print(f"  Effective batch size: {effective_batch_size}")
print(f"  Training batches per epoch: {len(train_loader):,}")

## Part 3: Model Initialization

Create our MoE Storyteller model.

In [None]:
# Create model
model = StorytellerModel(model_config)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Estimate active parameters (approximate)
# For MoE: use top_k/num_experts ratio for MoE layers
moe_layers = config["model"]["num_layers"] // config["model"]["moe_frequency"]
dense_layers = config["model"]["num_layers"] - moe_layers
expert_activation_ratio = (
    config["model"]["top_k_experts"] / config["model"]["num_experts"]
)

print("Model Architecture:")
print(f"  Total layers: {config['model']['num_layers']}")
print(f"  MoE layers: {moe_layers}")
print(f"  Dense layers: {dense_layers}")
print(f"  Experts per MoE layer: {config['model']['num_experts']}")
print(f"  Active experts per token: {config['model']['top_k_experts']}")
print("\nParameter Counts:")
print(f"  Total parameters: {total_params:,} ({total_params / 1e6:.1f}M)")
print(f"  Trainable parameters: {trainable_params:,} ({trainable_params / 1e6:.1f}M)")
print(
    f"  Approximate active per forward: ~{int(total_params * expert_activation_ratio):,} "
    f"({total_params * expert_activation_ratio / 1e6:.1f}M)"
)
print(f"\nModel size: ~{total_params * 4 / 1e9:.2f} GB (float32)")

## Part 4: MLflow Setup

Configure experiment tracking with MLflow.

In [None]:
# Set MLflow tracking URI
mlflow_uri = config["training"].get("mlflow_tracking_uri", "http://localhost:8080")
mlflow.set_tracking_uri(mlflow_uri)

# Set experiment
experiment_name = config["training"]["mlflow_experiment_name"]
mlflow.set_experiment(experiment_name)

print("MLflow Configuration:")
print(f"  Tracking URI: {mlflow_uri}")
print(f"  Experiment: {experiment_name}")
print("\nTo view experiments, run in terminal:")
print("  mlflow ui --port 8080")
print("Then open: http://localhost:8080")

## Part 5: Create Trainer

Set up the complete training pipeline with our Trainer class.

In [None]:
# Create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["training"]["learning_rate"],
    weight_decay=config["training"]["weight_decay"],
    betas=(0.9, 0.95),
)

# Create learning rate scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR

num_training_steps = len(train_loader) * config["training"]["num_epochs"]
num_training_steps //= config["training"]["gradient_accumulation_steps"]

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=num_training_steps - config["training"]["warmup_steps"],
    eta_min=config["training"]["learning_rate"] * 0.1,
)

print("Optimizer: AdamW")
print(f"  Learning rate: {config['training']['learning_rate']:.2e}")
print(f"  Weight decay: {config['training']['weight_decay']}")
print("\nScheduler: Cosine Annealing")
print(f"  Total steps: {num_training_steps:,}")
print(f"  Warmup steps: {config['training']['warmup_steps']:,}")

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=config["training"]["device"],
    use_amp=config["training"]["use_amp"],
    amp_dtype=config["training"]["amp_dtype"],
    gradient_accumulation_steps=config["training"]["gradient_accumulation_steps"],
    max_grad_norm=config["training"]["max_grad_norm"],
    save_dir=config["training"]["save_dir"],
    save_every_n_steps=config["training"]["save_every_n_steps"],
    eval_every_n_steps=config["training"]["eval_every_n_steps"],
    log_every_n_steps=config["training"]["log_every_n_steps"],
    keep_last_n_checkpoints=config["training"]["keep_last_n_checkpoints"],
    use_mlflow=config["training"]["use_mlflow"],
    mlflow_experiment_name=config["training"]["mlflow_experiment_name"],
    mlflow_run_name=config["training"]["mlflow_run_name"],
    mlflow_tracking_uri=config["training"].get("mlflow_tracking_uri"),
)

print("‚úì Trainer created successfully!")

## Part 6: Training!

Now let's train our model. This may take a while depending on your hardware.

**Note**: For a full training run, you'd typically train for many more epochs. This is a demonstration.

In [None]:
# Start training
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

try:
    trainer.train(num_epochs=config["training"]["num_epochs"])
    print("\n‚úì Training completed successfully!")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
    # Save checkpoint
    trainer.save_checkpoint("interrupted.pt")
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    raise

## Part 7: Viewing Training Metrics

Let's visualize the training progress.

In [None]:
# If MLflow is running, you can fetch metrics programmatically
try:
    # Get current run
    run = mlflow.active_run()
    if run:
        client = mlflow.tracking.MlflowClient()

        # Get metrics
        metrics = client.get_run(run.info.run_id).data.metrics

        print("Training Metrics (final values):")
        for key, value in sorted(metrics.items()):
            print(f"  {key}: {value:.4f}")
except Exception as e:
    print(f"Could not fetch MLflow metrics: {e}")
    print("Check MLflow UI for detailed metrics and visualizations")

## Part 8: Text Generation

Now for the fun part - let's generate stories!

### Generation Strategies

1. **Greedy**: Always pick highest probability token (deterministic, boring)
2. **Sampling**: Sample from probability distribution (random, diverse)
3. **Top-K**: Sample from top K most likely tokens
4. **Nucleus (Top-P)**: Sample from smallest set with cumulative prob ‚â• p
5. **Temperature**: Scale logits to control randomness

In [None]:
# Load tokenizer
from tokenizers import Tokenizer

tokenizer_path = project_root / "models" / "tokenizer" / "storyteller_tokenizer.json"

if tokenizer_path.exists():
    tokenizer = Tokenizer.from_file(str(tokenizer_path))
    print(f"‚úì Tokenizer loaded from {tokenizer_path}")
    print(f"  Vocabulary size: {tokenizer.get_vocab_size()}")
else:
    print("‚ö†Ô∏è  Tokenizer not found! Please train tokenizer first.")
    tokenizer = None

### Create Generator

In [None]:
# Create generator
generator = Generator(
    model=model,
    tokenizer=tokenizer,
    device=device,
)

print("‚úì Generator created!")

### Generate Stories with Different Strategies

In [None]:
# Story prompt
prompt = "Once upon a time, in a land far away, there lived a"

print(f"Prompt: '{prompt}'")
print("\n" + "=" * 70)

# 1. Greedy decoding (deterministic)
print("\n1. GREEDY DECODING (deterministic)")
print("-" * 70)
greedy_config = GenerationConfig(
    max_new_tokens=100,
    temperature=1.0,
    top_k=1,  # Greedy
    top_p=1.0,
)
story = generator.generate(prompt, greedy_config)
print(story)

# 2. High temperature (creative/random)
print("\n" + "=" * 70)
print("\n2. HIGH TEMPERATURE (creative, random)")
print("-" * 70)
creative_config = GenerationConfig(
    max_new_tokens=100,
    temperature=1.2,
    top_k=50,
    top_p=0.95,
)
story = generator.generate(prompt, creative_config)
print(story)

# 3. Low temperature (focused/coherent)
print("\n" + "=" * 70)
print("\n3. LOW TEMPERATURE (focused, coherent)")
print("-" * 70)
focused_config = GenerationConfig(
    max_new_tokens=100,
    temperature=0.7,
    top_k=40,
    top_p=0.9,
)
story = generator.generate(prompt, focused_config)
print(story)

# 4. Nucleus sampling (balanced)
print("\n" + "=" * 70)
print("\n4. NUCLEUS SAMPLING (balanced)")
print("-" * 70)
nucleus_config = GenerationConfig(
    max_new_tokens=100,
    temperature=0.9,
    top_k=0,  # Disable top-k
    top_p=0.95,
)
story = generator.generate(prompt, nucleus_config)
print(story)

## Part 9: Interactive Story Generation

Create an interactive interface for generating stories.

In [None]:
def generate_interactive_story():
    """
    Interactive story generation with custom prompts.
    """
    print("=" * 70)
    print("INTERACTIVE STORY GENERATOR")
    print("=" * 70)
    print("\nEnter your story prompt (or 'quit' to exit)")
    print("Example: 'In a dark forest, a young wizard discovered'")
    print()

    while True:
        # Get prompt
        prompt = input("\nPrompt: ").strip()

        if prompt.lower() in ["quit", "exit", "q"]:
            print("Goodbye!")
            break

        if not prompt:
            print("Please enter a prompt!")
            continue

        # Get generation parameters
        print("\nGeneration parameters:")
        try:
            max_tokens = int(input("  Max tokens (default 150): ") or "150")
            temperature = float(input("  Temperature 0.1-2.0 (default 0.9): ") or "0.9")
            top_p = float(input("  Top-p 0.0-1.0 (default 0.95): ") or "0.95")
        except ValueError:
            print("Invalid input! Using defaults.")
            max_tokens, temperature, top_p = 150, 0.9, 0.95

        # Generate
        print("\nGenerating...")
        gen_config = GenerationConfig(
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=50,
        )

        story = generator.generate(prompt, gen_config)

        print("\n" + "=" * 70)
        print("GENERATED STORY:")
        print("=" * 70)
        print(story)
        print("=" * 70)


# Uncomment to run interactive mode
# generate_interactive_story()

## Part 10: Evaluating Different Checkpoints

Compare models from different training stages.

In [None]:
def compare_checkpoints(checkpoint_paths, prompt, gen_config):
    """
    Generate stories from multiple checkpoints for comparison.
    """
    results = []

    for ckpt_path in checkpoint_paths:
        # Load checkpoint
        checkpoint = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])

        # Generate
        story = generator.generate(prompt, gen_config)

        results.append(
            {
                "checkpoint": ckpt_path.name,
                "step": checkpoint.get("global_step", "unknown"),
                "val_loss": checkpoint.get("best_val_loss", "unknown"),
                "story": story,
            }
        )

    # Display results
    print("=" * 70)
    print("CHECKPOINT COMPARISON")
    print("=" * 70)
    print(f"\nPrompt: '{prompt}'\n")

    for i, result in enumerate(results, 1):
        print(f"\n{i}. {result['checkpoint']}")
        print(f"   Step: {result['step']}, Val Loss: {result['val_loss']}")
        print("-" * 70)
        print(result["story"])
        print("=" * 70)


# Example usage (uncomment when you have checkpoints)
# checkpoint_dir = Path(config['training']['save_dir'])
# checkpoints = list(checkpoint_dir.glob('checkpoint_step_*.pt'))
#
# if checkpoints:
#     compare_checkpoints(
#         checkpoints[:3],  # Compare first 3 checkpoints
#         prompt="The brave knight ventured into",
#         gen_config=GenerationConfig(max_new_tokens=80, temperature=0.9),
#     )
# else:
#     print("No checkpoints found yet!")

## Part 11: Understanding Generation Quality

### What Makes Good Generated Text?

1. **Coherence**: Logical flow, consistent context
2. **Creativity**: Novel ideas, not just memorization
3. **Fluency**: Natural language, proper grammar
4. **Relevance**: Stays on topic from prompt

### Hyperparameter Effects:

- **Temperature**:
  - Low (0.5-0.7): Safe, coherent, repetitive
  - Medium (0.8-1.0): Balanced creativity
  - High (1.2-2.0): Creative but may be nonsensical

- **Top-P (Nucleus)**:
  - Low (0.7-0.8): Conservative choices
  - Medium (0.9-0.95): Good balance
  - High (0.95-1.0): More diversity

- **Top-K**:
  - Small (10-20): Very focused
  - Medium (40-50): Balanced
  - Large (100+): More random

In [None]:
# Visualize the effect of temperature
def plot_temperature_effects(logits, temperatures=[0.5, 1.0, 2.0]):
    """
    Visualize how temperature affects probability distribution.
    """
    fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 4))

    for idx, temp in enumerate(temperatures):
        # Apply temperature
        scaled_logits = logits / temp
        probs = F.softmax(torch.tensor(scaled_logits), dim=-1).numpy()

        # Plot
        axes[idx].bar(range(len(probs)), probs, color="steelblue", alpha=0.7)
        axes[idx].set_title(f"Temperature = {temp}")
        axes[idx].set_xlabel("Token")
        axes[idx].set_ylabel("Probability")
        axes[idx].set_ylim(0, max(probs) * 1.1)

    plt.tight_layout()
    plt.show()


# Example logits (10 tokens)
example_logits = np.array([2.0, 1.5, 1.0, 0.5, 0.3, 0.1, -0.5, -1.0, -2.0, -3.0])
plot_temperature_effects(example_logits)

## Summary and Next Steps

Congratulations! You've completed the full Storyteller training pipeline! üéâ

### What You've Learned:

1. **Complete Training Pipeline**:
   - Data loading and preprocessing
   - Model initialization (MoE transformer)
   - Optimizer and scheduler configuration
   - Mixed precision training
   - MLflow experiment tracking

2. **Text Generation**:
   - Different sampling strategies
   - Temperature and nucleus sampling
   - Interactive generation
   - Quality evaluation

3. **Production Practices**:
   - Checkpointing and model saving
   - Experiment tracking with MLflow
   - Hyperparameter tuning
   - Model comparison

### Next Steps for Improvement:

1. **Training Improvements**:
   - Train for more epochs (10-50)
   - Experiment with learning rate schedules
   - Try different optimizer settings
   - Use larger/more diverse datasets

2. **Model Improvements**:
   - Scale up model size (more layers/larger hidden size)
   - Adjust MoE parameters (num experts, top-k)
   - Experiment with different architectures

3. **Generation Improvements**:
   - Implement beam search
   - Add repetition penalty
   - Try constrained generation
   - Fine-tune on specific story genres

4. **Evaluation**:
   - Implement perplexity evaluation
   - Human evaluation of generated stories
   - Automatic metrics (BLEU, ROUGE)
   - A/B testing different configurations

### Resources:

- **Model Architecture**: See `src/storyteller/model/`
- **Training Code**: See `src/storyteller/training/`
- **Generation Code**: See `src/storyteller/inference/`
- **Configs**: See `configs/`
- **MLflow**: View at http://localhost:8080

### Community:

Share your generated stories and improvements! This is an educational project - experiment, break things, and learn!

Happy storytelling! üìö‚ú®

## Bonus: Export Model for Production

Export your trained model for deployment.

In [None]:
def export_model_for_production(model, tokenizer, save_dir):
    """
    Export model in production-ready format.
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Save model state dict
    model_path = save_dir / "model.pt"
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "model_config": model.config.__dict__,
        },
        model_path,
    )
    print(f"‚úì Model saved to {model_path}")

    # Save tokenizer
    if tokenizer is not None:
        tokenizer_path = save_dir / "tokenizer.json"
        tokenizer.save(str(tokenizer_path))
        print(f"‚úì Tokenizer saved to {tokenizer_path}")

    # Save config
    config_path = save_dir / "config.yaml"
    with open(config_path, "w") as f:
        yaml.dump(config, f)
    print(f"‚úì Config saved to {config_path}")

    # Save generation config template
    gen_config_path = save_dir / "generation_config.yaml"
    default_gen_config = {
        "max_new_tokens": 512,
        "temperature": 0.9,
        "top_p": 0.95,
        "top_k": 50,
        "repetition_penalty": 1.1,
    }
    with open(gen_config_path, "w") as f:
        yaml.dump(default_gen_config, f)
    print(f"‚úì Generation config saved to {gen_config_path}")

    print(f"\n‚úì Model exported successfully to {save_dir}")
    print("\nTo use in production:")
    print("  1. Load model state dict")
    print("  2. Load tokenizer")
    print("  3. Initialize Generator")
    print("  4. Generate stories!")


# Export the trained model
export_dir = (
    project_root
    / "models"
    / "production"
    / f"storyteller_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
export_model_for_production(model, tokenizer, export_dir)