In [1]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pydot
from networkx.drawing.nx_pydot import graphviz_layout
import torch
import pyro.distributions as dist

In [2]:
# 0 is nought
# 1 is cross
# -1 is empty

In [3]:
board = np.ones((3,3), dtype=np.int)*(-1)

In [4]:
board

array([[-1, -1, -1],
       [-1, -1, -1],
       [-1, -1, -1]])

In [5]:
def win(board, player):
    vertical_streak = any(all(board[i] == np.ones(3,dtype=np.int)*player) for i in range(3))
    horizontal_streak = any(all(board[:,i] == np.ones(3,dtype=np.int)*player) for i in range(3))
    diagonal_streak = all(np.diag(board) == np.ones(3,dtype=np.int)*player)
    other_diagonal_streak = all(np.diag(np.flip(board, axis=1)) == np.ones(3,dtype=np.int)*player)
    return any([vertical_streak,horizontal_streak,diagonal_streak,other_diagonal_streak])

In [6]:
(board == -1).sum()

9

In [7]:
def draw(board):
    # no empty positions left
    return not (win(board, 0) or win(board, 1)) and (board == -1).sum() <= 1

In [8]:
def move(board, player, position, in_place = True):
    i,j = position
    assert board[i,j] == -1
    if in_place:
        board[i,j] = player
        return board
    else:
        new_board = board.copy()
        new_board[i,j] = player
        return new_board

In [9]:
def all_possible_moves_from(board):
    i,j = np.where(board == -1)
    return list(zip(list(i),list(j)))

In [10]:
def random_move(board, player):
    # i = np.random.randint(3)
    # j = np.random.randint(3)
    # return move(board, player, (i,j))
    moves = all_possible_moves_from(board)
    n_possible_moves = len(moves)
    move_idx = np.random.choice(n_possible_moves,size=1)
    selected_move = moves[int(move_idx)]
    return move(board, player, selected_move)

In [11]:
def random_game(board):
    game = []
    turn = 1
    while not any([win(board, 0), win(board, 1), draw(board)]):
        try:
            board = random_move(board, turn)
        except AssertionError:
            continue
        turn = 1-turn
        game.append(board.copy())
    return game

In [12]:
def totuple(board):
    return tuple([tuple(row) for row in board])

In [13]:
def UCT(game_tree, node, parent_node = None):
    w = game_tree.nodes[node]['wins']
    n = game_tree.nodes[node]['n_simulations']
    # if a node hasn't been explored, we definitely want to go to that node
    if n == 0:
        return np.inf
    if parent_node is None:
        N = game_tree.in_edges(node)[0][0]['n_simulations']
    else:
        N = game_tree.nodes[parent_node]['n_simulations']
    c = np.sqrt(2)
    return w/np.float(n) + c*np.sqrt(np.log(N)/n)

In [14]:
def selection(game_tree, current_node):
    # from root node
    if current_node is None:
        current_node = totuple(starting_board)
    children = [e[1] for e in game_tree.out_edges(current_node)]
    path = [current_node]
    # and until we reach a leaf node
    while len(children) != 0:
        # if the current node hasn't been visited before
        if game_tree.nodes[current_node]['n_simulations'] == 0:
            # choose a child randomly
            favorite_child = np.random.choice(len(children),size=1)
        else:
            # go to child node that has greatest UCT
            favorite_child = np.argmax([UCT(game_tree, child_node, current_node) for child_node in children])
        current_node = children[int(favorite_child)]
        children = [e[1] for e in game_tree.out_edges(current_node)]
        path.append(current_node)
    return current_node, path

In [15]:
def expansion(current_node, path):
    # if current_node has been reached before
    if game_tree.nodes[current_node]['n_simulations'] != 0:
        player = 1-game_tree.nodes[current_node]['player']
        # add all other moves from parent node
        current_board = np.array(current_node)
        expansions = [totuple(move(current_board,player,pos,in_place=False)) for pos in all_possible_moves_from(current_board)]
        game_tree.add_nodes_from([(expansion, {'player':player,'wins':0, 'n_simulations':0}) for expansion in expansions])
        game_tree.add_edges_from([(current_node, expansion) for expansion in expansions])
        path.append(expansions[0])
        return expansions[0], path
    else:
        return current_node, path

In [16]:
def simulation(current_node):
    board = np.array(current_node)
    # if we didn't reach a terminal node
    if not any([win(board, 0), win(board, 1), draw(board)]):
        # simulate until we reach one
        game = random_game(np.array(current_node))
        board = game[-1]
    winner = win(board,0)*0 + win(board,1)*1 + draw(board)*0.5
    return winner

In [17]:
def backpropagation(game_tree, winner, path):
    for node in path:
        game_tree.nodes[node]['n_simulations'] += 1
        if game_tree.nodes[node]['player'] == winner:
            game_tree.nodes[node]['wins'] += 1
        if winner == 0.5:
            game_tree.nodes[node]['wins'] += winner

