# Curriculum Learning Visualization

Demonstrates adaptive curriculum with 3 parallel chains (15 tasks total), each with 5 sequential phases.

In [None]:
import jax
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from marin.rl.curriculum import (
    Curriculum,
    CurriculumConfig,
    LessonConfig,
    LessonDependency,
    compute_success_ratio,
)
from marin.rl.environments.base import EnvConfig
from marin.rl.types import RolloutStats

# Styling
sns.set_style("whitegrid")
sns.set_palette("husl", 5)
plt.rcParams["figure.dpi"] = 120

In [None]:
def create_success_trajectory(start_step: int, learning_speed: float, final_success: float, noise: float = 0.02):
    """Generate realistic sigmoid learning curve.

    Args:
        start_step: When learning begins (task unlocked).
        learning_speed: Rate of improvement (higher = faster learning).
        final_success: Asymptotic success rate.
        noise: Gaussian noise std for realism.

    Returns:
        Function that maps step -> reward.
    """
    rng = np.random.default_rng(42)

    def trajectory(step: int) -> float:
        if step < start_step:
            return 0.0
        progress = (step - start_step) * learning_speed
        # Sigmoid centered at progress=5
        base = final_success / (1 + np.exp(-progress + 5))
        return float(np.clip(base + rng.normal(0, noise), 0, 1))

    return trajectory

