In [73]:
import numpy as np
import random
import torch
from device import device

class OmokGame:
    def __init__(self):
        global device
        self.board = torch.zeros(15, 15, dtype=torch.float, device=device, requires_grad=False)
        self.turn = 1
        self.move_history = []

    def reset(self):
        self.board.zero_()
        self.turn = 1
        self.move_history = []
    
    def get_board(self):
        return self.board
    
    def get_legal_moves(self):
        global device
        rtr = torch.ones(15 * 15, dtype=torch.float, device=device, requires_grad=False)
        for move in self.move_history:
            rtr[move[0] * 15 + move[1]] = 0
        return rtr
    
    def get_turn(self):
        return self.turn
    
    def makeMove(self, x, y):
        if self.board[x, y] != 0:
            return False
        self.board[x, y] = self.turn
        self.move_history.append((x, y))
        self.turn = -self.turn
        return True
    
    def checkWin(self):
        lastMove = self.move_history[-1]
        # Search fron last move
        def UtilitySqCheck(x, y):
            if x < 0 or x >= 15 or y < 0 or y >= 15:
                return 0
            return self.board[x, y]
        
        last_Player = self.board[lastMove[0], lastMove[1]]

        dxl = [-1, 0, 1, -1, 1, -1, 0, 1]
        dyl = [-1, -1, -1, 0, 0, 1, 1, 1]

        for i in range(8):
            count = 1
            for j in range(1, 5):
                if UtilitySqCheck(lastMove[0] + dxl[i] * j, lastMove[1] + dyl[i] * j) == last_Player:
                    count += 1
                else:
                    break
            for j in range(1, 5):
                if UtilitySqCheck(lastMove[0] - dxl[i] * j, lastMove[1] - dyl[i] * j) == last_Player:
                    count += 1
                else:
                    break
        
            if count == 5:
                return last_Player
        return 0
    
    def play(self, x, y):
        if not self.makeMove(x, y):
            return False
        return self.checkWin()
    
    def __str__(self):
        ans = ""
        for i in range(15):
            for j in range(15):
                if self.board[i, j] == 1:
                    ans += "X"
                elif self.board[i, j] == -1:
                    ans += "O"
                else:
                    ans += "."
            ans += "\n"
        return ans
    
    def __repr__(self):
        return "Project Nexus : Omok Game"

    def copy(self):
        game = OmokGame()
        game.board = self.board.clone()
        game.turn = self.turn
        game.move_history = self.move_history.copy()
        return game

from configf import config
import configf
device = configf.device

In [74]:
class Node:
    def __init__(self, move, state, parent=None, prior=0):
        self.parent = parent
        self.move = move
        self.state = state
        self.children = []
        self.visits = 0
        self.score_sum = 0
        self.prior = prior
    
    def expand(self, policy):
        for move, prob in enumerate(policy):
            if prob > 0:
                new_state = self.state.copy()
                x, y = move // 15, move % 15
                new_state.makeMove(x, y)
                self.children.append(Node(move, new_state, self, prob))

    def is_leaf(self):
        return len(self.children) == 0
    
    def is_root(self):
        return self.parent == None
    
    def select(self, policy):
        global device
        # Calculate ucb, vectorlized
        # Assume policy sums up to 1
        q_values = torch.tensor([child.score_sum/child.visits if child.visits!=0 else 0 for child in self.children], dtype=torch.float, requires_grad=False, device=device)
        # print(f"Q value shape : {q_values.shape}")
        ucb = q_values + config['C'] * \
            (self.visits ** 0.5) / \
                (1 + torch.tensor([child.visits for child in self.children], dtype=torch.float, requires_grad=False, device=device))\
            * torch.tensor([child.prior for child in self.children], dtype=torch.float, requires_grad=False, device=device)
        return self.children[torch.argmax(ucb)]
    
    def backpropagate(self, score):
        self.visits += 1
        self.score_sum += score
        if not self.is_root():
            self.parent.backpropagate(-score)