In [18]:
def MCTS(game_tree, starting_node = None, n_iter = 10):
    i = 0
    while i < n_iter:
        current_node, path = selection(game_tree, starting_node)
        board = np.array(current_node)
        # if current node doesn't end the game
        if not any([win(board, 0), win(board, 1), draw(board)]):
            current_node, path = expansion(current_node, path)
            winner = simulation(current_node)
        else:
            winner = win(board,0)*0 + win(board,1)*1 + draw(board)*0.5
        backpropagation(game_tree, winner, path)
        i += 1

In [19]:
game_tree = nx.DiGraph()
starting_board = np.ones((3,3), dtype=np.int)*(-1)
game_tree.add_nodes_from([(totuple(starting_board), {'player':0,'wins':0, 'n_simulations':0})])
MCTS(game_tree)

In [20]:
def sample_policy(game_tree, game_state):
    game_tree.nodes[totuple(game_state)]
    children = [e[1] for e in game_tree.out_edges(totuple(game_state))]
    # assuming the move with the most simulations made is best
    n_simulated = [game_tree.nodes[child]['n_simulations'] for child in children]
    #win_ratios = [game_tree.nodes[child]['wins']/game_tree.nodes[child]['n_simulations'] for child in children]
    policy = np.random.choice(np.where(np.array(n_simulated) == max(n_simulated))[0])
    return np.array(children[policy])

In [21]:
#pos = graphviz_layout(game_tree, prog="dot")
#
#%matplotlib qt
#nx.draw(game_tree, pos=pos)
#nx.draw_networkx_labels(game_tree, pos, win_ratio, font_size=5);

In [22]:
def click_coords2move(coords):
    x,y = coords
    x_norm = x/2.5*3
    y_norm = y/2.5*3
    j = int(x_norm)
    i = int(y_norm)
    return (i,j)

In [23]:
import time

In [24]:
class TicTacToe:
    def __init__(self, board):
        self.board = board
        self.game_state = self.board.get_array()
        self.game = [self.game_state]
        self.end = False
        self.cid = self.board.figure.canvas.mpl_connect('button_press_event', self)

    def __call__(self, event):
        if not self.end:
            # get player move
            input_move = click_coords2move((event.xdata, event.ydata))
            board_after_player_move = move(self.game_state, 1, input_move)
            self.game_state = board_after_player_move

            # render player move
            self.board.set_data(self.game_state)
            self.board.figure.canvas.draw()
            # check if anyone won
            if any([win(self.game_state, 0), win(self.game_state, 1), draw(self.game_state)]):
                winner = win(self.game_state,0)*0 + win(self.game_state,1)*1 + draw(self.game_state)*0.5
                if winner == 1:
                    self.board.axes.set_title(f'You win!')
                elif winner == 0:
                    self.board.axes.set_title(f'AI wins!')
                else:
                    self.board.axes.set_title(f'Draw!')
                    
                self.board.figure.canvas.draw()
                self.end = True

            # get AI move
            # if player has made a move the AI has never seen, add it to the game tree
            if totuple(self.game_state) not in game_tree.nodes:
                game_tree.add_nodes_from([(totuple(self.game_state),{'player':player,'wins':0, 'n_simulations':0})])
                game_tree.add_edges_from([(self.game[-1],self.game_state)])
            # run search from board after player's move
            MCTS(game_tree, totuple(self.game_state), 400)
            # sample AI move
            board_after_AI_move = sample_policy(game_tree, self.game_state)
            self.game_state = board_after_AI_move
            self.game.append(self.game_state)

            # render AI move
            self.board.set_data(self.game_state)
            self.board.figure.canvas.draw()

            # check if anyone won
            if any([win(self.game_state, 0), win(self.game_state, 1), draw(self.game_state)]):
                winner = win(self.game_state,0)*0 + win(self.game_state,1)*1 + draw(self.game_state)*0.5
                if winner == 1:
                    self.board.axes.set_title(f'You win!')
                elif winner == 0:
                    self.board.axes.set_title(f'AI wins!')
                else:
                    self.board.axes.set_title(f'Draw!')
                self.board.figure.canvas.draw()
                self.end = True
            #print(self.board.get_array())
        

In [25]:
%matplotlib qt

In [30]:
starting_board = np.ones((3,3), dtype=np.int)*(-1)
#starting_board = np.array([[ 1, -1, -1],[-1, -1, -1],[-1, -1,  0]])
fig = plt.figure()
ax = fig.add_subplot(111)

#ax.set_title('Click')
board = ax.imshow(starting_board, vmin=-1,vmax=1)
plt.axis('off')

tictactoe = TicTacToe(board)

plt.show()

Traceback (most recent call last):
  File "/home/folzd/anaconda3/lib/python3.7/site-packages/matplotlib/cbook/__init__.py", line 216, in process
    func(*args, **kwargs)
  File "<ipython-input-24-159cf8887101>", line 13, in __call__
    board_after_player_move = move(self.game_state, 1, input_move)
  File "<ipython-input-8-ab8bfc1f275b>", line 3, in move
    assert board[i,j] == -1
AssertionError


In [27]:
len(game_tree.nodes)

10

In [28]:
len(game_tree.edges)

9

In [29]:
np.unique(np.array(game_tree.nodes), axis=-1).shape

(10, 3, 3)