# Curriculum Learning Visualization

In [None]:
# Imports
from dataclasses import dataclass, field


# Legend
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from marin.rl.curriculum import (
    Curriculum,
    CurriculumConfig,
    LessonConfig,
    LessonDependency,
)
from marin.rl.environments.base import EnvConfig
from marin.rl.types import RolloutStats

# Set style
sns.set_style("whitegrid")
sns.set_palette("husl")
plt.rcParams['figure.dpi'] = 100

print("✓ Imports successful")

## Data Structures and Simulation Helpers

In [None]:
@dataclass
class CurriculumHistory:
    """Track curriculum state over time for visualization."""

    steps: list[int] = field(default_factory=list)
    lesson_states: dict[str, list[str]] = field(default_factory=dict)  # lesson_id -> [state1, state2, ...]
    sampling_weights: dict[str, list[float]] = field(default_factory=dict)  # lesson_id -> [weight1, weight2, ...]
    success_rates: dict[str, list[float]] = field(default_factory=dict)  # lesson_id -> [rate1, rate2, ...]
    eval_success_rates: dict[str, list[float]] = field(default_factory=dict)
    sample_counts: dict[str, list[int]] = field(default_factory=dict)
    metrics: dict[str, list[float]] = field(default_factory=dict)  # metric_name -> [value1, value2, ...]
    events: list[dict] = field(default_factory=list)  # [{step, type, lesson_id, details}, ...]


def flat_rewards(base_reward: float = 0.8, noise_std: float = 0.01):
    """Generate flat rewards with minimal noise (should plateau quickly)."""
    def gen(step: int) -> float:
        return base_reward + np.random.default_rng().normal(0, noise_std)
    return gen


def improving_rewards(start: float = 0.1, end: float = 0.9, num_steps: int = 500):
    """Generate linearly improving rewards (should not plateau until end)."""
    def gen(step: int) -> float:
        progress = min(step / num_steps, 1.0)
        base = start + (end - start) * progress
        return base + np.random.default_rng().normal(0, 0.02)
    return gen


def noisy_stable_rewards(base_reward: float = 0.5, noise_std: float = 0.05):
    """Generate noisy but statistically stable rewards (should plateau)."""
    def gen(step: int) -> float:
        return base_reward + np.random.default_rng().normal(0, noise_std)
    return gen


def sigmoid_rewards(midpoint: int = 250, steepness: float = 0.02, start: float = 0.2, end: float = 0.9):
    """Generate sigmoid improvement curve (plateaus after reaching end)."""
    def gen(step: int) -> float:
        x = (step - midpoint) * steepness
        sigmoid = 1 / (1 + np.exp(-x))
        base = start + (end - start) * sigmoid
        return base + np.random.default_rng().normal(0, 0.01)
    return gen

print("✓ Helper functions defined")

