In [1]:
%load_ext autoreload
%autoreload 2
#!pip install sb3-contrib

In [22]:
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback
import time
from collections import deque

In [26]:
class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

class GradNormCallback(BaseCallback):
    def _on_step(self):
        if hasattr(self.model.policy, 'parameters'):
            total_norm = 0
            for p in self.model.policy.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            
            self.logger.record("train/grad_norm", total_norm)
        return True

from stable_baselines3.common.callbacks import BaseCallback

# up the difficulty upon reaching certain (rolling) mean episode length thresholds
class CurriculumCallback(BaseCallback):
    def __init__(self, env, verbose = 1):
        super().__init__(verbose = verbose)
        self.env = env
        self.verbose = verbose
        self.thresholds = [7.5, 5] # episode length required before upping difficulty
        self.current_stage = 0
        self.ep_lengths = deque(maxlen=30)

    def _on_step(self):
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" not in info:
                continue
                
            ep_len = info["episode"]["l"]
            self.ep_lengths.append(ep_len)
            avg_len = sum(self.ep_lengths) / len(self.ep_lengths)
            if self.current_stage < len(self.thresholds) and avg_len > self.thresholds[self.current_stage]:
                self.current_stage += 1
                self.env.set_difficulty(self.current_stage)
                if self.verbose:
                    print(f"Average episode length: {avg_len:.2f} — Switched to difficulty: {self.current_stage}")
        return True

In [5]:
# !! Discussion below points to the constraint Simplex being a bug, not a problem with my env or training.
# These lines turn off many validation checks, including the Simplex one causing us problems.
#
# Alternatively, the discussions also suggest modifying the check threshold in torch/distributions/constraints.py:: class _Simplex.
#        # Current:
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
#        
#        # Fix:
#        tol = torch.finfo(value.dtype).eps * 10 * value.size(-1) ** 0.5
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < tol)
#
# Both seem to work. I'm going for the latter. Its possibly is making my training slower and less stable though.
#
# Discussion:
# https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209/9
# https://github.com/pytorch/pytorch/issues/87468
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81

# from torch.distributions import Distribution
# Distribution.set_default_validate_args(False)

In [27]:
from apad_env import APADEnv

In [28]:
env = APADEnv()

In [None]:
model.save("mppo_model_25k_2025-06-19_1500")

In [21]:
env = APADEnv()
env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    #n_steps = 512,
    tensorboard_log="./maskable_ppo_logs_9/",
    verbose=1,
)
total_timesteps = 200000
checkpoint_interval = 50000
for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
    else:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
    model.save(f"mppo_model_{9}")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_9/PPO_2
Step 1000, 23s elapsed, 1150s remaining
Step 2000, 47s elapsed, 1134s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.17     |
|    ep_rew_mean     | 1.06     |
| time/              |          |
|    fps             | 42       |
|    iterations      | 1        |
|    time_elapsed    | 48       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 72s elapsed, 1134s remaining
Step 4000, 95s elapsed, 1096s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.16        |
|    ep_rew_mean          | 1.06        |
| time/                   |             |
|    fps                  | 42          |
|    iterations           | 2           |
|    

KeyboardInterrupt: 

In [19]:
env = APADEnv(-1,-1)
env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    #n_steps = 512,
    tensorboard_log="./maskable_ppo_logs_9/",
    verbose=1,
)
model.learn(total_timesteps=50000, reset_num_timesteps=False, callback=[TimerCallback(), GradNormCallback()])
model.save(f"mppo_model_9")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_9/PPO_1
Step 1000, 23s elapsed, 1119s remaining
Step 2000, 46s elapsed, 1108s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.27     |
|    ep_rew_mean     | 1.11     |
| time/              |          |
|    fps             | 43       |
|    iterations      | 1        |
|    time_elapsed    | 47       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 72s elapsed, 1124s remaining
Step 4000, 95s elapsed, 1088s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.37        |
|    ep_rew_mean          | 1.17        |
| time/                   |             |
|    fps                  | 42          |
|    iterations           | 2           |
|    

KeyboardInterrupt: 

In [12]:
for i in range(4):
    env.reset()
    model = None
    model = MaskablePPO(
        "MlpPolicy",
        env,
        #n_steps = 512,
        tensorboard_log="./maskable_ppo_logs_9/",
        verbose=1,
    )
    model.learn(total_timesteps=50000, reset_num_timesteps=True, callback=[TimerCallback(), GradNormCallback()])
    model.save(f"mppo_model_{i}")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_7/PPO_4
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 2.64     |
|    ep_rew_mean     | 0.167    |
| time/              |          |
|    fps             | 39       |
|    iterations      | 1        |
|    time_elapsed    | 12       |
|    total_timesteps | 512      |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 1000, 26s elapsed, 1254s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.73        |
|    ep_rew_mean          | 0.182       |
| time/                   |             |
|    fps                  | 39          |
|    iterations           | 2           |
|    time_elapsed         | 26          |
|    total_timesteps      | 1024        |
| train/                  |             |

KeyboardInterrupt: 