# Curriculum Learning - Self-Contained

Automatic curriculum learning for MiniGrid environments.

In [8]:
import os, warnings
os.environ["PYTHONWARNINGS"] = "ignore"
warnings.filterwarnings("ignore")

from typing import List, Dict
import numpy as np
import time
import torch

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import SubprocVecEnv

from src.environment import make_vec_env
from src.evaluation import evaluate
from src.cnn import get_policy_kwargs
from src.filemanager import FileManager

# Result directory

In [9]:
fm: FileManager = FileManager("curriculum_learning", output_dir="results")

Experiment directory initialized: results/curriculum_learning_20251028_174618


## Curriculum Teacher


In [10]:
class CurriculumTeacher:
    """Manages automatic progression through stages."""
    
    def __init__(self, stages: List[str], threshold: float = 0.90, window: int = 5) -> None:
        self.stages: List[str] = stages
        self.threshold: float = threshold
        self.window: int = window
        self.stage_idx: int = 0
        self.performance: Dict[str, List[float]] = {s: [] for s in stages}
        self._completed: bool = False
    
    def record(self, success_rate: float) -> None:
        """Record performance for current stage."""
        self.performance[self.current_stage()].append(success_rate)
    
    def should_advance(self) -> bool:
        """No eval below threshold."""
        history = self.performance[self.current_stage()]
        
        if len(history) < self.window:
            return False
        
        recent = history[-self.window:]
        min_recent = float(np.min(recent))
        
        return min_recent >= self.threshold

    
    def advance(self) -> None:
        """Move to next stage or mark as complete."""
        if self.stage_idx < len(self.stages) - 1:
            self.stage_idx += 1
        else:
            self._completed = True
    
    def current_stage(self) -> str:
        """Get current stage name."""
        return self.stages[self.stage_idx]
    
    def is_all_stages_complete(self) -> bool:
        """Check if all stages completed."""
        return self._completed
    
    def get_current_performance_summary(self) -> str:
        """Get human-readable summary of current stage performance."""
        history = self.performance[self.current_stage()]
        if not history:
            return "No evaluations yet"
        
        recent = history[-self.window:] if len(history) >= self.window else history
        mean = np.mean(recent)
        summary = f"Last {len(recent)}: {mean:.1%}"
       
        return summary


## Step Callback


In [11]:
class StepCallback(BaseCallback):
    """Periodic evaluation and entropy management."""
    
    def __init__(
        self,
        teacher: CurriculumTeacher,
        eval_freq: int = 5_000,
        n_eval: int = 30,
        visualize: bool = True,
    ) -> None:
        super().__init__()
        self.teacher: CurriculumTeacher = teacher
        self.eval_freq: int = eval_freq
        self.n_eval: int = n_eval
        self.visualize: bool = visualize
        self.stage_steps: int = 0
        self.total_steps: int = 0
        self.stage_start_time: float = time.time()
        self.total_training_time: float = 0.0
    
    def _on_step(self) -> bool:
        self.stage_steps += 1
        self.total_steps += 1
        
        if self.stage_steps % self.eval_freq == 0:
            assert isinstance(self.model, PPO)
            
            # Evaluate
            episode_batch = evaluate(self.model, self.teacher.current_stage(), self.n_eval)
            self.teacher.record(episode_batch.success_rate)
            
            # Calculate elapsed times
            stage_elapsed: float = time.time() - self.stage_start_time
            total_elapsed: float = self.get_total_time()
            
            # Print evaluation results
            print(
                f"  Eval @ Stage {self.stage_steps:,} | Total: {self.total_steps:,} | "
                f"Success: {episode_batch.success_rate:.1%} | "
                f"Reward: {episode_batch.mean_reward:.2f} | "
                f"PolicyEnt: {episode_batch.mean_entropy:.3f} | "
                f"StageTime: {int(stage_elapsed//60):02d}:{int(stage_elapsed%60):02d} | "
                f"TotalTime: {int(total_elapsed//60):02d}:{int(total_elapsed%60):02d}"
            )
            
            # Write evaluation
            fm.dump_eval_to_csv(
                total_step=self.total_steps,
                stage=self.teacher.current_stage(),
                stage_step=self.stage_steps,
                batch=episode_batch,
                model=self.model
            )
            
            # Visualize
            if self.visualize:
                from src.episode_visualization import visualize_eval_episode
                visualize_eval_episode(
                    model=self.model,
                    episode=episode_batch.episodes[0],
                    timestep=self.total_steps,
                    output_dir=fm.get_visualization_dir()
                )
        
        return True

    
    def reset_for_stage(self) -> None:
        """Reset stage counter and timer for new environment."""
        # Accumulate time from completed stage
        stage_elapsed: float = time.time() - self.stage_start_time
        self.total_training_time += stage_elapsed
        
        # Reset for new stage
        self.stage_steps = 0
        self.stage_start_time = time.time()
    
    def get_stage_elapsed(self) -> float:
        """Get elapsed time for current stage in seconds."""
        return time.time() - self.stage_start_time
    
    def get_total_time(self) -> float:
        """Get total training time across all stages in seconds."""
        return self.total_training_time + self.get_stage_elapsed()

