# Setup

In [1]:
%load_ext autoreload
%autoreload 2
#!pip install sb3-contrib

In [2]:
from sb3_contrib import MaskablePPO
from stable_baselines3.common.callbacks import BaseCallback
import time
from collections import deque

In [10]:
from stable_baselines3.common.utils import get_schedule_fn

In [3]:
class TimerCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.start_time = time.time()
    
    def _on_step(self):
        if self.num_timesteps % 1000 == 0:
            elapsed = time.time() - self.start_time
            rate = self.num_timesteps / elapsed
            remaining = (self.locals['total_timesteps'] - self.num_timesteps) / rate
            print(f"Step {self.num_timesteps}, {elapsed:.0f}s elapsed, {remaining:.0f}s remaining")
        return True

class GradNormCallback(BaseCallback):
    def _on_step(self):
        if hasattr(self.model.policy, 'parameters'):
            total_norm = 0
            for p in self.model.policy.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
            
            self.logger.record("train/grad_norm", total_norm)
        return True

from stable_baselines3.common.callbacks import BaseCallback

# up the difficulty upon reaching certain (rolling) mean episode length thresholds
class CurriculumCallback(BaseCallback):
    def __init__(self, env, verbose = 1):
        super().__init__(verbose = verbose)
        self.env = env
        self.verbose = verbose
        self.thresholds = [7.5, 5] # episode length required before upping difficulty
        self.current_stage = 0
        self.ep_lengths = deque(maxlen=30)

    def _on_step(self):
        infos = self.locals.get("infos", [])
        for info in infos:
            if "episode" not in info:
                continue
                
            ep_len = info["episode"]["l"]
            self.ep_lengths.append(ep_len)
            avg_len = sum(self.ep_lengths) / len(self.ep_lengths)
            if self.current_stage < len(self.thresholds) and avg_len > self.thresholds[self.current_stage]:
                self.current_stage += 1
                self.env.set_difficulty(self.current_stage)
                if self.verbose:
                    print(f"Average episode length: {avg_len:.2f} — Switched to difficulty: {self.current_stage}")
        return True

In [None]:
# !! Discussion below points to the constraint Simplex being a bug, not a problem with my env or training.
# These lines turn off many validation checks, including the Simplex one causing us problems.
#
# Alternatively, the discussions also suggest modifying the check threshold in torch/distributions/constraints.py:: class _Simplex.
#        # Current:
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
#        
#        # Fix:
#        tol = torch.finfo(value.dtype).eps * 10 * value.size(-1) ** 0.5
#        return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < tol)
#
# Both seem to work. I'm going for the latter. Its possibly is making my training slower and less stable though.
#
# Discussion:
# https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209/9
# https://github.com/pytorch/pytorch/issues/87468
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/81

# from torch.distributions import Distribution
# Distribution.set_default_validate_args(False)

In [14]:
from apad_env import APADEnv

In [15]:
env = APADEnv()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    tensorboard_log="./maskable_ppo_logs_15/",
    ent_coef= 0.2,
    #learning_rate=0.003,
    verbose=1,
)

total_timesteps = 800000
checkpoint_interval = 100000
for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
    else:
        progress = i / total_timesteps
        model.ent_coef = 0.01 + 0.19 * (1 - progress)
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
    model.save(f"mppo_model_v1")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_15/PPO_1
Step 1000, 27s elapsed, 2690s remaining
Step 2000, 54s elapsed, 2653s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 2.59     |
|    ep_rew_mean     | -0.0425  |
| time/              |          |
|    fps             | 37       |
|    iterations      | 1        |
|    time_elapsed    | 55       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 83s elapsed, 2692s remaining
Step 4000, 110s elapsed, 2644s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.2         |
|    ep_rew_mean          | -0.047      |
| time/                   |             |
|    fps                  | 36          |
|    iterations           | 2           |
|  

In [16]:
model.load("mppo_model_v1")
model.learn(total_timesteps=400000, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])