In [None]:
def simulate_curriculum_run(config: CurriculumConfig, num_steps: int = 1000, seed: int = 42) -> CurriculumHistory:
    """Simulate a curriculum run with synthetic rollout data."""
    rng = np.random.default_rng(seed)
    curriculum = Curriculum(config)
    history = CurriculumHistory()

    # Initialize tracking for each lesson
    for lesson_id in config.lessons:
        history.lesson_states[lesson_id] = []
        history.sampling_weights[lesson_id] = []
        history.success_rates[lesson_id] = []
        history.eval_success_rates[lesson_id] = []
        history.sample_counts[lesson_id] = []

    # Define reward patterns for different lessons (for simulation)
    reward_patterns = {
        "easy": flat_rewards(base_reward=0.85),
        "medium": sigmoid_rewards(midpoint=200, start=0.3, end=0.75),
        "hard": improving_rewards(start=0.1, end=0.6, num_steps=800),
        "intermediate": sigmoid_rewards(midpoint=400, start=0.2, end=0.65),
        "advanced": improving_rewards(start=0.05, end=0.5, num_steps=1000),
    }

    for step in range(num_steps):
        curriculum.current_step = step

        # Update unlocked and graduated lessons
        prev_unlocked = set(curriculum.unlocked)
        prev_graduated = set(curriculum.graduated)

        curriculum.update_lessons()

        # Record unlock events
        new_unlocked = curriculum.unlocked - prev_unlocked
        for lesson_id in new_unlocked:
            history.events.append(
                {"step": step, "type": "unlock", "lesson_id": lesson_id, "details": "Dependencies satisfied"}
            )

        # Record graduation events
        new_graduated = curriculum.graduated - prev_graduated
        for lesson_id in new_graduated:
            from marin.rl.curriculum import compute_success_ratio, is_plateaued
            stats = curriculum.stats[lesson_id]
            history.events.append(
                {
                    "step": step,
                    "type": "graduate",
                    "lesson_id": lesson_id,
                    "details": {
                        "success_rate": compute_success_ratio(stats, step),
                        "plateaued": is_plateaued(stats),
                    },
                }
            )

        # Sample from curriculum and generate synthetic rollout
        if curriculum.unlocked and not all(lid in curriculum.graduated for lid in curriculum.unlocked):
            lesson_id = curriculum.sample_lesson(prng_seed=step)

            # Generate synthetic reward based on pattern
            reward_gen = reward_patterns.get(lesson_id, flat_rewards(0.5))
            reward = np.clip(reward_gen(step), 0.0, 1.0)

            # Update curriculum stats
            rollout_stats = RolloutStats(lesson_id=lesson_id, episode_reward=reward, env_example_id=f"ex_{step}")
            curriculum.update_lesson_stats([rollout_stats], mode="training")

            # Periodic evaluation
            if step % config.eval_frequency == 0 and step > 0:
                for eval_lesson_id in curriculum.unlocked:
                    if eval_lesson_id not in curriculum.graduated:
                        # Simulate evaluation (slightly higher performance than training)
                        eval_gen = reward_patterns.get(eval_lesson_id, flat_rewards(0.5))
                        eval_rewards = [
                            np.clip(eval_gen(step) + rng.normal(0, 0.02), 0.0, 1.0)
                            for _ in range(config.eval_n_examples)
                        ]
                        for eval_reward in eval_rewards:
                            eval_stats = RolloutStats(
                                lesson_id=eval_lesson_id,
                                episode_reward=eval_reward,
                                env_example_id=f"eval_{step}",
                            )
                            curriculum.update_lesson_stats([eval_stats], mode="eval")

        # Record state
        history.steps.append(step)
        weights = curriculum.compute_sampling_weights()

        for lesson_id in config.lessons:
            # Determine state
            if lesson_id in curriculum.graduated:
                state = "graduated"
            elif lesson_id in curriculum.unlocked:
                state = "active"
            else:
                state = "locked"

            history.lesson_states[lesson_id].append(state)
            history.sampling_weights[lesson_id].append(weights.get(lesson_id, 0.0))

            stats = curriculum.stats[lesson_id]
            history.success_rates[lesson_id].append(stats.training_stats.smoothed_success)
            history.eval_success_rates[lesson_id].append(
                stats.eval_stats.smoothed_success if stats.eval_stats.last_update_step >= 0 else np.nan
            )
            history.sample_counts[lesson_id].append(stats.training_stats.total_samples)

        # Record curriculum-level metrics
        curriculum_metrics = curriculum.get_metrics()
        for key in ["sampling_entropy", "effective_lessons", "mean_success"]:
            if key not in history.metrics:
                history.metrics[key] = []
            history.metrics[key].append(curriculum_metrics[key])

    return history

print("✓ Simulation function defined")

## Setup: Create Curriculum Configuration

Define a 5-lesson curriculum with dependencies:
- **easy** → unlocked from start
- **medium** → requires 50% on easy
- **intermediate** → requires 60% on easy
- **hard** → requires 60% on medium
- **advanced** → requires 70% on medium AND 65% on intermediate

