# Introduction

In this notebook we will be implementing a Tree-Search algorithm for mastering a game of Go.

OpenAI Gym Environment: https://github.com/aigagror/GymGo

In [1]:
import gym
import torch
import numpy as np

## Random Play Tree

First step is to implement a Random Play Tree. In this implementation sequence of moves are organized in tree structure. Any node can be expanded. Random Play Tree choses random possible move at any turn and plays game unless no possible moves left.

In [2]:
from monte_carlo_tree import RandomPlayTree

BOARD_SIZE = 4

'''
Play a game using random tree strategy
'''
def random_play():
    
    tree = RandomPlayTree(BOARD_SIZE)
    
    root_node = tree.root_node
    terminal_node = tree.simulate(root_node)
    
    return (terminal_node.depth(), tree.evaluate_node(terminal_node))
    
'''
Play a number of random games and display result
'''
def build_random_play_stats(n_games=100):
    
    black_wins = 0
    white_wins = 0
    moves = []
    
    for _ in range(n_games):
        m, winner = random_play()
        if winner == 1:
            black_wins += 1
        else:
            white_wins += 1
        moves.append(m)
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Moves mean:", np.mean(moves))


In [3]:
%time build_random_play_stats(100)

Exception: ('Invalid Move', (0, 0), array([[[0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [1., 0., 0., 1.]],

       [[0., 1., 1., 0.],
        [1., 0., 1., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 1., 1., 0.],
        [1., 0., 0., 1.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]]))

## Monte Carlo Tree Search

Then we would sublclass Random Play Tree to implement all methods of Monte Carlo Tree Search algorithm.

MCTS involves following stages:

### 1. Simulate

Simulation is MCTS is a sequence of moves that starts in current node and ends in terminal node.
During simulation moves are chosen wrt **rollout policy function** which in usually uniform random.

### 2. Expand

* Expanded node: a playout has been started in this node
* Fully expanded node: if all children of node were visited

### 3. Rollout 

Once node has been expanded result and statistics are propagated all way back to root node 
through parent nodes.

Node Statistics:

* Q(v) - Total simulation reward
* N(v) - Total number of visits
* U(v) - Upper confidence bound

### 4. Select 

UCT is a core of MCTS. It allows us to choose next node among visited nodes.
    
Q_v/N_v - exploitattion component (favors nodes that were winning)
torch.sqrt(torch.log(N_v_parent)/N_v) - exploration component (favors node that weren't visited)
c - tradeoff

In competetive games Q always computed relative to player who moves.


In [4]:
from monte_carlo_tree import MonteCarloPlayTree

mcst = MonteCarloPlayTree(BOARD_SIZE)

'''
Play a game using MonteCarloSearchTree
'''
def mtsc_play(tree):
    
    root_node = tree.root_node
    terminal_node = tree.simulate(root_node)
    
    return (terminal_node.depth(), tree.evaluate_node(terminal_node))

'''
Play a number of random games and display result
'''
def build_mcst_stats(n_games=100):
    
    black_wins = 0
    white_wins = 0
    moves = []
    
    for counter in range(n_games):
        m, winner = mtsc_play(mcst)
        if winner == 1:
            black_wins += 1
        else:
            white_wins += 1
        moves.append(m)
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Moves mean:", np.mean(moves))

In [5]:
%time build_mcst_stats(10)

TypeError: move() missing 1 required positional argument: 'prob'

## Play Random Policy vs Monte Carlo Tree Search Policy

In [6]:
def match_random_mcts(tree):
    current_node = tree.root_node
    game_has_ended = False
    player = 1

    while not game_has_ended:
        if player == 1: # blacks
            action = tree.pick_move(current_node)   
        else:
            action = tree.rollout_policy(current_node)

        if action is None: # no possible moves
            game_has_ended = True
            break

        current_node = tree.move(current_node, action)
        player *= 1

    return (current_node.depth(), tree.evaluate_node(current_node))

def build_mcst_stats(n_games=100):
    
    black_wins = 0
    white_wins = 0
    moves = []
    
    for counter in range(n_games):
        m, winner = match_random_mcts(mcst)
        if winner == 1:
            black_wins += 1
        else:
            white_wins += 1
        moves.append(m)
    
    print("Blacks: ", black_wins, "Whites: ", white_wins, "Moves mean:", np.mean(moves))

In [7]:
build_mcst_stats(n_games=50)

TypeError: move() missing 1 required positional argument: 'prob'

## Guided Tree Search

### Neural Network

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        
        self.conv = nn.Conv2d(1, BOARD_SIZE**2, kernel_size=2, stride=1, bias=False)
        self.size = 2*2*BOARD_SIZE**2
        self.fc = nn.Linear(self.size,32)

        # Policy head
        self.fc_action1 = nn.Linear(32, 16)
        self.fc_action2 = nn.Linear(16, 9)
        
        # Critic head
        self.fc_value1 = nn.Linear(32, 8)
        self.fc_value2 = nn.Linear(8, 1)
        self.tanh_value = nn.Tanh()
        
    def forward(self, x):

        y = F.relu(self.conv(x))
        y = y.view(-1, self.size)
        y = F.relu(self.fc(y))
        
        # Policy head
        a = F.relu(self.fc_action1(y))
        a = self.fc_action2(a)
        # availability of moves
        avail = (torch.abs(x.squeeze())!=1).type(torch.FloatTensor)
        avail = avail.view(-1, 9)        
        # locations where actions are not possible, we set the prob to zero
        maxa = torch.max(a)
        # subtract off max for numerical stability (avoids blowing up at infinity)
        exp = avail*torch.exp(a-maxa)
        prob = exp/torch.sum(exp)
        
        # Critic head
        value = F.relu(self.fc_value1(y))
        value = self.tanh_value(self.fc_value2(value))
        return prob.view(3,3), value

policy = Policy()

### Train

In [None]:
(-1)**(-1-1)

## References

[1] https://int8.io/monte-carlo-tree-search-beginners-guide/