In [None]:
def simulate_parallel_curriculum(num_steps: int = 1200, seed: int = 42):
    """Simulate curriculum with 3 parallel chains (15 tasks total).

    Structure:
    - Chain A: Fast learning, high success (a1 → a2 → a3 → a4 → a5)
    - Chain B: Medium learning, medium success (b1 → b2 → b3 → b4 → b5)
    - Chain C: Slow learning, lower success (c1 → c2 → c3 → c4 → c5)

    Returns:
        Dictionary with simulation history:
        - steps: list of step numbers
        - success_rates: dict mapping lesson_id -> list of success rates
        - sampling_weights: dict mapping lesson_id -> list of sampling weights
        - events: list of (step, event_type, lesson_id) tuples
        - chains: dict mapping chain name -> list of lesson_ids
    """
    lessons = {}
    chains = {"A": [], "B": [], "C": []}

    # Chain A: Fast learners (high learning_speed, high final_success)
    for i in range(1, 6):
        lesson_id = f"a{i}"
        chains["A"].append(lesson_id)
        deps = [LessonDependency(dependency_id=f"a{i-1}", reward_threshold=0.0)] if i > 1 else []
        lessons[lesson_id] = LessonConfig(
            lesson_id=lesson_id,
            env_config=EnvConfig(env_class="marin.rl.environments.mock_env.MockEnv", env_args={}),
            dependencies=deps,
            stop_threshold=0.88 - (i - 1) * 0.02,  # 0.88, 0.86, 0.84, 0.82, 0.80
        )

    # Chain B: Medium learners
    for i in range(1, 6):
        lesson_id = f"b{i}"
        chains["B"].append(lesson_id)
        deps = [LessonDependency(dependency_id=f"b{i-1}", reward_threshold=0.0)] if i > 1 else []
        lessons[lesson_id] = LessonConfig(
            lesson_id=lesson_id,
            env_config=EnvConfig(env_class="marin.rl.environments.mock_env.MockEnv", env_args={}),
            dependencies=deps,
            stop_threshold=0.80 - (i - 1) * 0.02,  # 0.80, 0.78, 0.76, 0.74, 0.72
        )

    # Chain C: Slow learners
    for i in range(1, 6):
        lesson_id = f"c{i}"
        chains["C"].append(lesson_id)
        deps = [LessonDependency(dependency_id=f"c{i-1}", reward_threshold=0.0)] if i > 1 else []
        lessons[lesson_id] = LessonConfig(
            lesson_id=lesson_id,
            env_config=EnvConfig(env_class="marin.rl.environments.mock_env.MockEnv", env_args={}),
            dependencies=deps,
            stop_threshold=0.72 - (i - 1) * 0.02,  # 0.72, 0.70, 0.68, 0.66, 0.64
        )

    config = CurriculumConfig(
        lessons=lessons,
        eval_frequency=40,
        eval_n_examples=32,
        eval_n_generations=1,
        minimum_sample_probability=0.005,
    )
    curriculum = Curriculum(config)

    # Define learning curves with dramatically different speeds (~10x range)
    # All trajectories start at step 0 (no artificial offsets)
    # Staggering emerges naturally from different learning speeds
    trajectories = {}

    # Chain A: Very fast learning, plateaus at 95-100% (high performers)
    base_speeds_a = [0.080, 0.060, 0.050, 0.040, 0.035]
    base_success_a = [1.00, 0.98, 0.97, 0.96, 0.95]
    for i, lid in enumerate(chains["A"]):
        trajectories[lid] = create_success_trajectory(0, base_speeds_a[i], base_success_a[i], noise=0.012)

    # Chain B: Medium learning, plateaus at 75-85% (medium performers)
    base_speeds_b = [0.030, 0.025, 0.022, 0.020, 0.018]
    base_success_b = [0.85, 0.82, 0.80, 0.78, 0.75]
    for i, lid in enumerate(chains["B"]):
        trajectories[lid] = create_success_trajectory(0, base_speeds_b[i], base_success_b[i], noise=0.015)

    # Chain C: Slow learning, plateaus at 60-70% (lower performers)
    base_speeds_c = [0.012, 0.010, 0.009, 0.008, 0.007]
    base_success_c = [0.70, 0.68, 0.65, 0.63, 0.60]
    for i, lid in enumerate(chains["C"]):
        trajectories[lid] = create_success_trajectory(0, base_speeds_c[i], base_success_c[i], noise=0.018)

    # Initialize history tracking
    history = {
        "steps": [],
        "success_rates": {lid: [] for lid in lessons},
        "sampling_weights": {lid: [] for lid in lessons},
        "events": [],
        "chains": chains,
    }

    # Simulation loop
    for step in range(num_steps):
        curriculum.current_step = step

        # Track unlock and graduation events
        prev_unlocked = set(curriculum.unlocked)
        prev_graduated = set(curriculum.graduated)
        curriculum._unlock_and_graduate_lessons()

        for lid in curriculum.unlocked - prev_unlocked:
            history["events"].append((step, "unlock", lid))
        for lid in curriculum.graduated - prev_graduated:
            history["events"].append((step, "graduate", lid))

        # Sample lesson and generate rollout
        active_lessons = curriculum.unlocked - curriculum.graduated
        if active_lessons:
            lesson_id = curriculum.sample_lesson(jax.random.PRNGKey(step + seed))
            reward = trajectories[lesson_id](step)
            curriculum.update_lesson_stats(
                [RolloutStats(lesson_id=lesson_id, episode_reward=reward, env_example_id=f"ex_{step}")],
                mode="training",
                current_step=step,
            )

            # Periodic evaluation
            if step % config.eval_frequency == 0 and step > 0:
                for eval_lid in active_lessons:
                    eval_rewards = [trajectories[eval_lid](step) for _ in range(config.eval_n_examples)]
                    eval_stats = [
                        RolloutStats(lesson_id=eval_lid, episode_reward=r, env_example_id=f"eval_{step}_{i}")
                        for i, r in enumerate(eval_rewards)
                    ]
                    curriculum.update_lesson_stats(eval_stats, mode="eval", current_step=step)

        # Record state
        history["steps"].append(step)
        weights = curriculum.compute_sampling_weights()

        for lid in lessons:
            success = compute_success_ratio(curriculum.stats[lid], step)
            history["success_rates"][lid].append(success)
            history["sampling_weights"][lid].append(weights.get(lid, 0.0))

    return history

