## AlphaZero Implementation
self play full game

In [1]:
import warnings; warnings.simplefilter('ignore')

#### Import common library

In [2]:
import torch
import time
import shlex
import hashlib
from collections import defaultdict
import numpy as np
import dill
import random

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

#### Import MCTS library

In [4]:
from MCTS.utils import *
from MCTS.Agent import *
from MCTS.MCTS import *

#### Import game rules

In [5]:
import rules.Othello as Othello
# shorthands
OthelloGame   = Othello.OthelloGame   
OthelloHelper = Othello.OthelloHelper

#### Define hyperparameters

In [6]:
# MCTS search related
c = 1.2
allowed_time = 1

# Game specific
state_memory_n = 1
board_size = [8, 8]

#### Import and implement data structure for the game

In [7]:
class OthelloDataNode(ZeroDataNode):
    def __init__(self, name, Game=OthelloGame, player=1):
        super().__init__(Game=Game, name=name, player=player)
    # end def
# end class

#### Import a pre-trained model (on human expert dataset)

In [8]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # define functionals
        self.fc1     = torch.nn.Linear(64, 100)
        self.norm1   = torch.nn.BatchNorm1d(100)
        self.relu1   = torch.nn.ReLU()
        self.fc2     = torch.nn.Linear(100, 62)
        self.softmax = torch.nn.Softmax()
        self.sigmoid = torch.nn.Sigmoid()
    # end def
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.norm1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        out_p, out_v = torch.split(out, (61, 1), 1)
        out_p = self.softmax(out_p)
        out_v = self.sigmoid(out_v)
        out = torch.cat((out_p, out_v), 1)
        return out
    # end def
# end class

model_file = 'expert_prediction/data/models/oth_exp_pred-iter03500.dill'
with open(model_file, 'rb') as fin:
    net = dill.load(fin)
# end with

# copy to cpu and makes to evaluation mode
net = net.cpu().eval()

#### wrap the neuralnet in a predictor

In [9]:
def parse_move_probs(move_probs):
    '''parse neuralnet output to dict of vector'''
    label2move = {0: [0, 0], 1: [0, 1], 2: [0, 2], 3: [0, 3], 4: [0, 4], 5: [0, 5], 6: [0, 6], 
                  7: [0, 7], 8: [1, 0], 9: [1, 1], 10: [1, 2], 11: [1, 3], 12: [1, 4], 13: [1, 5], 
                  14: [1, 6], 15: [1, 7], 16: [2, 0], 17: [2, 1], 18: [2, 2], 19: [2, 3], 20: [2, 4], 
                  21: [2, 5], 22: [2, 6], 23: [2, 7], 24: [3, 0], 25: [3, 1], 26: [3, 2], 27: [3, 5], 
                  28: [3, 6], 29: [3, 7], 30: [4, 0], 31: [4, 1], 32: [4, 2], 33: [4, 5], 34: [4, 6], 
                  35: [4, 7], 36: [5, 0], 37: [5, 1], 38: [5, 2], 39: [5, 3], 40: [5, 4], 41: [5, 5], 
                  42: [5, 6], 43: [5, 7], 44: [6, 0], 45: [6, 1], 46: [6, 2], 47: [6, 3], 48: [6, 4], 
                  49: [6, 5], 50: [6, 6], 51: [6, 7], 52: [7, 0], 53: [7, 1], 54: [7, 2], 55: [7, 3], 
                  56: [7, 4], 57: [7, 5], 58: [7, 6], 59: [7, 7], 60: 'PASS'}
    move2label = {tuple(q): p for p, q in label2move.items()}

    z = copy.deepcopy(move_probs)
    z = {tuple(label2move[i]): z[i] for i in range(len(z))}
    z['PASS'] = z.pop(tuple('PASS'))
    return z
# end def

def nnet_pred(state_in):
    # TODO: currently only for black
    state_in = state_in[-1] # Othello uses only last step
    data_in = state_in.data.flatten()
    
    # assume white's strategy is the same as black
    data_in = data_in * state_in.player
    
    # call the net for prediction
    net_outputs = net(torch.from_numpy(np.array([data_in])).float())[0].detach().numpy()
    # split into move prob and winner pred
    P, v = net_outputs[:-1], net_outputs[-1] #np.split(net_outputs, [len(net_outputs)-1, ])    
    P = parse_move_probs(P)
    return P, v
# end def

#### Strategy to choose next move

In [None]:
def choose_next_move(node, temperature = 1):
    n = sum(node.N.values())
    exponent = 1. / temperature

    temp = [[val**exponent/n**exponent, key] for key, val in node.N.items()]
    _sum = 0
    r = random.random()
    sel = temp[-1][1]
    for i in range(len(temp)):
        _sum += temp[i][0]
        if _sum >= r:
            sel = temp[i][1]
            break
        # end if
    # end for
    
    return node.child_dict[sel]
# end def

#### Define exit conditions

In [None]:
def exit_cond(time0, time_thr):
    # time based
    if time.time() - time0 > time_thr:
        return True
    # end if
    
    # winning prob based
    # TODO
