In [1]:
%load_ext autoreload
%autoreload 2
#!pip install sb3-contrib

In [2]:
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback
import time
from collections import deque

In [3]:
class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

class GradNormCallback(BaseCallback):
    def _on_step(self):
        if hasattr(self.model.policy, 'parameters'):
            total_norm = 0
            for p in self.model.policy.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            
            self.logger.record("train/grad_norm", total_norm)
        return True

from stable_baselines3.common.callbacks import BaseCallback

# up the difficulty upon reaching certain (rolling) mean episode length thresholds
class CurriculumCallback(BaseCallback):
    def __init__(self, env, verbose = 1):
        super().__init__(verbose = verbose)
        self.env = env
        self.verbose = verbose
        self.thresholds = [7.5, 5] # episode length required before upping difficulty
        self.current_stage = 0
        self.ep_lengths = deque(maxlen=30)

    def _on_step(self):
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" not in info:
                continue
                
            ep_len = info["episode"]["l"]
            self.ep_lengths.append(ep_len)
            avg_len = sum(self.ep_lengths) / len(self.ep_lengths)
            if self.current_stage < len(self.thresholds) and avg_len > self.thresholds[self.current_stage]:
                self.current_stage += 1
                self.env.set_difficulty(self.current_stage)
                if self.verbose:
                    print(f"Average episode length: {avg_len:.2f} — Switched to difficulty: {self.current_stage}")
        return True

In [4]:
# !! Discussion below points to the constraint Simplex being a bug, not a problem with my env or training.
# These lines turn off many validation checks, including the Simplex one causing us problems.
#
# Alternatively, the discussions also suggest modifying the check threshold in torch/distributions/constraints.py:: class _Simplex.
#        # Current:
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
#        
#        # Fix:
#        tol = torch.finfo(value.dtype).eps * 10 * value.size(-1) ** 0.5
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < tol)
#
# Both seem to work. I'm going for the latter. Its possibly is making my training slower and less stable though.
#
# Discussion:
# https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209/9
# https://github.com/pytorch/pytorch/issues/87468
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81

# from torch.distributions import Distribution
# Distribution.set_default_validate_args(False)

In [5]:
from apad_env import APADEnv

In [6]:
env = APADEnv()

In [None]:
model.save("mppo_model_25k_2025-06-19_1500")

In [9]:
env = APADEnv(-1,-1)
#env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    tensorboard_log="./maskable_ppo_logs_10/",
    ent_coef=0.1,
    verbose=1,
)
total_timesteps = 500000
checkpoint_interval = 50000
for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
    else:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
    model.save(f"mppo_model_{9}")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_10/PPO_3
Step 1000, 24s elapsed, 1169s remaining
Step 2000, 47s elapsed, 1126s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.26     |
|    ep_rew_mean     | 1.1      |
| time/              |          |
|    fps             | 42       |
|    iterations      | 1        |
|    time_elapsed    | 48       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 73s elapsed, 1138s remaining
Step 4000, 96s elapsed, 1099s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.34        |
|    ep_rew_mean          | 1.16        |
| time/                   |             |
|    fps                  | 41          |
|    iterations           | 2           |
|   

In [12]:
env = APADEnv(-1)
#env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    tensorboard_log="./maskable_ppo_logs_10/",
    ent_coef=0.03,
    verbose=1,
)
total_timesteps = 300000
checkpoint_interval = 50000
for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
    else:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
    model.save(f"mppo_model_{11}")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_10/PPO_6
Step 1000, 25s elapsed, 1228s remaining
Step 2000, 50s elapsed, 1209s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3        |
|    ep_rew_mean     | 0.99     |
| time/              |          |
|    fps             | 39       |
|    iterations      | 1        |
|    time_elapsed    | 51       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 78s elapsed, 1224s remaining
Step 4000, 104s elapsed, 1191s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.89        |
|    ep_rew_mean          | 0.925       |
| time/                   |             |
|    fps                  | 38          |
|    iterations           | 2           |
|  

KeyboardInterrupt: 

In [None]:
env = APADEnv(-1,-1)
env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    #n_steps = 512,
    tensorboard_log="./maskable_ppo_logs_9/",
    verbose=1,
)
model.learn(total_timesteps=50000, reset_num_timesteps=False, callback=[TimerCallback(), GradNormCallback()])
model.save(f"mppo_model_9")

In [None]:
for i in range(4):
    env.reset()
    model = None
    model = MaskablePPO(
        "MlpPolicy",
        env,
        #n_steps = 512,
        tensorboard_log="./maskable_ppo_logs_9/",
        verbose=1,
    )
    model.learn(total_timesteps=50000, reset_num_timesteps=True, callback=[TimerCallback(), GradNormCallback()])
    model.save(f"mppo_model_{i}")