In [75]:
class MCTS:
    def __init__(self, model):
        global device
        self.model = model.to(device)
    
    def search(self, state):
        global device
        model.eval()

        # First, use dirichlet noise to add some randomness to the root node
        root = Node(None, state)

        bd = torch.Tensor(state.get_board()).to(device).view(1, 1, 15, 15)
        actor, critic = model(bd)

        policy = torch.softmax(actor, dim=1).squeeze(0).detach().cpu().numpy()
        policy = (1 - config['EPSILON']) * policy + config['EPSILON'] * np.random.dirichlet([config['ALPHA']] * 225)


        valid_moves = state.get_legal_moves().cpu().numpy()
        policy = policy * valid_moves
        policy = policy / policy.sum()

        root.expand(policy)

        for search in range(config['MCTS_SEARCHES']):
            node = root
            while not node.is_leaf():
                node = node.select(policy)
            
            # Check if node is terminal
            gameResult = node.state.checkWin()
            if gameResult != 0:
                node.backpropagate(gameResult)
                continue
            if gameResult == 0:
                actor, critic = model(torch.Tensor(node.state.get_board()).to(device).view(1, 1, 15, 15))
                policy = torch.softmax(actor, dim=1).squeeze(0).detach().cpu().numpy()
                valid_moves = node.state.get_legal_moves().cpu().numpy()
                policy = policy * valid_moves
                policy = policy / policy.sum()

                value = critic.item()

                node.expand(policy)
                node.backpropagate(value)
        
        action_probs = torch.zeros(225, dtype=torch.float, requires_grad=False, device=device)
        for child in root.children:
            action_probs[child.move] = child.visits
        action_probs = action_probs / action_probs.sum()
        return action_probs

In [76]:
import torch.nn as nn
import torch.nn.functional as F

class Nexus_Small(nn.Module):
    def __init__(self):
        super(Nexus_Small, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.Actor = nn.Linear(64, 225)
        self.Critic = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.global_pool(x)
        x = x.view(-1, 64)
        return self.Actor(x), F.tanh(self.Critic(x))

In [77]:
model = Nexus_Small().to(device)
mcts = MCTS(model)

print(mcts.search(OmokGame()))  

tensor([0.0067, 0.0133, 0.0067, 0.0067, 0.0000, 0.0000, 0.0067, 0.0133, 0.0000,
        0.0067, 0.0067, 0.0067, 0.0000, 0.0067, 0.0000, 0.0067, 0.0000, 0.0067,
        0.0133, 0.0000, 0.0000, 0.0133, 0.0067, 0.0067, 0.0067, 0.0000, 0.0067,
        0.0000, 0.0067, 0.0067, 0.0000, 0.0000, 0.0067, 0.0133, 0.0067, 0.0067,
        0.0000, 0.0067, 0.0000, 0.0067, 0.0200, 0.0067, 0.0000, 0.0067, 0.0067,
        0.0000, 0.0067, 0.0067, 0.0067, 0.0067, 0.0000, 0.0000, 0.0067, 0.0067,
        0.0067, 0.0000, 0.0067, 0.0000, 0.0067, 0.0067, 0.0067, 0.0000, 0.0067,
        0.0067, 0.0067, 0.0067, 0.0000, 0.0067, 0.0000, 0.0000, 0.0000, 0.0067,
        0.0067, 0.0067, 0.0000, 0.0000, 0.0067, 0.0000, 0.0067, 0.0000, 0.0067,
        0.0000, 0.0067, 0.0000, 0.0067, 0.0000, 0.0000, 0.0067, 0.0000, 0.0000,
        0.0000, 0.0067, 0.0067, 0.0067, 0.0000, 0.0000, 0.0000, 0.0067, 0.0067,
        0.0267, 0.0067, 0.0067, 0.0067, 0.0000, 0.0067, 0.0067, 0.0000, 0.0067,
        0.0133, 0.0000, 0.0000, 0.0067, 