In [None]:
def plot_curriculum_progression(history):
    """Visualize curriculum learning progression across 3 parallel chains.

    Creates a 2-row, 3-column grid:
    - Row 1: Success rates over time for each chain
    - Row 2: Sampling weights over time for each chain
    """
    chains = history["chains"]
    steps = np.array(history["steps"])

    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle("Parallel Curriculum Learning: 3 Chains *  5 Lessons", fontsize=16, fontweight="bold")

    # Color palette for lessons within each chain
    colors = plt.cm.viridis(np.linspace(0.2, 0.9, 5))

    for col_idx, (chain_name, lesson_ids) in enumerate(sorted(chains.items())):
        # Top row: Success rates
        ax_success = axes[0, col_idx]
        ax_success.set_title(f"Chain {chain_name}: Success Rates", fontweight="bold")
        ax_success.set_xlabel("Training Step")
        ax_success.set_ylabel("Success Rate")
        ax_success.set_ylim(-0.05, 1.05)
        ax_success.grid(True, alpha=0.3)

        # Bottom row: Sampling weights
        ax_weights = axes[1, col_idx]
        ax_weights.set_title(f"Chain {chain_name}: Sampling Weights", fontweight="bold")
        ax_weights.set_xlabel("Training Step")
        ax_weights.set_ylabel("Sampling Probability")
        ax_weights.set_ylim(-0.02, 1.02)
        ax_weights.grid(True, alpha=0.3)

        # Plot each lesson in the chain
        for lesson_idx, lesson_id in enumerate(lesson_ids):
            success_rates = np.array(history["success_rates"][lesson_id])
            sampling_weights = np.array(history["sampling_weights"][lesson_id])

            # Success rate plot
            ax_success.plot(steps, success_rates, label=lesson_id.upper(),
                          color=colors[lesson_idx], linewidth=2, alpha=0.8)

            # Sampling weight plot
            ax_weights.plot(steps, sampling_weights, label=lesson_id.upper(),
                          color=colors[lesson_idx], linewidth=2, alpha=0.8)

            # Mark unlock events
            unlock_steps = [s for s, t, lid in history["events"] if t == "unlock" and lid == lesson_id]
            for unlock_step in unlock_steps:
                ax_success.axvline(unlock_step, color=colors[lesson_idx],
                                 linestyle="--", alpha=0.3, linewidth=1)
                ax_weights.axvline(unlock_step, color=colors[lesson_idx],
                                 linestyle="--", alpha=0.3, linewidth=1)

            # Mark graduation events
            grad_steps = [s for s, t, lid in history["events"] if t == "graduate" and lid == lesson_id]
            for grad_step in grad_steps:
                ax_success.axvline(grad_step, color=colors[lesson_idx],
                                 linestyle=":", alpha=0.4, linewidth=1.5)
                ax_weights.axvline(grad_step, color=colors[lesson_idx],
                                 linestyle=":", alpha=0.4, linewidth=1.5)

        ax_success.legend(loc="lower right", fontsize=9)
        ax_weights.legend(loc="upper right", fontsize=9)

    plt.tight_layout()
    plt.show()

    # Print event summary
    print("\n📊 Event Timeline:")
    print("  Legend: -- unlock | : graduate\n")

    for chain_name, lesson_ids in sorted(chains.items()):
        print(f"  Chain {chain_name}:")
        for lesson_id in lesson_ids:
            unlock_steps = [s for s, t, lid in history["events"] if t == "unlock" and lid == lesson_id]
            grad_steps = [s for s, t, lid in history["events"] if t == "graduate" and lid == lesson_id]
            unlock_str = f"unlocked@{unlock_steps[0]}" if unlock_steps else "never unlocked"
            grad_str = f"graduated@{grad_steps[0]}" if grad_steps else "not graduated"
            print(f"    {lesson_id.upper()}: {unlock_str}, {grad_str}")
        print()

In [None]:
# Run simulation
print("Simulating 3-chain parallel curriculum...")
history = simulate_parallel_curriculum(num_steps=1200, seed=42)
print("✓ Simulation complete")
print(f"  Total tasks: {len(history['chains']['A']) + len(history['chains']['B']) + len(history['chains']['C'])}")
print(f"  Unlock events: {sum(1 for _, t, _ in history['events'] if t == 'unlock')}")
print(f"  Graduate events: {sum(1 for _, t, _ in history['events'] if t == 'graduate')}")

# Visualize
plot_curriculum_progression(history)