In [1]:
%load_ext autoreload
%autoreload 2
#!pip install sb3-contrib

In [2]:
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback
import time

In [13]:
class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

class GradNormCallback(BaseCallback):
    def _on_step(self):
        if hasattr(self.model.policy, 'parameters'):
            total_norm = 0
            for p in self.model.policy.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            
            self.logger.record("train/grad_norm", total_norm)
        return True

In [16]:
# !! Discussion below points to the constraint Simplex being a bug, not a problem with my env or training.
# These lines turn off many validation checks, including the Simplex one causing us problems.
#
# Alternatively, the discussions also suggest modifying the check threshold in torch/distributions/constraints.py:: class _Simplex.
#        # Current:
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
#        
#        # Fix:
#        tol = torch.finfo(value.dtype).eps * 10 * value.size(-1) ** 0.5
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < tol)
#
# Both seem to work. I'm going for the latter. Its possibly is making my training slower and less stable though.
#
# Discussion:
# https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209/9
# https://github.com/pytorch/pytorch/issues/87468
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81

# from torch.distributions import Distribution
# Distribution.set_default_validate_args(False)

In [26]:
from apad_env import APADEnv

In [32]:
env = APADEnv(-1,-1)

In [33]:
env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    n_steps = 128,
    tensorboard_log="./maskable_ppo_logs_5/",
    verbose=1,
)
model.learn(total_timesteps=25000, reset_num_timesteps=True, callback=[TimerCallback(), GradNormCallback()])
model.save(f"mppo_model_5")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_5/PPO_5
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.06     |
|    ep_rew_mean     | 2.06     |
| time/              |          |
|    fps             | 37       |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 128      |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.95        |
|    ep_rew_mean          | 1.95        |
| time/                   |             |
|    fps                  | 36          |
|    iterations           | 2           |
|    time_elapsed         | 7           |
|    total_timesteps      | 256         |
| train/                  |             |
|    approx_kl            | 0.018130466

KeyboardInterrupt: 

In [30]:
for i in range(4):
    env.reset()
    model = None
    model = MaskablePPO(
        "MlpPolicy",
        env,
        #ent_coef=0.05,
        n_steps = 128,
        tensorboard_log="./maskable_ppo_logs_5/",
        verbose=1,
    )
    model.learn(total_timesteps=25000, reset_num_timesteps=True, callback=[TimerCallback(), GradNormCallback()])
    model.save(f"mppo_model_{i}")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_5/PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.85     |
|    ep_rew_mean     | 1.85     |
| time/              |          |
|    fps             | 36       |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 128      |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.05       |
|    ep_rew_mean          | 2.05       |
| time/                   |            |
|    fps                  | 36         |
|    iterations           | 2          |
|    time_elapsed         | 7          |
|    total_timesteps      | 256        |
| train/                  |            |
|    approx_kl            | 0.01508404 |
|    cli

In [None]:
model.save("mppo_model_25k_2025-06-19_1500")