<a href="https://colab.research.google.com/github/keith-leung/cis667/blob/master/MCTS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# MCTS to solve tic-tac-toe
# state[r,c] is the content at position (r,c) of the board: "_", "X", or "O"
# TODO: implement exploration strategy
# TODO: increase number of rollouts
# TODO: implement interactive game-play with AI

import numpy as np

def state_string(state):
    return "\n".join(["".join(row) for row in state])

def score(state):
    for player, value in (("X", 1), ("O", -1)):
        if (state == player).all(axis=0).any(): return value
        if (state == player).all(axis=1).any(): return value
        if (np.diag(state) == player).all(): return value
        if (np.diag(np.rot90(state)) == player).all(): return value
    return 0

def get_player(state):
    return "XO"[
        np.count_nonzero(state == "O") < np.count_nonzero(state == "X")]

def children_of(state):
    symbol = get_player(state)
    children = []
    for r in range(state.shape[0]):
        for c in range(state.shape[1]):
            if state[r,c] == "_":
                child = state.copy()
                child[r,c] = symbol
                children.append(child)
    return list(reversed(children))

def is_leaf(state):
    children = children_of(state)
    value = score(state)
    return len(children) == 0 or value != 0


In [None]:
class Node:
    def __init__(self, state):
        self.state = state
        self.visit_count = 0
        self.score_total = 0
        self.score_estimate = 0
        self.child_list = None

    def children(self):
        if self.child_list == None:
            self.child_list = list(map(Node, children_of(self.state)))
        return self.child_list

    def N_values(self):
        return [c.visit_count for c in self.children()]

    def Q_values(self):
        children = self.children()
        sign = +1 if get_player(self.state) == "X" else -1
        Q = [sign * c.score_total / (c.visit_count+1) for c in children]
        # Q = [sign * c.score_total / max(c.visit_count, 1) for c in children]
        return Q

def exploit(node):
    return node.children()[np.argmax(node.Q_values())]

def explore(node):
    return node.children()[np.argmin(node.N_values())] # TODO

def uct(node):
    # max_c Qc + sqrt(ln(Np) / Nc)
    Q = np.array(node.Q_values())
    N = np.array(node.N_values())
    U = Q + np.sqrt( np.log(node.visit_count + 1) / (N + 1))
    return node.children()[np.argmax(U)]

# choose_child = exploit
# choose_child = explore
choose_child = uct

def rollout(node):
    if is_leaf(node.state): result = score(node.state)
    else: result = rollout(choose_child(node))
    node.visit_count += 1
    node.score_total += result
    node.score_estimate = node.score_total / node.visit_count
    return result


In [None]:
if __name__ == "__main__":

    state = np.array([["_"]*3]*3)
    state[0,0] = "X" # optimal, according to https://xkcd.com/832/
    #state[0,1] = "X" # suboptimal

    # gauge sub-optimality with rollouts
    node = Node(state)
    for r in range(10000): # TODO
        rollout(node)
        print(r, node.score_estimate)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
5000 0.03699260147970406
5001 0.03698520591763295
5002 0.03697781331201279
5003 0.036970423661071145
5004 0.03696303696303696
5005 0.03695565321614063
5006 0.03694827241861394
5007 0.0369408945686901
5008 0.03693351966460371
5009 0.036926147704590816
5010 0.036918778686888844
5011 0.036911412609736634
5012 0.03690404947137443
5013 0.03689668927004388
5014 0.036689930209371883
5015 0.03668261562998405
5016 0.036675303966513854
5017 0.036667995217218016
5018 0.036660689380354654
5019 0.036653386454183264
5020 0.03664608643696475
5021 0.03663878932696137
5022 0.03663149512243679
5023 0.03662420382165605
5024 0.036616915422885574
5025 0.03660962992439316
5026 0.03660234732444798
5027 0.03659506762132061
5028 0.036786637502485585
5029 0.036779324055666
5030 0.03677201351619956
5031 0.03676470588235294
5032 0.0367574011523942
5033 0.03675009932459277
5034 0.036941410129096325
5035 0.036934074662430504
5036 0.03692674210839786
5