Logging to ./maskable_ppo_logs_15/PPO_1
Step 803000, 5s elapsed, 2s remaining
Step 804000, 31s elapsed, 15s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.24     |
|    ep_rew_mean     | -0.007   |
| time/              |          |
|    fps             | 38       |
|    iterations      | 1        |
|    time_elapsed    | 53       |
|    total_timesteps | 804864   |
| train/             |          |
|    grad_norm       | 0.399    |
---------------------------------
Step 805000, 60s elapsed, 29s remaining
Step 806000, 86s elapsed, 42s remaining
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.22        |
|    ep_rew_mean          | -0.0045     |
| time/                   |             |
|    fps                  | 37          |
|    iterations           | 2           |
|    time_elapsed         | 109         |
|    total_timesteps      | 806912      |
| train/        

<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x3296094e0>

# step 1: no date, 6 pieces = win
v0

In [7]:
env = APADEnv(-1,-1,2)
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    tensorboard_log="./maskable_ppo_logs_14/",
    ent_coef=0.1,
    learning_rate=0.003,
    verbose=1,
)
model.learn(total_timesteps=75000, callback=[TimerCallback(), GradNormCallback()])
model.save(f"mppo_model_v0_2")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./maskable_ppo_logs_14/PPO_15
Step 1000, 22s elapsed, 1618s remaining
Step 2000, 44s elapsed, 1596s remaining
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 5.72     |
|    ep_rew_mean     | -1.51    |
| time/              |          |
|    fps             | 45       |
|    iterations      | 1        |
|    time_elapsed    | 44       |
|    total_timesteps | 2048     |
| train/             |          |
|    grad_norm       | 0        |
---------------------------------
Step 3000, 68s elapsed, 1629s remaining
Step 4000, 90s elapsed, 1593s remaining
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 5.69       |
|    ep_rew_mean          | -1.35      |
| time/                   |            |
|    fps                  | 44         |
|    iterations           | 2          |
|    time_

# step 2:

## a) no date, 7 pieces
v0a

In [None]:
model = MaskablePPO.load("mppo_model_v0")
env = APADEnv(-1,-1,1)
model.set_env(env)
model.learn(100000, reset_num_timesteps=True)
model.save(f"mppo_model_v0a")

## b) day only, 6 pieces
v0b

In [None]:
model = MaskablePPO.load("mppo_model_v0")
env = APADEnv(-1,None,2)
model.set_env(env)
model.learn(100000, reset_num_timesteps=True)
model.save(f"mppo_model_v0b")

## c) month only, 6 pieces
v0c

In [None]:
model = MaskablePPO.load("mppo_model_v0")
env = APADEnv(None,-1,2)
model.set_env(env)
model.learn(100000, reset_num_timesteps=True)
model.save(f"mppo_model_v0c")

# Old

In [None]:
#total_timesteps = 100000
#checkpoint_interval = 25000
#for i in range(0, total_timesteps, checkpoint_interval):
#    remaining_steps = min(checkpoint_interval, total_timesteps - i)
#    if i == 0:
#        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
#    else:
#        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
#

In [None]:
env = APADEnv(-1)
#env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    tensorboard_log="./maskable_ppo_logs_10/",
    ent_coef=0.03,
    verbose=1,
)
total_timesteps = 300000
checkpoint_interval = 50000
for i in range(0, total_timesteps, checkpoint_interval):
    remaining_steps = min(checkpoint_interval, total_timesteps - i)
    if i == 0:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=True, callback=[TimerCallback(),GradNormCallback()])
    else:
        model.learn(total_timesteps=remaining_steps, reset_num_timesteps=False, callback=[TimerCallback(),GradNormCallback()])
    model.save(f"mppo_model_{11}")

In [None]:
env = APADEnv(-1,-1)
env.reset()
model = None
model = MaskablePPO(
    "MlpPolicy",
    env,
    #n_steps = 512,
    tensorboard_log="./maskable_ppo_logs_9/",
    verbose=1,
)
model.learn(total_timesteps=50000, reset_num_timesteps=False, callback=[TimerCallback(), GradNormCallback()])
model.save(f"mppo_model_9")

In [None]:
for i in range(4):
    env.reset()
    model = None
    model = MaskablePPO(
        "MlpPolicy",
        env,
        #n_steps = 512,
        tensorboard_log="./maskable_ppo_logs_9/",
        verbose=1,
    )
    model.learn(total_timesteps=50000, reset_num_timesteps=True, callback=[TimerCallback(), GradNormCallback()])
    model.save(f"mppo_model_{i}")

In [None]:
model.save("mppo_model_25k_2025-06-19_1500")