In [1]:
import numpy as np
import random
import torch
from tqdm.auto import *

class OmokGame:
    def __init__(self):
        global device
        self.board = torch.zeros(15, 15, dtype=torch.float, device=device, requires_grad=False)
        self.turn = 1
        self.move_history = []

    def reset(self):
        self.board.zero_()
        self.turn = 1
        self.move_history = []
    
    def get_board(self):
        return self.board
    
    def get_legal_moves(self):
        global device
        rtr = torch.ones(15 * 15, dtype=torch.float, device=device, requires_grad=False)
        for move in self.move_history:
            rtr[move[0] * 15 + move[1]] = 0
        return rtr
    
    def get_turn(self):
        return self.turn
    
    def makeMove(self, x, y):
        if self.board[x, y] != 0:
            return False
        self.board[x, y] = self.turn
        self.move_history.append((x, y))
        self.turn = -self.turn
        return True
    
    def checkWin(self):
        lastMove = self.move_history[-1]
        # Search fron last move
        def UtilitySqCheck(x, y):
            if x < 0 or x >= 15 or y < 0 or y >= 15:
                return 0
            return self.board[x, y]
        
        last_Player = self.board[lastMove[0], lastMove[1]]

        dxl = [-1, 0, 1, -1, 1, -1, 0, 1]
        dyl = [-1, -1, -1, 0, 0, 1, 1, 1]

        for i in range(8):
            count = 1
            for j in range(1, 5):
                if UtilitySqCheck(lastMove[0] + dxl[i] * j, lastMove[1] + dyl[i] * j) == last_Player:
                    count += 1
                else:
                    break
            for j in range(1, 5):
                if UtilitySqCheck(lastMove[0] - dxl[i] * j, lastMove[1] - dyl[i] * j) == last_Player:
                    count += 1
                else:
                    break
        
            # In this rule, only exactly 5 stones is considered a win
            if count == 5:
                return last_Player
        return 0
    
    def play(self, x, y):
        if not self.makeMove(x, y):
            return False
        if len(self.move_history) == 15 * 15:
            return 2
        return self.checkWin()
    
    def __str__(self):
        ans = ""
        for i in range(15):
            for j in range(15):
                if self.board[i, j] == 1:
                    ans += "X"
                elif self.board[i, j] == -1:
                    ans += "O"
                else:
                    ans += "."
            ans += "\n"
        return ans
    
    def __repr__(self):
        return "Project Nexus : Omok Game"

    def copy(self):
        game = OmokGame()
        game.board = self.board.clone()
        game.turn = self.turn
        game.move_history = self.move_history.copy()
        return game

from configf import config
import configf
device = configf.device

print(config)

{'C': 2, 'EPSILON': 0.5, 'ALPHA': 0.9, 'MCTS_SEARCHES': 1000, 'LearningRate': 0.001, 'TEMPERATURE': 0.8, 'BatchSize': 64, 'Iterations': 10, 'GamesPerIteration': 128, 'Epochs': 10}


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Node:
    def __init__(self, move, state, parent=None, prior=0):
        self.parent = parent
        self.move = move
        self.state = state
        self.children = []
        self.visits = 0
        self.score_sum = 0
        self.prior = prior
    
    def expand(self, policy):
        for move, prob in enumerate(policy):
            if prob > 0:
                new_state = self.state.copy()
                x, y = move // 15, move % 15
                new_state.makeMove(x, y)
                self.children.append(Node(move, new_state, self, prob))

    def is_leaf(self):
        return len(self.children) == 0
    
    def is_root(self):
        return self.parent == None
    
    def select(self, policy):
        global device
        # Calculate ucb, vectorlized
        # Assume policy sums up to 1
        q_values = torch.tensor([child.score_sum/child.visits if child.visits!=0 else 0 for child in self.children], dtype=torch.float, requires_grad=False, device=device)
        # print(f"Q value shape : {q_values.shape}")
        ucb = q_values + config['C'] * \
            (self.visits ** 0.5) / \
                (1 + torch.tensor([child.visits for child in self.children], dtype=torch.float, requires_grad=False, device=device))\
            * torch.tensor([child.prior for child in self.children], dtype=torch.float, requires_grad=False, device=device)
        return self.children[torch.argmax(ucb)]
    
    def backpropagate(self, score):
        self.visits += 1
        self.score_sum += score
        if not self.is_root():
            self.parent.backpropagate(-score)

