In [11]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import deque
from typing import Tuple, Dict, Optional
import random
import os

In [12]:
class GomokuEnv(gym.Env):
    metadata = {"render_modes": [], "render_fps": 1}

    def __init__(self, board_size: int = 15, agent_is_x: bool = True):
        self.board_size = board_size
        self.agent_is_x = agent_is_x
        self.board = np.zeros((board_size, board_size), dtype=int)
        self.action_space = gym.spaces.Discrete(board_size * board_size)
        self.observation_space = gym.spaces.Box(
            low=-1, high=1, shape=(board_size, board_size), dtype=int
        )

    def _count(self, player: int, r: int, c: int, dr: int, dc: int) -> int:
        count = 0
        nr, nc = r + dr, c + dc
        while 0 <= nr < self.board_size and 0 <= nc < self.board_size and self.board[nr, nc] == player:
            count += 1
            nr += dr
            nc += dc
        return count

    def _check_winner_from_move(self, r: int, c: int) -> int:
        player = self.board[r, c]
        if player == 0:
            return 0
        for dr, dc in [(0,1), (1,0), (1,1), (1,-1)]:
            total = 1 + self._count(player, r, c, dr, dc) + self._count(player, r, c, -dr, -dc)
            if total >= 5:
                return player
        return 0

    def _is_full(self) -> bool:
        return not (self.board == 0).any()

    def _get_obs(self) -> np.ndarray:
        return self.board.copy() if self.agent_is_x else -self.board.copy()

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[np.ndarray, dict]:
        super().reset(seed=seed)
        self.board = np.zeros((self.board_size, self.board_size), dtype=int)
        return self._get_obs(), {}

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
        if not (0 <= action < self.board_size * self.board_size):
            obs = self._get_obs()
            return obs, -10.0, True, False, {}

        r, c = divmod(action, self.board_size)
        if self.board[r, c] != 0:
            obs = self._get_obs()
            return obs, -10.0, True, False, {}

        player_symbol = 1 if self.agent_is_x else -1
        self.board[r, c] = player_symbol
        winner = self._check_winner_from_move(r, c)

        if winner == player_symbol:
            obs = self._get_obs()
            return obs, 1.0, True, False, {}
        if self._is_full():
            obs = self._get_obs()
            return obs, 0.0, True, False, {}

        opp_symbol = -1 if self.agent_is_x else 1
        valid = np.argwhere(self.board == 0)
        if len(valid) == 0:
            obs = self._get_obs()
            return obs, 0.0, True, False, {}

        idx = self.np_random.integers(len(valid))
        orow, ocol = valid[idx]
        self.board[orow, ocol] = opp_symbol
        winner = self._check_winner_from_move(orow, ocol)

        if winner == opp_symbol:
            reward = -1.0
            terminated = True
        elif self._is_full():
            reward = 0.0
            terminated = True
        else:
            reward = 0.0
            terminated = False

        obs = self._get_obs()
        return obs, reward, terminated, False, {}

    def get_valid_mask(self) -> np.ndarray:
        return (self.board.flatten() == 0)


In [13]:
class PolicyNetwork(nn.Module):
    def __init__(self, board_size: int = 15, hidden_dim: int = 256):
        super().__init__()
        self.board_size = board_size
        self.input_dim = board_size * board_size
        self.hidden_dim = hidden_dim

        self.actor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.input_dim)
        )

        self.critic = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        logits = self.actor(x)
        value = self.critic(x).squeeze(-1)
        return logits, value

