In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

# --------------------------------------------------------------------------------------
# Environment
# --------------------------------------------------------------------------------------

class TicTacToeEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 1}

    def __init__(self):
        super(TicTacToeEnv, self).__init__()
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1

        self.action_space = spaces.Discrete(9)
        self.observation_space = spaces.Box(low=0, high=1, shape=(3, 9), dtype=np.float32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
        return self._get_obs(), {}

    def step(self, action):
        # Invalid move
        if not (0 <= action < 9) or self.board.flatten()[action] != 0:
            return self._get_obs(), -1.0, True, False, {"message": "Invalid move"}

        # Agent move
        row, col = divmod(action, 3)
        self.board[row, col] = self.current_player

        # Check win or draw after agent's move
        if self.check_win(self.current_player):
            return self._get_obs(), 1.0, True, False, {"winner": self.current_player}

        if self.check_draw():
            return self._get_obs(), 0.5, True, False, {"message": "Draw"}

        # Opponent move (Player 2)
        available_moves = self.get_available_moves()
        if available_moves:
            opponent_action = self.rule_based_opponent()
            row, col = divmod(opponent_action, 3)
            self.board[row, col] = 3 - self.current_player

            if self.check_win(3 - self.current_player):
                return self._get_obs(), -1.0, True, False, {"winner": 3 - self.current_player}

            if self.check_draw():
                return self._get_obs(), 0.5, True, False, {"message": "Draw"}

        return self._get_obs(), 0.0, False, False, {}

    def _get_obs(self):
        obs = self.board.flatten()
        player1 = (obs == 1).astype(np.float32)
        player2 = (obs == 2).astype(np.float32)
        current = np.full((9,), self.current_player == 1, dtype=np.float32)
        return np.stack([player1, player2, current], axis=0)

    def check_win(self, player, board=None):
        b = self.board if board is None else board
        for i in range(3):
            if all(b[i, :] == player) or all(b[:, i] == player):
                return True
        return (b[0, 0] == b[1, 1] == b[2, 2] == player) or (b[0, 2] == b[1, 1] == b[2, 0] == player)

    def check_draw(self):
        return not np.any(self.board == 0) and not self.check_win(1) and not self.check_win(2)

    def get_available_moves(self):
        return [i for i, v in enumerate(self.board.flatten()) if v == 0]

    def rule_based_opponent(self):
        for move in self.get_available_moves():
            test_board = self.board.copy()
            r, c = divmod(move, 3)
            test_board[r, c] = 2
            if self.check_win(2, test_board):
                return move
            test_board[r, c] = 1
            if self.check_win(1, test_board):
                return move
        return np.random.choice(self.get_available_moves())

    def render(self):
        symbols = {0: ' ', 1: 'X', 2: 'O'}
        print("\n" + "-" * 13)
        for row in self.board:
            print("| " + " | ".join(symbols[val] for val in row) + " |")
            print("-" * 13)

# --------------------------------------------------------------------------------------
# Train the agent
# --------------------------------------------------------------------------------------

env = Monitor(TicTacToeEnv())

model = DQN(
    policy="MlpPolicy",
    env=env,
    learning_rate=1e-3,
    buffer_size=10000,
    learning_starts=1000,
    batch_size=64,
    gamma=0.95,
    verbose=1
)

model.learn(total_timesteps=100_000)
model.save("dqn_tictactoe")

mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=1000)
print(f"Mean reward over 1000 episodes: {mean_reward}")


# --------------------------------------------------------------------------------------
# Play against trained agent
# --------------------------------------------------------------------------------------

def get_human_move(env):
    while True:
        try:
            inp = input("Enter your move (1-9) or 'Q' to quit: ")
            if inp.lower() == 'q':
                print("Player quit the game.")
                return "QUIT"
            pos = int(inp) - 1
            if 0 <= pos <= 8 and pos in env.get_available_moves():
                return pos
            else:
                print("Invalid move.")
        except ValueError:
            print("Please enter a number from 1 to 9 or 'Q' to quit.")

def play_human_vs_agent():

    env = TicTacToeEnv()
    model = DQN.load("dqn_tictactoe")
    obs, _ = env.reset()
    
    print("You are X (Player 1). DQN Agent is O (Player 2).")
    print("Board index positions:\n 1 | 2 | 3\n---+---+---\n 4 | 5 | 6\n---+---+---\n 7 | 8 | 9")

    done = False
    while not done:
        env.render()
        if env.current_player == 1:
            action = get_human_move(env)
            if action == "QUIT":
                return
        else:
            print("Agent is thinking...")
            action, _states = model.predict(obs, deterministic=True)
        
        obs, reward, done, _, info = env.step(action)

    env.render()
    winner = info.get("winner")
    if winner == 1:
        print("You win!")
    elif winner == 2:
        print("DQN agent wins!")
    else:
        print("It is a draw.")

Using cpu device
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2        |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.999    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 7619     |
|    time_elapsed     | 0        |
|    total_timesteps  | 8        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.38     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.998    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 8207     |
|    time_elapsed     | 0        |
|    total_timesteps  | 19       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.58     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.997    |
| t

In [2]:
play_human_vs_agent()

You are X (Player 1). DQN Agent is O (Player 2).
Board index positions:
 1 | 2 | 3
---+---+---
 4 | 5 | 6
---+---+---
 7 | 8 | 9

-------------
|   |   |   |
-------------
|   |   |   |
-------------
|   |   |   |
-------------

-------------
|   |   |   |
-------------
|   | O |   |
-------------
|   |   | X |
-------------

-------------
|   |   | X |
-------------
|   | O | O |
-------------
|   |   | X |
-------------
Invalid move.

-------------
|   | O | X |
-------------
| X | O | O |
-------------
|   |   | X |
-------------

-------------
|   | O | X |
-------------
| X | O | O |
-------------
| O | X | X |
-------------

-------------
| X | O | X |
-------------
| X | O | O |
-------------
| O | X | X |
-------------
It is a draw.
