# Imports

In [2]:
from __future__ import annotations

import os
import time
import glob
import pygame

import numpy as np
from numpy import copy

import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import EzPickle

from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy, MlpPolicy

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker

import pettingzoo
from pettingzoo import ParallelEnv, AECEnv
from pettingzoo.utils import wrappers
from pettingzoo.utils.agent_selector import agent_selector
from pettingzoo.classic import connect_four_v3
from pettingzoo.butterfly import knights_archers_zombies_v10


pygame 2.6.1 (SDL 2.28.4, Python 3.13.2)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Action Mask

In [3]:
class SB3ActionMaskWrapper(pettingzoo.utils.BaseWrapper, gym.Env):
    def reset(self, seed=None, options=None):
        super().reset(seed, options)

        self.observation_space = super().observation_space(self.possible_agents[0])["observation"]
        self.action_space = super().action_space(self.possible_agents[0])

        return self.observe(self.agent_selection), {}


    def step(self, action):
        current_agent = self.agent_selection

        super().step(action)

        next_agent = self.agent_selection
        return (
            self.observe(next_agent),
            self._cumulative_rewards[current_agent],
            self.terminations[current_agent],
            self.truncations[current_agent],
            self.infos[current_agent],
        )


    def observe(self, agent):
        return super().observe(agent)["observation"]


    def action_mask(self):
        return super().observe(self.agent_selection)["action_mask"]

# Environment

In [None]:
class TicTacToeAecEnv(AECEnv, EzPickle):
    metadata = {
        "name": "tttAec-v0",
        "is_parallelizable": False,
    }

    def __init__(self):
        super().__init__()
        self.board = np.zeros(9, dtype=np.int8)

        self.agents = ["p0", "p1"]
        self.possible_agents = self.agents[:]

        self.action_spaces = {i: spaces.Discrete(9) for i in self.agents}
        self.observation_spaces = {
            i: spaces.Dict({
                "observation": spaces.Box(low=0, high=1, shape=(3, 3, 2), dtype=np.int8),
                "action_mask": spaces.Box(low=0, high=1, shape=(3, 3), dtype=np.int8),
            })
            for i in self.agents
        }


    def observe(self, agent):
        board_vals = np.array(self.board).reshape(3, 3)
        cur_player = self.possible_agents.index(agent)
        opp_player = (cur_player + 1) % 2

        cur_p_board = np.equal(board_vals, cur_player + 1)
        opp_p_board = np.equal(board_vals, opp_player + 1)

        observation = np.stack([cur_p_board, opp_p_board], axis=2).astype(np.int8)
        legal_moves = self._legal_moves() if agent == self.agent_selection else []

        action_mask = np.zeros(9, "int8")
        for i in legal_moves:
            action_mask[i] = 1

        return {"observation": observation, "action_mask": action_mask}


    def _legal_moves(self):
        return [i for i in range(9) if self.board[i] == 0]


    def observation_space(self, agent):
        return self.observation_spaces[agent]


    def action_space(self, agent):
        return self.action_spaces[agent]


    def step(self, action):
        if (
            self.truncations[self.agent_selection]
            or self.terminations[self.agent_selection]
        ):
            print(f"Agent {self.agent_selection} tried to step in a terminated or truncated state.")
            return self._was_dead_step(action)

        assert self.board[0:9][action] == 0, "played illegal move."

        piece = self.agents.index(self.agent_selection) + 1
        self.board[action] = piece

        next_agent = self._agent_selector.next()

        winner = self.check_for_winner()

        if winner:
            self.rewards[self.agent_selection] += 1
            self.rewards[next_agent] -= 1
            self.terminations = {i: True for i in self.agents}
        elif not any(x == 0 for x in self.board):
            self.terminations = {i: True for i in self.agents}

        self.agent_selection = next_agent
        self._accumulate_rewards()


    def reset(self, seed=None, options=None):
        self.board = np.zeros(9, dtype=np.int8)

        self.agents = self.possible_agents[:]
        self.rewards = {i: 0 for i in self.agents}
        self._cumulative_rewards = {name: 0 for name in self.agents}
        self.terminations = {i: False for i in self.agents}
        self.truncations = {i: False for i in self.agents}
        self.infos = {i: {} for i in self.agents}

        self._agent_selector = agent_selector(self.agents)

        self.agent_selection = self._agent_selector.reset()


    def check_for_winner(self):
        board = np.reshape(self.board, (3, 3))
        for i in range(3):
            if np.all(board[i, :] == 1) or np.all(board[:, i] == 1):
                return True
            if np.all(board[i, :] == 2) or np.all(board[:, i] == 2):
                return True
        if np.all(np.diag(board) == 1) or np.all(np.diag(np.fliplr(board)) == 1):
            return True
        if np.all(np.diag(board) == 2) or np.all(np.diag(np.fliplr(board)) == 2):
            return True
        return False
    

