Build the environement, first approach with tictactoe

In [1]:
import numpy as np

In [2]:
class TicTacToe:

    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))
    
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
        
    def check_win(self, state, action):

        if action is None:
            return False

        row = action // self.row_count
        column = action % self.column_count
        player = state[row, column]

        # check row
        if np.all(state[row, :] == player):
            return True

        # check column
        if np.all(state[:, column] == player):
            return True
        
        # check diagonal
        if row == column and np.all(np.diag(state) == player):
            return True

        # check anti-diagonal
        if row + column == self.row_count - 1 and np.all(np.diag(np.fliplr(state)) == player):
            return True
        
        return False

    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        return state * player



In [4]:
tictactoe = TicTacToe()
player = 1

state = tictactoe.get_initial_state()

while True:
    print(state)
    valid_moves = tictactoe.get_valid_moves(state)
    print("valid moves :", str([i for i in range(tictactoe.action_size) if i in valid_moves]))
    action = int(input(f"{player} :"))

    if valid_moves[action] == 0:
        print("invalid move")
        continue

    state = tictactoe.get_next_state(state, action, player)

    value, is_terminal = tictactoe.get_value_and_terminated(state, action)

    if is_terminal:
        print(state)
        if value == 1:
            print(f"{player} win")
        else:
            print("draw")
        break

    player = tictactoe.get_opponent(player)


[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves : [1]
[[1. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves : [0, 1]
[[ 1. -1.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
valid moves : [0, 1]
[[ 1. -1.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  0.]]
valid moves : [0, 1]
[[ 1. -1.  0.]
 [ 0.  1. -1.]
 [ 0.  0.  0.]]
valid moves : [0, 1]
[[ 1. -1.  0.]
 [ 0.  1. -1.]
 [ 0.  0.  1.]]
1 win


In [37]:
class Node:
    def __init__(self, game, args, state, parent= None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)

        self.value_sum = 0
        self.visit_count = 0
    
    def is_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.calculate_ucb(child)
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        
        return best_child

    def calculate_ucb(self, child):
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2 
        return q_value + self.args['C'] * np.sqrt(np.log(self.visit_count) / child.visit_count)

    def expand(self):


        action = np.random.choice(np.where(self.expandable_moves == 1)[0])

        self.expandable_moves[action] = 0

        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, player=-1)
        child = Node(self.game, self.args, child_state, self, action)
        self.children.append(child)

        return child
    
    def simulate(self):
        value, is_terminal = self.game.get_value_and_terminated(self.state, self.action_taken)
        value = self.game.get_opponent_value(value)

        if is_terminal:
            return value
        
        rollout_state = self.state.copy()
        rollout_player = 1
        while True:
            valid_moves = self.game.get_valid_moves(rollout_state)
            action = np.random.choice(np.where(valid_moves == 1)[0])
            rollout_state = self.game.get_next_state(rollout_state, action, rollout_player)
            value, is_terminal = self.game.get_value_and_terminated(rollout_state, action)
            value = self.game.get_opponent_value(value)

            if is_terminal:
                if rollout_player == -1:
                    value = self.game.get_opponent_value(value)
                return value
            
            rollout_player = self.game.get_opponent(rollout_player)
        
        return None
    
    def backpropagation(self, value):
        node = self
        while node is not None:
            node.visit_count += 1
            node.value_sum += value

            value = self.game.get_opponent_value(value)
            node = node.parent

class MCTS:

    def __init__(self, game, args):
        self.game = game
        self.args = args
    
    def search(self, state):

        root = Node(self.game, self.args, state)

        for search in range(self.args['num_searches']):

            node = root

            while node.is_expanded():

                node = node.select()

            
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)
            
            if not is_terminal:
                # expansion
                node = node.expand()
                # simulation
                value = node.simulate()
            
            # backpropagation (iterative)
            node.backpropagation(value)

        action_probs = np.zeros(self.game.action_size)
        for child in root.children:
            action_probs[child.action_taken] = child.visit_count
        action_probs = action_probs / np.sum(action_probs)
        
        return action_probs

In [38]:
tictactoe = TicTacToe()
player = 1

args = {
    'num_searches': 1000,
    'C': 1.41 # sqrt(2)
}

mcts = MCTS(tictactoe, args)

state = tictactoe.get_initial_state()

while True:
    print(state)

    if player == 1:

        valid_moves = tictactoe.get_valid_moves(state)
        print("valid moves :", str([i for i in range(tictactoe.action_size) if i in valid_moves]))
        action = int(input(f"{player} :"))

        if valid_moves[action] == 0:
            print("invalid move")
            continue
    else:
        neutral_state = tictactoe.change_perspective(state, player=-1)
        mcts_probs = mcts.search(state)
        action = np.argmax(mcts_probs)

    state = tictactoe.get_next_state(state, action, player)

    value, is_terminal = tictactoe.get_value_and_terminated(state, action)

    if is_terminal:
        print(state)
        if value == 1:
            print(f"{player} win")
        else:
            print("draw")
        break

    player = tictactoe.get_opponent(player)


[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
valid moves : [1]
[[0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 0.]]
[[ 0.  0.  0.]
 [-1.  0.  1.]
 [ 0.  0.  0.]]
valid moves : [0, 1]
[[ 0.  0.  1.]
 [-1.  0.  1.]
 [ 0.  0.  0.]]
[[ 0.  0.  1.]
 [-1.  0.  1.]
 [ 0.  0. -1.]]
valid moves : [0, 1]
[[ 0.  1.  1.]
 [-1.  0.  1.]
 [ 0.  0. -1.]]
[[-1.  1.  1.]
 [-1.  0.  1.]
 [ 0.  0. -1.]]
valid moves : [0, 1]
invalid move
[[-1.  1.  1.]
 [-1.  0.  1.]
 [ 0.  0. -1.]]
valid moves : [0, 1]
[[-1.  1.  1.]
 [-1.  1.  1.]
 [ 0.  0. -1.]]
[[-1.  1.  1.]
 [-1.  1.  1.]
 [-1.  0. -1.]]
-1 win
