In [None]:
# install torch if necessary
# !pip install torch

In [None]:
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from copy import deepcopy
import logging

In [None]:
# Define the Gomoku environment
class GomokuEnv:
    def __init__(self, size=8):
        """Initialize the Gomoku environment with an 8x8 board."""
        self.size = size
        self.board = np.zeros((size, size), dtype=int)
        self.current_player = 1
        self.done = False
        self.winner = None

    def reset(self):
        """Reset the board and game state."""
        self.board = np.zeros((self.size, self.size), dtype=int)
        self.current_player = 1
        self.done = False
        self.winner = None
        return self.board.copy()

    def step(self, action):
        """Execute a move for the current player. Returns board, reward, is_done, info"""
        if self.done:
            raise Exception("Game is over")
        row, col = divmod(action, self.size)
        if self.board[row, col] != 0:
            raise Exception("Invalid move")
        self.board[row, col] = self.current_player
        if self._check_winner(row, col):
            self.done = True
            self.winner = self.current_player
            return self.board.copy(), 1, True, f'player {self.current_player} wins'
        if np.all(self.board != 0):
            self.done = True
            return self.board.copy(), 0, True, 'draw'
        self.current_player *= -1
        return self.board.copy(), 0, False, ''

    def _check_winner(self, row, col):
        """Check if the move at (row, col) wins the game."""
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        player = self.board[row, col]
        for dr, dc in directions:
            count = 1
            for i in range(1, 5):
                r, c = row + dr * i, col + dc * i
                if not (0 <= r < self.size and 0 <= c < self.size) or self.board[r, c] != player:
                    break
                count += 1
            for i in range(1, 5):
                r, c = row - dr * i, col - dc * i
                if not (0 <= r < self.size and 0 <= c < self.size) or self.board[r, c] != player:
                    break
                count += 1
            if count >= 5:
                return True
        return False

    def render(self):
        """Render the current board state."""
        print('   ' + ' '.join(str(i) for i in range(self.size)))
        logging.info('   ' + ' '.join(str(i) for i in range(self.size)))
        print('  +' + '-' * (self.size * 2 - 1) + '+')
        logging.info('  +' + '-' * (self.size * 2 - 1) + '+')
        for i in range(self.size):
            print(f'{i} |' + ' '.join(['X' if x == 1 else 'O' if x == -1 else '.' for x in self.board[i]]) + '|')
            logging.info(f'{i} |' + ' '.join(['X' if x == 1 else 'O' if x == -1 else '.' for x in self.board[i]]) + '|')

In [None]:
# Define the DQN Agent
class DQNAgent:
    def __init__(self, state_size, action_size, env, device, learning_rate=0.0001, gamma=0.95, epsilon=1.0, epsilon_decay=0.999, epsilon_min=0.05):
        """Initialize the DQN agent with hyperparameters."""
        self.state_size = state_size  # (2, size, size)
        self.action_size = action_size
        self.memory = deque(maxlen=10000)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.device = torch.device(device)
        self.model = self._build_model().to(self.device)
        self.target_model = deepcopy(self.model).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.env = env
        self.update_target_freq = 100
        self.step_counter = 0

    def _build_model(self):
        """Build the CNN architecture for Q-value approximation."""
        model = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * self.state_size[1] * self.state_size[2], 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, self.action_size)
        )
        return model

    def get_state(self, board, current_player):
        """Convert board to two-channel state (current player's pieces, opponent's pieces)."""
        player_pieces = (board == current_player).astype(np.float32)
        opponent_pieces = (board == -current_player).astype(np.float32)
        return np.stack([player_pieces, opponent_pieces], axis=0)

    def remember(self, state, action, reward, next_state, done):
        """Store experience in replay memory."""
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        """Select an action using epsilon-greedy policy."""
        available_actions = np.where(state[0, 0].flatten() + state[0, 1].flatten() == 0)[0]
        if np.random.rand() <= self.epsilon:
            return np.random.choice(available_actions)
        state_tensor = torch.tensor(state, dtype=torch.float32).to(self.device)
        act_values = self.model(state_tensor).cpu().detach().numpy()[0]
        act_values[state[0, 0].flatten() + state[0, 1].flatten() != 0] = float('-inf')
        return np.argmax(act_values)

    def replay(self, batch_size=64):
        """Train the model using a batch of experiences."""
        if len(self.memory) < batch_size:
            return
        minibatch = random.sample(self.memory, batch_size)
        self.step_counter += 1
        if self.step_counter % self.update_target_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())

        states = np.array([s for (s, a, r, ns, d) in minibatch])
        next_states = np.array([ns for (s, a, r, ns, d) in minibatch])
        states = torch.from_numpy(states).float().to(self.device)
        next_states = torch.from_numpy(next_states).float().to(self.device)
        actions = torch.tensor([a for (s, a, r, ns, d) in minibatch], dtype=torch.long).to(self.device)
        rewards = torch.tensor([r for (s, a, r, ns, d) in minibatch], dtype=torch.float32).to(self.device)
        dones = torch.tensor([d for (s, a, r, ns, d) in minibatch], dtype=torch.float32).to(self.device)

        q_values = self.model(states)
        next_q_values = self.target_model(next_states).detach()
        max_next_q = next_q_values.max(dim=1)[0]
        # next_states is oppoent's state so we need to reverse the reward by multiply a negative number
        targets = rewards + self.gamma * max_next_q * (1 - dones) * -0.9
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        loss = F.mse_loss(q_values, targets)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.001)  # Gradient clipping
        self.optimizer.step()
        logging.info(f'loss={loss.item()}')

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def save_model(self, filename):
        """Save the model weights."""
        torch.save(self.model.state_dict(), filename)

    def load_model(self, filename):
        """Load the model weights if available."""
        if os.path.exists(filename):
            self.model.load_state_dict(torch.load(filename))
            self.model.eval()