In [None]:
lessons = {
    "easy": LessonConfig(
        lesson_id="easy",
        env_config=EnvConfig(env_class="marin.rl.environments.mock_env.MockEnv", env_args={"task_type": "cats"}),
        stop_threshold=0.9,
    ),
    "medium": LessonConfig(
        lesson_id="medium",
        env_config=EnvConfig(
            env_class="marin.rl.environments.mock_env.MockEnv", env_args={"task_type": "addition"}
        ),
        dependencies=[LessonDependency(dependency_id="easy", reward_threshold=0.5)],
        stop_threshold=0.85,
    ),
    "hard": LessonConfig(
        lesson_id="hard",
        env_config=EnvConfig(
            env_class="marin.rl.environments.mock_env.MockEnv", env_args={"task_type": "opposites"}
        ),
        dependencies=[LessonDependency(dependency_id="medium", reward_threshold=0.6)],
        stop_threshold=0.75,
    ),
    "intermediate": LessonConfig(
        lesson_id="intermediate",
        env_config=EnvConfig(
            env_class="marin.rl.environments.mock_env.MockEnv", env_args={"task_type": "number_comparison"}
        ),
        dependencies=[LessonDependency(dependency_id="easy", reward_threshold=0.6)],
        stop_threshold=0.8,
    ),
    "advanced": LessonConfig(
        lesson_id="advanced",
        env_config=EnvConfig(
            env_class="marin.rl.environments.mock_env.MockEnv", env_args={"task_type": "opposites"}
        ),
        dependencies=[
            LessonDependency(dependency_id="medium", reward_threshold=0.7),
            LessonDependency(dependency_id="intermediate", reward_threshold=0.65),
        ],
        stop_threshold=0.7,
    ),
}

config = CurriculumConfig(
    lessons=lessons,
    eval_frequency=100,
    eval_n_examples=32,
    eval_n_generations=1,
    temperature=1.0,
    minimum_sample_probability=0.01,
)

print("✓ Curriculum configured with 5 lessons")

## Figure 1: Plateau Detection

Shows how the plateau detection algorithm works with 4 different reward patterns:
- **Flat**: Stable rewards trigger plateau quickly
- **Improving**: Linear improvement prevents plateau
- **Noisy Stable**: Plateau despite variance
- **Sigmoid**: Eventually plateaus after curve flattens

In [None]:
from marin.rl.curriculum import is_plateaued, LessonStats, PerformanceStats

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("Plateau Detection Examples (Conservative Algorithm)", fontsize=16, fontweight="bold")

patterns = [
    ("Flat Rewards (Plateaus)", flat_rewards(0.8, 0.01)),
    ("Improving Rewards (No Plateau)", improving_rewards(0.1, 0.9, 300)),
    ("Noisy Stable (Plateaus)", noisy_stable_rewards(0.5, 0.05)),
    ("Sigmoid (Eventually Plateaus)", sigmoid_rewards(150, 0.02, 0.2, 0.85)),
]

window = 50
threshold = 0.01

for ax, (title, gen) in zip(axes.flat, patterns, strict=False):
    # Generate reward sequence
    steps = 300
    rewards = [gen(i) for i in range(steps)]

    # Track plateau status using actual curriculum algorithm
    plateau_status = []
    for i in range(steps):
        if i < window:
            plateau_status.append(False)
        else:
            # Use the actual is_plateaued function from curriculum
            recent_history = rewards[i - window : i]
            stats = LessonStats(training_stats=PerformanceStats(reward_history=recent_history))
            plateau_status.append(is_plateaued(stats, window=window, threshold=threshold))

    # Plot rewards
    ax.plot(rewards, label="Rewards", alpha=0.7, linewidth=1.5)

    # Shade plateau regions
    plateau_regions = np.array(plateau_status)
    ax.fill_between(
        range(steps), 0, 1, where=plateau_regions, alpha=0.2, color="green", label="Plateaued", step="mid"
    )

    # Add linear regression for last window
    if len(rewards) >= window:
        from scipy import stats as scipy_stats
        recent = np.array(rewards[-window:])
        x = np.arange(len(recent))
        result = scipy_stats.linregress(x, recent)
        trend_line = result.slope * x + result.intercept
        start_idx = len(rewards) - window
        ax.plot(range(start_idx, len(rewards)), trend_line, "r--", label="Recent Trend", linewidth=2)

    ax.set_title(title, fontweight="bold")
    ax.set_xlabel("Step")
    ax.set_ylabel("Reward")
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc="best")
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Run Simulation