In [3]:
class MCTS:
    def __init__(self, model):
        self.model = model
    
    def search(self, state):
        global device
        model.eval()

        # First, use dirichlet noise to add some randomness to the root node
        root = Node(None, state)

        bd = torch.Tensor(state.get_board()).to(device).view(1, 1, 15, 15)
        actor, critic = model(bd)

        policy = torch.softmax(actor, dim=1).squeeze(0).detach().cpu().numpy()
        policy = (1 - config['EPSILON']) * policy + config['EPSILON'] * np.random.dirichlet([config['ALPHA']] * 225)


        valid_moves = state.get_legal_moves().cpu().numpy()
        policy = policy * valid_moves
        policy = policy / policy.sum()

        root.expand(policy)

        for search in range(config['MCTS_SEARCHES']):
            node = root
            while not node.is_leaf():
                node = node.select(policy)
            
            # Check if node is terminal
            gameResult = node.state.checkWin()
            if gameResult != 0:
                node.backpropagate(gameResult)
                continue
            if gameResult == 0:
                actor, critic = model(torch.Tensor(node.state.get_board()).to(device).view(1, 1, 15, 15))
                policy = torch.softmax(actor, dim=1).squeeze(0).detach().cpu().numpy()
                valid_moves = node.state.get_legal_moves().cpu().numpy()
                policy = policy * valid_moves
                policy = policy / policy.sum()

                value = critic.item()

                node.expand(policy)
                node.backpropagate(value)
        
        action_probs = torch.zeros(225, dtype=torch.float, requires_grad=False, device=device)
        for child in root.children:
            action_probs[child.move] = child.visits
        action_probs = action_probs / action_probs.sum()
        return action_probs

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class Nexus_Small(nn.Module):
    def __init__(self):
        super(Nexus_Small, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.Actor = nn.Linear(64, 225)
        self.Critic = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.global_pool(x)
        x = x.view(-1, 64)
        return self.Actor(x), F.tanh(self.Critic(x))

In [5]:
import datetime
def DateTimeGet():
    return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(DateTimeGet())

2024-10-22_15-46-24


In [6]:
class Nexus_Agent:
    def __init__(self, model, optimizer):
        global device
        self.model = model
        model.to(device)
        self.optimizer = optimizer
        self.mcts = MCTS(model)
    
    def selfPlay(self):
        memory = []
        board = OmokGame()

        max_moves = 15 * 15
        move_count = 0

        while move_count < max_moves:
            action_probs = self.mcts.search(board)
            memory.append((board.get_board().clone(), action_probs.copy(), board.get_turn()))

            # Apply temperature
            policy = action_probs ** (1 / config['TEMPERATURE'])
            policy /= policy.sum()

            # Select action
            action = np.random.choice(len(policy), p=policy)
            x, y = action // 15, action % 15

            # Play the move
            played = board.play(x, y)
            move_count += 1

            if played == 1 or played == -1:
                whoWon = played
                break
            elif played == 2:
                whoWon = 0  # Draw
                break
        else:
            whoWon = 0  # Draw if max_moves reached

        # Assign outcomes
        returnMemory = []
        for state, probs, turn in memory:
            if whoWon == 0:
                outcome = 0
            else:
                outcome = whoWon * turn
            returnMemory.append((state.clone(), probs.copy(), outcome))
        return returnMemory

    def train(self, memory):
        random.shuffle(memory)
        for batch_index in range(0, len(memory), config['BatchSize']):
            sample = memory[batch_index:min(batch_index + config['BatchSize'], len(memory))]
            state, probs, outcome = zip(*sample)
            state = torch.stack(state).to(device).float()
            policy_target = torch.stack(probs).to(device).float()
            value_target = torch.tensor(outcome, dtype=torch.float, device=device)

            model.train()

            # forward pass
            actor, critic = self.model(state)

            actor_loss = F.kl_div(F.log_softmax(actor, dim=1), policy_target, reduction='batchmean')
            critic_loss = F.mse_loss(critic.squeeze(1), value_target)

            loss = actor_loss + critic_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def learn(self):
        for iteration in trange(config['Iterations']):
            memory = []

            model.eval()
            for game in range(config['GamesPerIteration']):
                memory += self.selfPlay()

            model.train()
            for epoch in range(config['Epochs']):
                self.train(memory)
            
            timestamp = DateTimeGet()
            torch.save(self.model.state_dict(), f"model_{timestamp}.pt")
            torch.save(self.optimizer.state_dict(), f"optimizer_{timestamp}.pt")

In [7]:
model = Nexus_Small()
optimizer = torch.optim.Adam(model.parameters(), lr=config['LearningRate'])
agent = Nexus_Agent(model, optimizer)

In [8]:
agent.learn()

  policy_TNSR = torch.tensor(policy, dtype=torch.float, device=device)
  0%|          | 0/10 [02:57<?, ?it/s]


KeyboardInterrupt: 