# Training

In [5]:
def mask_fn(env: TicTacToeAecEnv):
    return env.action_mask()

In [6]:
env = TicTacToeAecEnv()
env = wrappers.OrderEnforcingWrapper(env)
env = SB3ActionMaskWrapper(env)
env.reset(seed=np.random.randint(0, 1000))
env = ActionMasker(env, mask_fn)

In [7]:
model = MaskablePPO(
    MaskableActorCriticPolicy, 
    env, 
    verbose=1,
    ent_coef=0.01,
)
model.set_random_seed(seed=np.random.randint(0, 1000))

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [8]:
model.learn(total_timesteps=10_000)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 7.78     |
|    ep_rew_mean     | 0.88     |
| time/              |          |
|    fps             | 416      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.44        |
|    ep_rew_mean          | 0.93        |
| time/                   |             |
|    fps                  | 347         |
|    iterations           | 2           |
|    time_elapsed         | 11          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.011186462 |
|    clip_fraction        | 0.0803      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.59       |
|    explained_variance   | -0.532      |
|    learning_rate        | 0.

<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x77b84b37af90>

In [27]:
obs, info = env.reset(seed=np.random.randint(0, 1000))
done = False

while not done:
    # If obs is not a dictionary, assume it directly contains the observation
    if isinstance(obs, dict):
        observation, action_mask = obs["observation"], obs["action_mask"]
    else:
        observation = obs
        action_mask = env.env.action_mask()  # Retrieve the action mask separately

    action, _states = model.predict(observation, action_masks=action_mask)

    obs, reward, termination, truncation, info = env.step(action)

    print(env.env.board.reshape(3, 3))

    done = termination or truncation
    if done:
        print(f"Game ended. Reward: {env.env.agent_selection}")
    print()

[[0 0 0]
 [0 0 0]
 [0 0 1]]

[[0 0 0]
 [2 0 0]
 [0 0 1]]

[[0 0 0]
 [2 0 0]
 [1 0 1]]

[[0 0 0]
 [2 0 0]
 [1 2 1]]

[[1 0 0]
 [2 0 0]
 [1 2 1]]

[[1 0 0]
 [2 0 2]
 [1 2 1]]

[[1 0 1]
 [2 0 2]
 [1 2 1]]

[[1 2 1]
 [2 0 2]
 [1 2 1]]

[[1 2 1]
 [2 1 2]
 [1 2 1]]
Game ended. Reward: p1



In [10]:
obs, info = env.reset(seed=np.random.randint(0, 1000))
done = False

print("You are playing as Player 0 (p0). The model is Player 1 (p1).")
print("Board positions are numbered as follows:")
print(np.arange(9).reshape(3, 3))  # Display board positions for reference
print()
print(env.env.board.reshape(3, 3), flush=True)  # Display initial board state

while not done:
    # If obs is not a dictionary, assume it directly contains the observation
    if isinstance(obs, dict):
        observation, action_mask = obs["observation"], obs["action_mask"]
    else:
        observation = obs
        action_mask = env.env.action_mask()  # Retrieve the action mask separately

    # Determine the current agent
    current_agent = env.env.agent_selection

    if current_agent == "p0":  # Human player's turn
        print("Your turn!", flush=True)

        inp = input("Enter your move (0-8): ")
        if inp.isdigit() and int(inp) in range(9) and action_mask[int(inp)] == 1:
            action = int(inp)
        else:
            print("Invalid move")
            break

    else:  # Model's turn
        print("Model's turn...", flush=True)
        action, _states = model.predict(
            observation, action_masks=action_mask, deterministic=True
        )

    # Take the chosen action
    obs, reward, termination, truncation, info = env.step(action)

    # Print the updated board state
    print(env.env.board.reshape(3, 3), flush=True)  # Display the board state after the move

    # Check if the game has ended
    done = termination or truncation
    if done:
        print("Game over!")
        if reward > 0:
            if current_agent == "p0":
                print("You win!")
            else:
                print("The model wins!")
        elif reward < 0:
            if current_agent == "p0":
                print("The model wins!")
            else:
                print("You win!")
        else:
            print("It's a draw!")

You are playing as Player 0 (p0). The model is Player 1 (p1).
Board positions are numbered as follows:
[[0 1 2]
 [3 4 5]
 [6 7 8]]

[[0 0 0]
 [0 0 0]
 [0 0 0]]
Your turn!


[[0 0 0]
 [0 1 0]
 [0 0 0]]
Model's turn...
[[0 0 0]
 [0 1 0]
 [2 0 0]]
Your turn!
[[1 0 0]
 [0 1 0]
 [2 0 0]]
Model's turn...
[[1 0 2]
 [0 1 0]
 [2 0 0]]
Your turn!
[[1 0 2]
 [0 1 0]
 [2 0 1]]
Game over!
You win!
