In [None]:
from chessrl.cnnextractor import CustomCNNExtractor
import chess_gymnasium_env

In [None]:
import os
import numpy as np

In [None]:
import gymnasium as gym

In [None]:
from sb3_contrib import MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
# This is a drop-in replacement for EvalCallback
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
from stable_baselines3.common.env_util import make_vec_env

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

In [None]:
def action_mask_fn(env):
    return env.unwrapped._get_action_mask()

In [None]:
def make_env():
    def _init():
        env = gym.make('chess_gymnasium_env/ChessEnv-v0')
        env = ActionMasker(env, action_mask_fn)
        return env
    return _init

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

# Number of parallel environments
n_envs = 12

# Create a vectorized environment
env = DummyVecEnv([make_env() for _ in range(n_envs)])

# OR for SubprocVecEnv (more efficient for computationally intensive tasks)
# env = SubprocVecEnv([make_env() for _ in range(n_envs)])

In [None]:
policy_kwargs = dict(
    features_extractor_class=CustomCNNExtractor,
    features_extractor_kwargs=dict(features_dim=256)
)

In [None]:
# env = gym.make('chess_gymnasium_env/ChessEnv-v0')#, render_mode="human")
# env = ActionMasker(env, action_mask_fn)

In [None]:
# Create a directory for TensorBoard logs
log_dir = './tb_logs'

In [None]:
model = MaskablePPO(
    "MlpPolicy", 
    env,
    policy_kwargs={
        "features_extractor_class": CustomCNNExtractor,  # Use the custom CNN extractor
        "features_extractor_kwargs": {"features_dim": 256},  # Output features dimension
    },
    gamma=0.995,              # Default gamma for PPO is 0.99, higher discount factor for long-term rewards
    # learning_rate=0.00025,  # Typical learning rate for PPO (can be tuned)
    # n_steps=2048,           # Number of steps to collect before updating the model (affects training stability)
    # batch_size=64,          # Size of the batch used to update the model (affects the stability and speed of learning)
    # n_epochs=10,            # Number of epochs to train on each batch (try increasing to fine-tune)
    # gae_lambda=0.95,        # The lambda for Generalized Advantage Estimation (higher means more variance reduction)
    # ent_coef=0.01,          # Coefficient for entropy regularization to encourage exploration
    # clip_range=0.1,         # Lower value for more conservative updates (reduces large policy shifts)
    # clip_range_vf=0.1,      # Clip range for the value function updates
    # vf_coef=0.5,            # Coefficient for value function loss (fine-tune if necessary)
    # max_grad_norm=0.5,      # Gradient clipping to prevent exploding gradients
    # tensorboard_log="logs", # TensorBoard logging directory (adjust to your needs)
    seed=32,                # Random seed for reproducibility
    verbose=1,              # Set verbosity level to 1 for progress logging
    tensorboard_log=log_dir
)
print(f"Model is running on device: {model.policy.device}")

In [None]:
model.policy

In [None]:
model.learn(5_000_000)

In [None]:
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False)

# model.save("ppo_mask")
# del model # remove to demonstrate saving and loading

# model = MaskablePPO.load("ppo_mask")

In [None]:
render_env = gym.make('chess_gymnasium_env/ChessEnv-v0', render_mode="human")
render_env = ActionMasker(render_env, action_mask_fn)

In [None]:
obs, _ = render_env.reset()
while True:
    # Retrieve current action mask
    action_masks = get_action_masks(render_env)
    action, _states = model.predict(obs, action_masks=action_masks)
    obs, reward, terminated, truncated, info = render_env.step(action)

    if terminated or truncated:
        break