Simulate 1000 training steps with synthetic rewards to generate data for visualization.

In [None]:
print("Running simulation...")
history = simulate_curriculum_run(config, num_steps=1000, seed=42)
print(f"✓ Simulation complete ({len(history.steps)} steps)")
print(f"  Events recorded: {len(history.events)}")
print(f"  Unlock events: {sum(1 for e in history.events if e['type'] == 'unlock')}")
print(f"  Graduation events: {sum(1 for e in history.events if e['type'] == 'graduate')}")

## Figure 2: Sampling Weights Over Time

In [None]:
fig, ax = plt.subplots(figsize=(14, 6))

steps = np.array(history.steps)
lesson_ids = list(history.sampling_weights.keys())

# Plot each lesson's weight as a line (clearer than stacked area)
for lesson_id in lesson_ids:
    weights = history.sampling_weights[lesson_id]
    ax.plot(steps, weights, label=lesson_id, linewidth=2, alpha=0.8)

ax.set_title("Sampling Weight Distribution Over Time", fontsize=14, fontweight="bold")
ax.set_xlabel("Training Step")
ax.set_ylabel("Sampling Probability")
ax.set_ylim(0, 1.05)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Figure 3: Lesson Lifecycle Timeline

Gantt chart showing lesson state transitions:
- **Gray**: Locked (dependencies not met)
- **Green**: Active (unlocked and sampling)
- **Blue**: Graduated (mastered)

Events:
- **Gold diamonds**: Unlock events
- **Blue squares**: Graduation events

In [None]:
fig, ax = plt.subplots(figsize=(14, 6))

lesson_ids = list(history.lesson_states.keys())
state_colors = {"locked": "#cccccc", "active": "#90EE90", "graduated": "#87CEEB"}

y_pos = {lid: i for i, lid in enumerate(lesson_ids)}

# Plot state regions
for lesson_id in lesson_ids:
    states = history.lesson_states[lesson_id]
    current_state = None
    start_idx = 0

    for idx, state in enumerate([*states, None]):  # Add None to trigger final segment
        if state != current_state:
            if current_state is not None:
                # Draw rectangle for previous state
                color = state_colors[current_state]
                ax.barh(
                    y_pos[lesson_id],
                    idx - start_idx,
                    left=history.steps[start_idx],
                    height=0.8,
                    color=color,
                    edgecolor="black",
                    linewidth=0.5,
                )
            current_state = state
            start_idx = idx

# Mark events
for event in history.events:
    lesson_id = event["lesson_id"]
    step = event["step"]
    event_type = event["type"]

    marker = "D" if event_type == "unlock" else "s"
    color = "gold" if event_type == "unlock" else "blue"
    ax.scatter(step, y_pos[lesson_id], marker=marker, s=100, color=color, edgecolor="black", zorder=10)

ax.set_yticks(range(len(lesson_ids)))
ax.set_yticklabels(lesson_ids)
ax.set_xlabel("Training Step")
ax.set_title("Lesson Lifecycle Timeline", fontsize=14, fontweight="bold")
ax.grid(alpha=0.3, axis="x")

legend_elements = [
    Patch(facecolor=state_colors["locked"], label="Locked"),
    Patch(facecolor=state_colors["active"], label="Active"),
    Patch(facecolor=state_colors["graduated"], label="Graduated"),
    plt.Line2D([0], [0], marker="D", color="w", markerfacecolor="gold", markersize=10, label="Unlock Event"),
    plt.Line2D([0], [0], marker="s", color="w", markerfacecolor="blue", markersize=10, label="Graduate Event"),
]
ax.legend(handles=legend_elements, loc="upper left", bbox_to_anchor=(1, 1))

plt.tight_layout()
plt.show()

## Figure 4: Weight Function

**Left**: Base quadratic weight function w = max(0, -4s² + 4s) peaks at 50% success

**Right**: Exploration bonus amplifies weights for under-sampled lessons

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Panel 1: Base quadratic weight
success_rates = np.linspace(0, 1, 100)
base_weights = np.maximum(0, -4 * success_rates**2 + 4 * success_rates)