In [None]:
# Evaluation function
def evaluate(agent, env, num_games=100):
    """Evaluate the agent against a random opponent for 100 times. Returns the win rate of the agent"""
    wins = 0
    for _ in range(num_games):
        state = env.reset()
        done = False
        while not done:
            if env.current_player == 1:
                state_channels = agent.get_state(state, 1)
                state_channels = np.reshape(state_channels, [1, 2, env.size, env.size])
                action = agent.act(state_channels)
            else:
                available_actions = np.where(state.flatten() == 0)[0]
                action = np.random.choice(available_actions)
            state, _, done, info = env.step(action)
            if done and env.winner == 1:
                wins += 1
    return wins / num_games

In [None]:
# Training function
def train_dqn(episodes=10000000000, batch_size=64, model_filename='gomoku_dqn.pth'):
    """Train the DQN agent using self-play."""
    env = GomokuEnv(size=8)
    best_win_rate = 0
    state_size = (2, env.size, env.size)
    action_size = env.size ** 2
    agent = DQNAgent(state_size, action_size, env, "cpu")
    agent.load_model(model_filename)
    for e in range(episodes):
        state = env.reset()
        # Initial random move by player 1
        available_actions = np.where(state.flatten() == 0)[0]
        action = np.random.choice(available_actions)
        state, _, done, _ = env.step(action)
        if done:
            continue
        for time in range(8**2 - 1):
            current_player = env.current_player
            state_channels = agent.get_state(state, current_player)
            state_channels_reshaped = np.reshape(state_channels, [1, 2, env.size, env.size])
            action = agent.act(state_channels_reshaped)
            next_state, reward, done, info = env.step(action)
            next_current_player = env.current_player

            # If a player wins, the other loses. So we need to adjust last reward to a negative number to punish the loser.
            if done and env.winner is not None:
                # agent.memory[-1][2] = -1
                agent.memory[-1] = agent.memory[-1][0], agent.memory[-1][1], -2, agent.memory[-1][3], agent.memory[-1][4]

            next_state_channels = agent.get_state(next_state, next_current_player)
            next_state_channels_reshaped = np.reshape(next_state_channels, [1, 2, env.size, env.size])
            agent.remember(state_channels, action, reward, next_state_channels, done)
            state = next_state
            if done:
                logging.info(f"episode: {e}/{episodes}, time: {time}, e: {agent.epsilon}, info: {info}")
                break
        if len(agent.memory) > batch_size:
            agent.replay(batch_size)
        if e % 1000 == 0:
            win_rate = evaluate(agent, env)
            if win_rate >= best_win_rate:
                logging.info(f"Episode {e}, Win rate vs random: {win_rate}")
                print(f"Episode {e}, Win rate vs random: {win_rate}")
                agent.save_model(model_filename)
                env.render()
                best_win_rate = win_rate if win_rate != 1 else 0.99

In [None]:
if __name__ == "__main__":
    logging.basicConfig(filename='gomoku_dqn1.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    train_dqn()