# Direct Training - No Curriculum

Train directly on target environment without curriculum learning.


In [1]:
import os, warnings
os.environ["PYTHONWARNINGS"] = "ignore"
warnings.filterwarnings("ignore")

import time
import torch

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

from src.environment import make_vec_env
from src.evaluation import evaluate
from src.cnn import get_policy_kwargs
from src.filemanager import FileManager

In [2]:
fm: FileManager = FileManager("direct_training", output_dir="results")      

Experiment directory initialized: results/direct_training_20251028_181546


## Evaluation Callback

In [None]:
class StepCallback(BaseCallback):
    """Periodic evaluation"""
    
    def __init__(
        self,
        env_name: str,
        eval_freq: int = 5_000,
        n_eval: int = 30,
        visualize: bool = True,
        checkpoint_freq: int = 100_000
    ) -> None:
        super().__init__()
        self.env_name: str = env_name
        self.eval_freq: int = eval_freq
        self.n_eval: int = n_eval
        self.visualize: bool = visualize
        self.checkpoint_freq: int = checkpoint_freq
        self.steps: int = 0
        self.start_time: float = time.time()
        self.total_training_time: float = 0.0
    
    def _on_step(self) -> bool:
        self.steps += 1
        
        if self.steps % self.eval_freq == 0:
            assert isinstance(self.model, PPO)
            
            # Evaluate
            episode_batch = evaluate(self.model, self.env_name, self.n_eval)
            
            # Calculate elapsed time for this stage
            self.total_training_time = time.time() - self.start_time
            
            # Print evaluation results
            print(
                f"  Eval @ Total: {self.steps:,} | "
                f"Success: {episode_batch.success_rate:.1%} | "
                f"Len: {episode_batch.mean_length:.1f} | "
                f"Reward: {episode_batch.mean_reward:.2f} | "
                f"PolicyEnt: {episode_batch.mean_entropy:.3f} | "
                f"Time: {int(self.total_training_time//60):02d}:{int(self.total_training_time%60):02d}"
            )
            
            # Write evaluation
            fm.dump_eval_to_csv(
                total_step=self.steps,
                stage="direct",
                stage_step=self.steps,
                batch=episode_batch,
                model=self.model,
                allocation={}
            )
            
            # Visualize
            if self.visualize:
                from src.episode_visualization import visualize_eval_episode
                visualize_eval_episode(
                    model=self.model,
                    episode=episode_batch.episodes[0],
                    timestep=self.steps,
                    output_dir=fm.get_visualization_dir()
                )
        
        if self.steps % self.checkpoint_freq == 0:
            fm.save_checkpoint(self.model, "direct", self.steps)

        return True

    def get_total_time(self) -> float:
        """Get total training time."""
        return self.total_training_time

## Training Configuration

In [4]:
N_ENVS: int = 8
N_STEPS: int = 128
STEPS_PER_ROLLOUT = N_STEPS * N_ENVS

# Evaluate every 5 rollout
EVAL_FREQ: int = 5 * STEPS_PER_ROLLOUT
N_EVALS: int = 100

device = ""
if torch.cuda.is_available(): # type: ignore
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}")

# https://minigrid.farama.org/environments/minigrid/
ENV_NAME = "MiniGrid-DoorKey-8x8-v0"

TOTAL_STEPS: int = 70_000

Using device: cuda


In [5]:
# Create env and model
env = make_vec_env(ENV_NAME, N_ENVS)
model: PPO = PPO(
        "CnnPolicy",
        env,
        policy_kwargs=get_policy_kwargs(),
        learning_rate= 3e-4,
        n_steps=N_STEPS,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.02,
        verbose=0,
        device=device
    )

# Callback with tracking
callback = StepCallback(
    ENV_NAME,
    EVAL_FREQ,
    n_eval=N_EVALS,
    visualize=True,
    checkpoint_freq=100_000
)

In [6]:
# Train
print(f"\n{'='*60}\nDirect Training: {ENV_NAME}\n{'='*60}")

while callback.steps < TOTAL_STEPS:
    model.learn( # type: ignore
        total_timesteps=EVAL_FREQ,
        callback=callback,
        reset_num_timesteps=False
    )

env.close()

# Save final checkpoint
fm.save_checkpoint(model, "direct", callback.steps)

total_time: float = callback.get_total_time()
print(f"\n{'='*60}")
print("DIRECT TRAINING COMPLETE")
print(f"{'='*60}")
print(f"Total training time: {int(total_time//60):02d}:{int(total_time%60):02d}")
print(f"Total steps: {callback.steps:,}")


Direct Training: MiniGrid-DoorKey-8x8-v0
  Eval @ Total: 5,120 | Success: 3.0% | Len: 634.0 | Reward: 0.01 | PolicyEnt: 1.774 | Time: 04:21
    → Evaluation saved to: results/direct_training_20251028_181546/evaluations.csv
    → Saved visualization: results/direct_training_20251028_181546/visualizations/eval_5120.png
  Eval @ Total: 10,240 | Success: 0.0% | Len: 640.0 | Reward: 0.00 | PolicyEnt: 1.657 | Time: 08:11
    → Evaluation saved to: results/direct_training_20251028_181546/evaluations.csv
    → Saved visualization: results/direct_training_20251028_181546/visualizations/eval_10240.png
  Eval @ Total: 15,360 | Success: 0.0% | Len: 640.0 | Reward: 0.00 | PolicyEnt: 1.706 | Time: 12:36
    → Evaluation saved to: results/direct_training_20251028_181546/evaluations.csv
    → Saved visualization: results/direct_training_20251028_181546/visualizations/eval_15360.png
  Eval @ Total: 20,480 | Success: 0.0% | Len: 640.0 | Reward: 0.00 | PolicyEnt: 1.747 | Time: 16:37
    → Evaluation sav