ax1.plot(success_rates, base_weights, linewidth=3, color="steelblue")
ax1.axvline(0.5, color="red", linestyle="--", alpha=0.5, label="Peak at 50%")
ax1.fill_between(success_rates, base_weights, alpha=0.3, color="steelblue")

ax1.set_title("Base Weight Function", fontsize=14, fontweight="bold")
ax1.set_xlabel("Success Rate")
ax1.set_ylabel("Base Weight")
ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1.1)
ax1.grid(alpha=0.3)
ax1.legend()

# Panel 2: Exploration bonus comparison
sample_counts = np.arange(0, 200)
exploration_bonus = 1.0 + np.exp(-0.01 * sample_counts)

# Show effect on weight for different initial success rates
for success_rate in [0.3, 0.5, 0.7]:
    base = max(0, -4 * success_rate**2 + 4 * success_rate)
    weighted = base * exploration_bonus
    ax2.plot(sample_counts, weighted, label=f"Success={success_rate:.1f}", linewidth=2)

ax2.set_title("Exploration Bonus Effect on Weights", fontsize=14, fontweight="bold")
ax2.set_xlabel("Number of Samples")
ax2.set_ylabel("Final Weight")
ax2.grid(alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.show()

## Figure 7: Health Metrics Dashboard

Four key curriculum health indicators:
1. **Sampling Entropy**: Diversity measure (low = collapse to single lesson)
2. **Effective Lessons**: Inverse Simpson index (how many meaningfully contribute)
3. **Mean Success**: Average performance across active lessons
4. **Lesson Counts**: Number in each state over time

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("Curriculum Health Metrics", fontsize=16, fontweight="bold")

steps = np.array(history.steps)

# Panel 1: Sampling entropy
ax = axes[0, 0]
ax.plot(steps, history.metrics["sampling_entropy"], linewidth=2, color="steelblue")
ax.axhline(0.5, color="red", linestyle="--", alpha=0.5, label="Low diversity threshold")
ax.set_title("Sampling Entropy (Diversity)", fontweight="bold")
ax.set_xlabel("Step")
ax.set_ylabel("Entropy")
ax.grid(alpha=0.3)
ax.legend()

# Panel 2: Effective lessons
ax = axes[0, 1]
ax.plot(steps, history.metrics["effective_lessons"], linewidth=2, color="darkgreen")
ax.axhline(2.0, color="red", linestyle="--", alpha=0.5, label="Warning threshold")
ax.set_title("Effective Lessons (Inverse Simpson)", fontweight="bold")
ax.set_xlabel("Step")
ax.set_ylabel("Effective Count")
ax.grid(alpha=0.3)
ax.legend()

# Panel 3: Mean success rate
ax = axes[1, 0]
ax.plot(steps, history.metrics["mean_success"], linewidth=2, color="darkorange")
ax.set_title("Mean Success Rate (Active Lessons)", fontweight="bold")
ax.set_xlabel("Step")
ax.set_ylabel("Success Rate")
ax.set_ylim(0, 1)
ax.grid(alpha=0.3)

# Panel 4: Lesson counts
ax = axes[1, 1]
lesson_ids = list(history.lesson_states.keys())

locked_count = []
active_count = []
graduated_count = []

for step_idx in range(len(steps)):
    locked = sum(1 for lid in lesson_ids if history.lesson_states[lid][step_idx] == "locked")
    active = sum(1 for lid in lesson_ids if history.lesson_states[lid][step_idx] == "active")
    graduated = sum(1 for lid in lesson_ids if history.lesson_states[lid][step_idx] == "graduated")

    locked_count.append(locked)
    active_count.append(active)
    graduated_count.append(graduated)

ax.plot(steps, locked_count, label="Locked", linewidth=2)
ax.plot(steps, active_count, label="Active", linewidth=2)
ax.plot(steps, graduated_count, label="Graduated", linewidth=2)
ax.set_title("Lesson State Counts", fontweight="bold")
ax.set_xlabel("Step")
ax.set_ylabel("Count")
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()