In [1]:
%load_ext autoreload
%autoreload 2
#!pip install sb3-contrib

In [2]:
from sb3_contrib import MaskablePPO
from apad_env import APADEnv
from stable_baselines3.common.callbacks import BaseCallback
import time

In [3]:
class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

In [4]:
env = APADEnv()

In [None]:
# Adds penalty for overly confident predictions.
# Forces model to maintain some uncertainty, preventing probabilities from becoming too extreme (close to 0/1)
# where numerical errors accumulate.
entropy_coeff = 0.05

model = MaskablePPO(
    "MlpPolicy",
    env,
    learning_rate=1e-3,
    tensorboard_log="./maskable_ppo_logs/",
    verbose=1,
    #ent_coef=entropy_coeff,
    #batch_size=32
)

In [None]:
total_timesteps = 300000
checkpoint_interval = 50000

for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    
    hit_error = False
    try:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=TimerCallback())
    except ValueError as e:
        if "constraint Simplex" in str(e):
            print(f"Numerical precision error at step {model.num_timesteps}, continuing...")
            hit_error = True
        else:
            raise
    
    # Always save, but only reload if we hit numerical issues
    model.save(f"checkpoint_{i}")
    if hit_error:
        print("Reloading model to reset numerical drift...")
        model = MaskablePPO.load(f"checkpoint_{i}", env=env)

In [6]:
a, b = env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    learning_rate=1e-4,        # Extremely conservative
    max_grad_norm=0.1,         # Very tight gradient clipping
    ent_coef=0.001,            # Minimal entropy coefficient
    tensorboard_log="./maskable_ppo_logs_2/",
    verbose=1,
)

total_timesteps = 500000
checkpoint_interval = 25000

for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True)
    else:
        try:
            model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False)
        except ValueError as e:
            if "constraint Simplex" in str(e):
                print(f"Numerical precision error at step {model.num_timesteps}, continuing...")
            else:
                raise
    
    model.save(f"checkpoint_1_{i}")
    print(f"Saved checkpoint at {i + remaining_steps} steps")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_2/PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.79     |
|    ep_rew_mean     | 36.8     |
| time/              |          |
|    fps             | 43       |
|    iterations      | 1        |
|    time_elapsed    | 47       |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 3.6          |
|    ep_rew_mean          | 32.9         |
| time/                   |              |
|    fps                  | 42           |
|    iterations           | 2            |
|    time_elapsed         | 96           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0008493009 |
|    clip_fraction        | 0            |
|    clip_r

In [None]:
model.learn(total_timesteps=50000, callback=TimerCallback())
model.save("checkpoint_50k")

In [None]:
model.learn(total_timesteps=50000, callback=TimerCallback())
model.save("checkpoint_100k")

In [None]:
model = MaskablePPO.load("checkpoint_50k")

In [None]:
model.learn(total_timesteps=50000, callback=TimerCallback())
model.save("checkpoint_100k")

Let's keep going