# end def

#### Main logic

In [None]:
def start_mcts(node, allowed_time):
    assert node.parent is None
    winner_value_dict = {1: 1, -1: 0, 0: 0.5}

    _iter = 0
    time0 = time.time()
    while True:
        _iter += 1
        if _iter % 100 == 0:
            pass
            #print('iteration: %d' % (_iter,))
        # end if
        
        # check exit condition
        if exit_cond(time0, allowed_time):
            #print('iteration: %d' % (_iter,))
            break
        # end if

        # dynamically expand the tree - search for first un-expanded node
        player = node.player
        det_end_game = False
        while True:
            # if end game, break loop and proceed to backprop
            _winner = OthelloGame.get_winner(node.state)
            if _winner is not None:
                det_end_game = True
                break
            # end if
            
            # exit loop if found the unexpanded node
            if node.expanded() is False:
                break
            else:
                # choose node
                # 1. compute Q value for every node
                N = sum(node.N.values())
                actions = list(node.Q.keys())
                U = {key: None for key in actions}
                V = []
                for a in actions:
                    U[a] = c*np.sqrt(N) / (1+node.N[a]) * node.P[a]
                    V.append(node.Q[a] + U[a])
                # end for

                # 2. choose action that maximizes V
                idx = np.argmax(V)
                action = actions[idx]
                node = node.child_dict[action]
            # end if
        # end while

        # if end_game, simply evaluate
        if det_end_game is True:
            if _winner == 0:
                node.backprob(0.5)
            elif _winner == player:
                node.backprob(1)
            else:
                node.backprob(0)
            # end if
        
            # go back to root
            node = node.root
            continue
        # end if

        # if not, expand and evaluate
        # - list all possible moves
        child_nodes = node.grow_branches() 
        # append legal child nodes
        node.append_children(child_nodes)

        # - neuralnet evaluation here
        P, v = nnet_pred(node.state)

        # back-propagation
        node.assign_probs(P)
        node.backprob(v)

        # go back to root
        node = node.root
    # end while
    
    return node
# end def

#### Initialize the board

In [None]:
new_board = OthelloHelper.new_board()

In [None]:
init_state = [stateType(data=new_board, player=1) for _ in range(state_memory_n)]

#### Declare a game record

In [None]:
game_record = []

In [None]:
player_name = {1: 'B', -1: 'W'}

#### Initialize a node

In [None]:
current_node = OthelloDataNode("root")
current_node.state = init_state

In [None]:
game_record.append(current_node.state)

In [None]:
for i in range(100):
    print('Step: %d' % (i+1,), end=' | ')
    # end game check
    _winner = OthelloGame.get_winner(current_node.state)
    if _winner is not None:
        print('Winner: %s' % (player_name[_winner]))
        break
    # end if
        
    # sample the node
    sampled_node = start_mcts(current_node, allowed_time=allowed_time)
    # choose the node
    next_node = choose_next_move(sampled_node)
    # print
    print(player_name[sampled_node.player], next_node.name, end=' | ')
    win_prob = sampled_node.W[next_node.name] / sampled_node.N[next_node.name]
    print('winning rate: %4.1f %%' % (win_prob*100,))
    
    # record
    game_record.append(next_node.state)
    OthelloHelper.print_board(game_record[-1][-1], highlight=next_node.action, outfilename='last_state.png')
    
    # next
    current_node = next_node
    current_node.parent = None # remove parent node before simulation
    current_node.reset_stats()
# end for

Step: 1 | B (5, 4) | winning rate: 54.9 %
Step: 2 | W (5, 5) | winning rate: 55.8 %
Step: 3 | B (4, 5) | winning rate: 55.6 %
Step: 4 | W (3, 5) | winning rate: 55.5 %
Step: 5 | B (2, 2) | winning rate: 55.7 %
Step: 6 | W (5, 2) | winning rate: 54.1 %
Step: 7 | B (5, 6) | winning rate: 55.6 %
Step: 8 | W (6, 5) | winning rate: 53.3 %
Step: 9 | B (4, 2) | winning rate: 52.9 %
Step: 10 | W (5, 7) | winning rate: 54.3 %
Step: 11 | B (3, 6) | winning rate: 54.6 %
Step: 12 | W (2, 3) | winning rate: 55.7 %
Step: 13 | B (1, 3) | winning rate: 55.8 %
Step: 14 | W (1, 2) | winning rate: 53.9 %
Step: 15 | B (6, 4) | winning rate: 53.6 %
Step: 16 | W (2, 7) | winning rate: 54.7 %
Step: 17 | B (2, 5) | winning rate: 54.0 %
Step: 18 | W (4, 1) | winning rate: 53.6 %
Step: 19 | B (5, 1) | winning rate: 51.8 %
Step: 20 | W (2, 6) | winning rate: 53.4 %
Step: 21 | B (4, 0) | winning rate: 52.6 %
Step: 22 | W (3, 2) | winning rate: 50.2 %
Step: 23 | B (2, 1) | winning rate: 53.8 %
Step: 24 | W (7, 4) 