In [14]:
class SelfPlayA2C:
    def __init__(
        self,
        board_size: int = 15,
        lr: float = 1e-4,
        gamma: float = 0.99,
        entropy_coef: float = 0.01,
        value_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        past_agents_window: int = 5
    ):
        self.board_size = board_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.gamma = gamma
        self.entropy_coef = entropy_coef
        self.value_coef = value_coef
        self.max_grad_norm = max_grad_norm

        self.policy = PolicyNetwork(board_size).to(self.device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

        self.past_policies = deque(maxlen=past_agents_window)

    def get_action(self, state: np.ndarray, valid_mask: np.ndarray):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
        valid_mask_tensor = torch.tensor(valid_mask, dtype=torch.bool).unsqueeze(0).to(self.device)

        logits, value = self.policy(state_tensor)
        masked_logits = logits.masked_fill(~valid_mask_tensor, -1e9)
        probs = torch.softmax(masked_logits, dim=-1)
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()

        return action.item(), log_prob, value, entropy

    def compute_returns(self, rewards, dones, final_value, gamma):
        returns = []
        R = final_value
        for r, d in zip(reversed(rewards), reversed(dones)):
            R = r + gamma * R * (1 - d)
            returns.insert(0, R)
        return returns

    def train_epoch(self, num_episodes: int = 100):
        total_loss = 0.0
        total_episodes = 0

        for _ in range(num_episodes):
            # Select opponent: 50% current, 50% past
            if len(self.past_policies) > 0 and random.random() < 0.5:
                opponent = random.choice(self.past_policies)
                agent_is_x = random.choice([True, False])
            else:
                opponent = None
                agent_is_x = True

            env = GomokuEnv(board_size=self.board_size, agent_is_x=agent_is_x)
            state, _ = env.reset()
            done = False

            log_probs = []
            values = []
            rewards = []
            entropies = []
            dones_list = []

            while not done:
                valid_mask = env.get_valid_mask()
                action, log_prob, value, entropy = self.get_action(state, valid_mask)
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                log_probs.append(log_prob)
                values.append(value)
                rewards.append(reward)
                entropies.append(entropy)
                dones_list.append(float(done))

                state = next_state

            # Bootstrap final value
            if done:
                final_value = 0.0
            else:
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
                _, final_value = self.policy(state_tensor)
                final_value = final_value.item()

            returns = self.compute_returns(rewards, dones_list, final_value, self.gamma)
            returns = torch.tensor(returns, dtype=torch.float32).to(self.device)
            values = torch.cat(values)
            advantages = returns - values

            log_probs = torch.cat(log_probs)
            entropies = torch.cat(entropies)

            actor_loss = -(log_probs * advantages.detach()).mean()
            critic_loss = advantages.pow(2).mean()
            entropy_loss = -entropies.mean()

            loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.optimizer.step()

            total_loss += loss.item()
            total_episodes += 1

        return total_loss / total_episodes

    def save_policy(self, path: str):
        torch.save(self.policy.state_dict(), path)

    def add_past_policy(self):
        self.past_policies.append(self._copy_policy())

    def _copy_policy(self):
        policy_copy = PolicyNetwork(self.board_size).to(self.device)
        policy_copy.load_state_dict(self.policy.state_dict())
        policy_copy.eval()
        return policy_copy

In [15]:
def evaluate_policy(policy, board_size: int = 15, episodes: int = 100):
    wins, draws, losses = 0, 0, 0
    device = next(policy.parameters()).device

    for _ in range(episodes):
        agent_is_x = random.choice([True, False])
        env = GomokuEnv(board_size=board_size, agent_is_x=agent_is_x)
        state, _ = env.reset()
        done = False

        while not done:
            valid_mask = env.get_valid_mask()
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            valid_mask_tensor = torch.tensor(valid_mask, dtype=torch.bool).unsqueeze(0).to(device)

            with torch.no_grad():
                logits, _ = policy(state_tensor)
                masked_logits = logits.masked_fill(~valid_mask_tensor, -1e9)
                action = masked_logits.argmax(dim=-1).item()

            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

        if reward == 1.0:
            wins += 1
        elif reward == 0.0:
            draws += 1
        else:
            losses += 1

    return wins / episodes, draws / episodes, losses / episodes

In [38]:
trainer = SelfPlayA2C(board_size=15, lr=1e-4)
device = trainer.device
os.makedirs("models", exist_ok=True)

# Elo setup
elo = EloEvaluator(k_factor=20.0)
elo.add_player("current", 1200.0)

past_checkpoints = []

EVAL_FREQ = 20      # episodes
ELO_FREQ = 100      # episodes
TOTAL_EPISODES = 2000

for episode in range(TOTAL_EPISODES):
    loss = trainer.train_epoch(num_episodes=1)  # 1 episode per step

    # --- Periodic Evaluation ---
    if episode % EVAL_FREQ == 0:
        # Evaluate vs random
        win_r, draw_r, loss_r = evaluate_agent(
            trainer.policy,
            lambda env: random_opponent(env),
            board_size=15,
            episodes=50,
            device=device
        )

        # Evaluate vs heuristic
        win_h, draw_h, loss_h = evaluate_agent(
            trainer.policy,
            lambda env: heuristic_opponent(env),
            board_size=15,
            episodes=50,
            device=device
        )

        print(
            f"Ep {episode:4d} | "
            f"Loss: {loss:7.4f} | "
            f"Rand: W{win_r:5.1%} D{draw_r:5.1%} L{loss_r:5.1%} | "
            f"Heur: W{win_h:5.1%} D{draw_h:5.1%} L{loss_h:5.1%} | "
            f"ε: {trainer.entropy_coef:.3f}"
        )

    # --- Elo Evaluation vs Past Models ---
    if episode % ELO_FREQ == 0 and episode > 0:
        # Save current as checkpoint
        ckpt_path = f"models/gomoku_ep_{episode:04d}.pth"
        torch.save(trainer.policy.state_dict(), ckpt_path)
        past_checkpoints.append(ckpt_path)

        # Only keep last 5 for efficiency
        if len(past_checkpoints) > 5:
            oldest = past_checkpoints.pop(0)
            if os.path.exists(oldest):
                os.remove(oldest)

        # Evaluate against all past checkpoints
        current_rating = elo.get_rating("current")
        for past_path in past_checkpoints[:-1]:  # exclude self
            name = os.path.basename(past_path).replace(".pth", "")
            if name not in elo.ratings:
                elo.add_player(name, 1200.0)

            opp_policy = PolicyNetwork(board_size=15).to(device)
            opp_policy.load_state_dict(torch.load(past_path, map_location=device))
            opp_policy.eval()

            wins = 0
            eval_eps = 20  # faster Elo
            for _ in range(eval_eps):
                agent_is_x = random.choice([True, False])
                env = GomokuEnv(board_size=15, agent_is_x=agent_is_x)
                obs, _ = env.reset()
                done = False
                while not done:
                    if env.agent_is_x == agent_is_x:
                        # current agent
                        valid_mask = (env.board.flatten() == 0)
                        s = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
                        vm = torch.tensor(valid_mask, dtype=torch.bool).unsqueeze(0).to(device)
                        with torch.no_grad():
                            logits, _ = trainer.policy(s)
                            action = logits.masked_fill(~vm, -1e9).argmax().item()
                        obs, r, term, trunc, _ = env.step(action)
                    else:
                        # past agent
                        valid_mask = (env.board.flatten() == 0)
                        s = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
                        vm = torch.tensor(valid_mask, dtype=torch.bool).unsqueeze(0).to(device)
                        with torch.no_grad():
                            logits, _ = opp_policy(s)
                            action = logits.masked_fill(~vm, -1e9).argmax().item()
                        obs, r, term, trunc, _ = env.step(action)
                    done = term or trunc
                if (agent_is_x and r == 1.0) or (not agent_is_x and r == -1.0):
                    wins += 1

            score = wins / eval_eps
            elo.update_rating("current", name, score)

        current_rating = elo.get_rating("current")
        print(f"Ep {episode:4d} | Elo: {current_rating:6.1f} (vs {len(past_checkpoints)-1} past models)")

    # --- Add to self-play pool periodically ---
    if episode % 50 == 0:
        trainer.add_past_policy()

Ep    0 | Loss:  3.9982 | Rand: W46.0% D 0.0% L54.0% | Heur: W50.0% D 0.0% L50.0% | ε: 0.010
Ep   20 | Loss: -3.9414 | Rand: W46.0% D 0.0% L54.0% | Heur: W58.0% D 0.0% L42.0% | ε: 0.010
Ep   40 | Loss:  3.9171 | Rand: W46.0% D 0.0% L54.0% | Heur: W50.0% D 0.0% L50.0% | ε: 0.010
Ep   60 | Loss:  4.4260 | Rand: W44.0% D 0.0% L56.0% | Heur: W50.0% D 0.0% L50.0% | ε: 0.010
Ep   80 | Loss: -3.7417 | Rand: W48.0% D 0.0% L52.0% | Heur: W60.0% D 0.0% L40.0% | ε: 0.010
Ep  100 | Loss:  3.8861 | Rand: W62.0% D 0.0% L38.0% | Heur: W62.0% D 0.0% L38.0% | ε: 0.010
Ep  100 | Elo: 1200.0 (vs 0 past models)
Ep  120 | Loss:  3.9477 | Rand: W48.0% D 0.0% L52.0% | Heur: W52.0% D 0.0% L48.0% | ε: 0.010
Ep  140 | Loss: -4.1917 | Rand: W58.0% D 0.0% L42.0% | Heur: W56.0% D 0.0% L44.0% | ε: 0.010
Ep  160 | Loss:  3.6548 | Rand: W52.0% D 0.0% L48.0% | Heur: W54.0% D 0.0% L46.0% | ε: 0.010
Ep  180 | Loss: -4.2723 | Rand: W54.0% D 0.0% L46.0% | Heur: W44.0% D 0.0% L56.0% | ε: 0.010
Ep  200 | Loss:  3.9486 | Ran

## Play against the bot

In [39]:
import cv2
import numpy as np
import torch

# ----------------------------
# Config
# ----------------------------
BOARD_SIZE = 15
CELL_SIZE = 30  # pixels
WINDOW_SIZE = BOARD_SIZE * CELL_SIZE
MODEL_PATH = "models/gomoku_ep_1900.pth"

# ----------------------------
# Load agent
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy = PolicyNetwork(board_size=BOARD_SIZE).to(device)
policy.load_state_dict(torch.load(MODEL_PATH, map_location=device))
policy.eval()

# ----------------------------
# Game state
# ----------------------------
env = GomokuEnv(board_size=BOARD_SIZE, agent_is_x=True)
obs, _ = env.reset()
game_over = False
message = "Click to place O (you play second). Press R to reset, Q to quit."

# ----------------------------
# Helper: check win
# ----------------------------
def check_win(board, r, c):
    if board[r, c] == 0:
        return 0
    player = board[r, c]
    for dr, dc in [(0,1), (1,0), (1,1), (1,-1)]:
        count = 1
        for sign in [1, -1]:
            nr, nc = r + sign*dr, c + sign*dc
            while 0 <= nr < BOARD_SIZE and 0 <= nc < BOARD_SIZE and board[nr, nc] == player:
                count += 1
                nr += sign*dr
                nc += sign*dc
        if count >= 5:
            return player
    return 0

# ----------------------------
# Draw board
# ----------------------------
def draw_board(img, board, message):
    img[:] = 255  # white background

    # Draw grid
    for i in range(1, BOARD_SIZE):
        cv2.line(img, (i * CELL_SIZE, 0), (i * CELL_SIZE, WINDOW_SIZE), (200, 200, 200), 1)
        cv2.line(img, (0, i * CELL_SIZE), (WINDOW_SIZE, i * CELL_SIZE), (200, 200, 200), 1)

    # Draw stones
    for i in range(BOARD_SIZE):
        for j in range(BOARD_SIZE):
            center = (j * CELL_SIZE + CELL_SIZE // 2, i * CELL_SIZE + CELL_SIZE // 2)
            if board[i, j] == 1:  # X (agent)
                cv2.circle(img, center, CELL_SIZE // 2 - 2, (0, 0, 0), -1)
            elif board[i, j] == -1:  # O (you)
                cv2.circle(img, center, CELL_SIZE // 2 - 2, (255, 255, 255), -1)
                cv2.circle(img, center, CELL_SIZE // 2 - 2, (0, 0, 0), 1)

    # Draw message
    cv2.putText(img, message, (10, WINDOW_SIZE + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 1)

# ----------------------------
# Main loop
# ----------------------------
img = np.ones((WINDOW_SIZE + 30, WINDOW_SIZE, 3), dtype=np.uint8) * 255
cv2.namedWindow("Gomoku - Play vs A2C Agent", cv2.WINDOW_AUTOSIZE)

def mouse_callback(event, x, y, flags, param):
    global obs, game_over, message
    if event == cv2.EVENT_LBUTTONDOWN and not game_over:
        if y >= WINDOW_SIZE:  # clicked on message area
            return
        col = x // CELL_SIZE
        row = y // CELL_SIZE
        if not (0 <= row < BOARD_SIZE and 0 <= col < BOARD_SIZE):
            return
        if env.board[row, col] != 0:
            message = "Invalid: cell occupied."
            return

        # Human move
        env.board[row, col] = -1
        winner = check_win(env.board, row, col)
        full = not (env.board == 0).any()

        if winner == -1:
            message = "You win! Press R to reset."
            game_over = True
        elif full:
            message = "Draw! Press R to reset."
            game_over = True
        else:
            # Agent move
            valid_mask = (env.board.flatten() == 0)
            state_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(device)
            valid_mask_tensor = torch.tensor(valid_mask, dtype=torch.bool).unsqueeze(0).to(device)
            with torch.no_grad():
                logits, _ = policy(state_tensor)
                masked_logits = logits.masked_fill(~valid_mask_tensor, -1e9)
                action = masked_logits.argmax(dim=-1).item()
            ar, ac = divmod(action, BOARD_SIZE)
            env.board[ar, ac] = 1
            obs = env._get_obs()
            winner = check_win(env.board, ar, ac)
            full = not (env.board == 0).any()
            if winner == 1:
                message = "Agent wins! Press R to reset."
                game_over = True
            elif full:
                message = "Draw! Press R to reset."
                game_over = True
            else:
                message = "Your turn: click to place O."

        draw_board(img, env.board, message)
        cv2.imshow("Gomoku - Play vs A2C Agent", img)

cv2.setMouseCallback("Gomoku - Play vs A2C Agent", mouse_callback)
draw_board(img, env.board, message)
cv2.imshow("Gomoku - Play vs A2C Agent", img)

print("Game started. Click to play. Press R to reset, Q to quit.")

while True:
    key = cv2.waitKey(1) & 0xFF
    if key == ord('q') or cv2.getWindowProperty("Gomoku - Play vs A2C Agent", cv2.WND_PROP_VISIBLE) < 1:
        break
    if key == ord('r'):
        env.reset()
        obs, _ = env.reset()
        game_over = False
        message = "Click to place O (you play second)."
        draw_board(img, env.board, message)
        cv2.imshow("Gomoku - Play vs A2C Agent", img)

cv2.destroyAllWindows()

Game started. Click to play. Press R to reset, Q to quit.


Random Opponent:   Win=100.00%, Draw=0.00%, Loss=0.00%
Heuristic Opponent: Win=100.00%, Draw=0.00%, Loss=0.00%


RuntimeError: Error(s) in loading state_dict for PolicyNetwork:
	Missing key(s) in state_dict: "actor.1.weight", "actor.1.bias", "actor.3.weight", "actor.3.bias", "actor.5.weight", "actor.5.bias", "critic.1.weight", "critic.1.bias", "critic.3.weight", "critic.3.bias", "critic.5.weight", "critic.5.bias". 
	Unexpected key(s) in state_dict: "net.1.weight", "net.1.bias", "net.3.weight", "net.3.bias", "net.5.weight", "net.5.bias". 