In [8]:
import numpy as np

In [15]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count), dtype=np.int8)

    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count

        state[row, column] = player
        return state

    def get_valid_moves(self, state: np.array):
        return (state.reshape(-1) == 0).astype(np.uint8)

    def check_win(self, state, action):
        if action is None:
            return False
        
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]

        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[: column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count
            or np.sum(np.diag(np.flip(state, axis=0))) == player * self.row_count
        )
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player):
        return state * player


In [16]:
ttt = TicTacToe()
state = ttt.get_initial_state()

state = ttt.get_next_state(state, 8, 1)
state = ttt.get_next_state(state, 4, 1)
state = ttt.get_next_state(state, 0, 1)

moves = ttt.get_valid_moves(state)

print(moves)

ttt.get_value_and_terminated(state, 8)



[0 1 1 1 0 1 1 1 0]


(1, True)

In [17]:
arr = np.arange(9)
arr = arr.reshape((3,3))

# np.diag(arr)


arr, arr.diagonal(), np.flip(arr, axis=-1).diagonal()

(array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]),
 array([0, 4, 8]),
 array([2, 4, 6]))

In [18]:
import time


ttt = TicTacToe()

player = 1
state = ttt.get_initial_state()

while False and True:
    print("adjusted state:\n", np.flipud(state))
    valid_moves = ttt.get_valid_moves(state)
    print(f"valid_moves: ", [i + 1 for i in range(ttt.action_size) if valid_moves[i] == 1])

    # need to sleep otherwise sometimes the input pops up before the previous output
    time.sleep(.5)

    action = int(input(f"player {player} move: ")) - 1

    if valid_moves[action] == 0:
        print("action not valid")
        continue

    state = ttt.get_next_state(state, action, player)

    value, is_terminal = ttt.get_value_and_terminated(state, action)
    if is_terminal:
        if value == 1:
            print(f"player {player} won")
        else:
            print("draw")
        print("adjusted state:\n", np.flipud(state))
        print("true state:\n", state)

        break

    player = ttt.get_opponent(player)

In [19]:
import random

ttt = TicTacToe()

num_moves = 0
player = 1
state = ttt.get_initial_state()

while num_moves <= ttt.action_size:
    num_moves += 1
    print("adjusted state:\n", np.flipud(state))
    valid_moves = ttt.get_valid_moves(state)
    choices = [i for i in range(ttt.action_size) if valid_moves[i] == 1]
    print(f"adjusted valid_moves: {choices}")

    action = random.choice(choices)
    print(f"player {player} choice: {action} | move # {num_moves}")

    if valid_moves[action] == 0:
        print("action not valid")
        continue

    state = ttt.get_next_state(state, action, player)

    value, is_terminal = ttt.get_value_and_terminated(state, action)
    if is_terminal:
        if value == 1:
            print(f"player {player} won")
        else:
            print("draw")
        print("adjusted state:\n", np.flipud(state))
        print("true state:\n", state)

        break

    player = ttt.get_opponent(player)

adjusted state:
 [[0 0 0]
 [0 0 0]
 [0 0 0]]
adjusted valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
player 1 choice: 7 | move # 1
adjusted state:
 [[0 1 0]
 [0 0 0]
 [0 0 0]]
adjusted valid_moves: [0, 1, 2, 3, 4, 5, 6, 8]
player -1 choice: 5 | move # 2
adjusted state:
 [[ 0  1  0]
 [ 0  0 -1]
 [ 0  0  0]]
adjusted valid_moves: [0, 1, 2, 3, 4, 6, 8]
player 1 choice: 4 | move # 3
adjusted state:
 [[ 0  1  0]
 [ 0  1 -1]
 [ 0  0  0]]
adjusted valid_moves: [0, 1, 2, 3, 6, 8]
player -1 choice: 6 | move # 4
adjusted state:
 [[-1  1  0]
 [ 0  1 -1]
 [ 0  0  0]]
adjusted valid_moves: [0, 1, 2, 3, 8]
player 1 choice: 2 | move # 5
adjusted state:
 [[-1  1  0]
 [ 0  1 -1]
 [ 0  0  1]]
adjusted valid_moves: [0, 1, 3, 8]
player -1 choice: 3 | move # 6
adjusted state:
 [[-1  1  0]
 [-1  1 -1]
 [ 0  0  1]]
adjusted valid_moves: [0, 1, 8]
player 1 choice: 0 | move # 7
adjusted state:
 [[-1  1  0]
 [-1  1 -1]
 [ 1  0  1]]
adjusted valid_moves: [1, 8]
player -1 choice: 8 | move # 8
adjusted state:
 [[-1  1 -

In [20]:
from __future__ import annotations
import math


class Node:
    def __init__(
        self, game: TicTacToe, args, state, parent=None, action_taken=None
    ) -> None:
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = self.game.get_valid_moves(state)

        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0

    def select(self) -> Node:
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb

        return best_child

    def get_ucb(self, child: Node):
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args["C"] * math.sqrt(
            math.log(self.visit_count) / child.visit_count
        )
    
    def expand(self) -> Node:
        action = np.random.choice(np.where(self.expandable_moves == 1)[0])
        self.expandable_moves[action] = 0

        child_state = self.state.copy()
        child_state = self.game.get_next_state(child_state, action, 1)
        child_state = self.game.change_perspective(child_state, player=-1)
        
        child = Node(self.game, self.args, child_state, self, action)
        self.children.append(child)
        return child
    


class MonteCarloTreeSearch:
    def __init__(self, game:TicTacToe, args) -> None:
        self.game = game
        self.args = args

    def search(self, state):
        # define root node
        root = Node(self.game, self.args, state)

        for search in range(self.args["num_searches"]):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()

                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                value = self.game.get_opponent_value(value)

                if not is_terminal:
                    # expansion
                    node = node.expand()

                    # simulation
            # backprop

            # return visit counts

# pause video at 1:21:52