## Training


In [12]:
N_ENVS: int = 8
N_STEPS: int = 128
STEPS_PER_ROLLOUT = N_STEPS * N_ENVS

# Evaluate every 5 rollout
EVAL_FREQ: int = 5 * STEPS_PER_ROLLOUT
N_EVALS: int = 100

device = ""
if torch.cuda.is_available(): # type: ignore
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

# Curriculum
THRESHOLD: float = 0.90
WINDOW: int = 4

# https://minigrid.farama.org/environments/minigrid/

STAGES: List[str] = [
    "MiniGrid-DoorKey-5x5-v0",
    "MiniGrid-DoorKey-6x6-v0",
    "MiniGrid-DoorKey-8x8-v0",
]


TOTAL_STEPS: int = 200_000

# PPO
def make_model(env: SubprocVecEnv) -> PPO:
    return PPO(
        "CnnPolicy",
        env,
        policy_kwargs=get_policy_kwargs(),
        learning_rate= 3e-4,
        n_steps=N_STEPS,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.02,
        verbose=0,
        device=device
    )

Using device: cuda


In [13]:
def train_curriculum() -> None:

    # Teacher
    teacher: CurriculumTeacher = CurriculumTeacher(STAGES, threshold=THRESHOLD, window=WINDOW)

    # Initial env and model
    env = make_vec_env(teacher.current_stage(), N_ENVS)
    model = make_model(env)

    # Callback with tracking
    callback: StepCallback = StepCallback(
        teacher,
        eval_freq=EVAL_FREQ,
        n_eval=N_EVALS,
        visualize=True,
    )

    # Train through stages
    while not teacher.is_all_stages_complete():
        print(f"\n{'='*60}\nStage {teacher.stage_idx + 1}/{len(STAGES)}: {teacher.current_stage()}\n{'='*60}")
        # Train until stage mastered or max steps
        while callback.total_steps < TOTAL_STEPS:
            model.learn(  # type: ignore
                total_timesteps=EVAL_FREQ,
                callback=callback,
                reset_num_timesteps=False
            )
            
            # Check advancement
            if teacher.should_advance():
                print(f"\n  ✓ Stage passed after {callback.stage_steps:,} steps")
                # Save checkpoint for this completed stage
                fm.save_checkpoint(
                    model=model,
                    stage=teacher.current_stage(),
                    total_step=callback.total_steps
                )
                break

        if callback.total_steps >= TOTAL_STEPS:
            raise ValueError(f"Training failed after {callback.total_steps:,} steps")

        # Advance
        teacher.advance()
        if not teacher.is_all_stages_complete():
            env.close()
            callback.reset_for_stage()
            env = make_vec_env(teacher.current_stage(), N_ENVS)
            model.set_env(env)  # type: ignore
    
    env.close()
    total_time = callback.get_total_time()
    
    print(f"\n{'='*60}")
    print("CURRICULUM COMPLETE")
    print(f"{'='*60}")
    print(f"Total training time: {int(total_time//60):02d}:{int(total_time%60):02d}")
    print(f"Total steps: {callback.total_steps:,}")

In [14]:
train_curriculum()


Stage 1/3: MiniGrid-DoorKey-5x5-v0
  Eval @ Stage 5,120 | Total: 5,120 | Success: 26.0% | Reward: 0.11 | PolicyEnt: 1.714 | StageTime: 02:01 | TotalTime: 02:01
    → Evaluation saved to: results/curriculum_learning_20251028_174618/evaluations.csv
    → Saved visualization: results/curriculum_learning_20251028_174618/visualizations/eval_5120.png
  Eval @ Stage 10,240 | Total: 10,240 | Success: 100.0% | Reward: 0.86 | PolicyEnt: 1.371 | StageTime: 03:13 | TotalTime: 03:13
    → Evaluation saved to: results/curriculum_learning_20251028_174618/evaluations.csv
    → Saved visualization: results/curriculum_learning_20251028_174618/visualizations/eval_10240.png
  Eval @ Stage 15,360 | Total: 15,360 | Success: 100.0% | Reward: 0.96 | PolicyEnt: 0.281 | StageTime: 04:17 | TotalTime: 04:17
    → Evaluation saved to: results/curriculum_learning_20251028_174618/evaluations.csv
    → Saved visualization: results/curriculum_learning_20251028_174618/visualizations/eval_15360.png
  Eval @ Stage 20,48

# Run

In [15]:
# from src.environment import run_episode
# model_path = "results/curriculum_learning_20251027_134503/checkpoints/MiniGrid-KeyCorridorS3R2-v0_step_475136.zip"
# model = PPO.load(model_path) # type: ignore
# episode_data = run_episode(
#     model=model, 
#     env_name="MiniGrid-KeyCorridorS3R2-v0", 
#     seed=42, 
#     render_mode="human", 
#     deterministic=True
# )

# print(episode_data)

In [16]:
# from src.episode_visualization import visualize_eval_episode
# visualize_eval_episode(
#     model=model,
#     episode=episode_data,
#     timestep=-1,
#     output_dir="./"
# )