In [54]:
import reversi_board
import numpy as np
import collections

In [344]:
class turnBoard(reversi_board.ReversiBoard):
    def __init__(self, turn=2):
        self.turn=turn
        super(turnBoard, self).__init__()
    
    def changeTurn(self, turn):
        #Maps 1->2 and 2->1
        return (-(turn-1)+2)
    
    def push(self, p):
        if p==-1: # pass turn
            self.turn=self.changeTurn(self.turn)
            return
        else:
            self.put_piece(p,self.turn)
            self.turn=self.changeTurn(self.turn)
            return
        
    
    def isGameOver(self):
        #Return 0 if draw, 1 if player whose turn it is wins, -1 if player whose turn it is loses.
        if len(self.placable_positions(1))!=0:
            return 0  # Not draw
        if len(self.placable_positions(2))!=0:
            return 0  # Not draw
        counts=collections.Counter(self.board)
        if counts[1]>counts[2]:
            if self.turn==1:
                return 1
            else:
                return -1
        else:
            if self.turn==2:
                return 1
            else:
                return -1
    
    def copy(self):
        boardCopy=turnBoard(self.turn)
        boardCopy.board=self.board.copy()
        return boardCopy
    
    def hash(self):
        return np.append(self.board,self.turn).tobytes()

In [332]:
from scipy.stats import dirichlet
import math
x=0.75
alpha=2
c_puct=4
class MCTS:
    def __init__(self):
        self.Q={}#Array for some given state as to the rewards for taking each action
        self.N={}#Array for some given state as to the number of times each action has been visited from state
        self.P={}#Policy vector for given state
    
    def search(self, s, nnet):
        gameOver=s.isGameOver()
        if gameOver!=0:  # Is game over?
            return -gameOver
        sh=s.hash()
        if sh not in self.P.keys():  # Not visited
            prediction=nnet.predict(np.reshape(s.board,(1,8,8),order='F'))
            self.P[sh]=prediction[0][0]
            v=prediction[1][0][0]
            #Add dirichlet noise
            self.P[sh]=np.array(self.P[sh])
            self.P[sh]=x*self.P[sh]+(1-x)*dirichlet.rvs(np.ones(len(self.P[sh]))*alpha)[0]
            self.N[sh]=np.zeros(64)
            self.Q[sh]=np.zeros(64)
            return -v
        
        max_u, best_square = -float("inf"), -1
        for square in s.placable_positions(s.turn):
            u=self.Q[sh][square]+c_puct*self.P[sh][square]*math.sqrt(sum(self.N[sh]))/(1+self.N[sh][square])
            if u>max_u:
                max_u=u
                best_square=square
        square=best_square
        sp=s.copy()
        sp.push(square)
        v=self.search(sp, nnet)
        self.Q[sh][best_square]=(self.N[sh][best_square]*self.Q[sh][best_square]+v)/(self.N[sh][best_square]+1)
        self.N[sh][best_square]+=1
        return -1

In [310]:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
class ReversiModel:
    def __init__(self):
        inx = x=Input((8,8))
        x=Reshape((8,8,1))(x)
        for _ in range(10):  # Create residual layers
            x=Conv2D(filters=64, kernel_size=(3,3), padding='same', data_format='channels_last')(x)
            x=BatchNormalization(axis=3)(x)
            x=Activation("relu")(x)
            
        res_out=x
        # Policy output
        x=Conv2D(filters=2, kernel_size=1, data_format='channels_last')(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        policy_out=Dense(8*8, activation="softmax", name="policy_out")(x)
        self.model=policy_out
        
        #Value output
        x=Conv2D(filters=1, kernel_size=1, data_format="channels_last")(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        value_out=Dense(1, activation='tanh', name='value_out')(x)
        self.model=Model(inx, [policy_out, value_out], name='reversi_model')
        self.model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer='adam')

In [311]:
model=ReversiModel().model

In [336]:
mcts=MCTS()
s=turnBoard(1)
for i in range(100):
    mcts.search(s, model)
print(mcts.P[s.hash()])
print(mcts.Q[s.hash()])
print(mcts.N[s.hash()])

[0.01253782 0.01469467 0.0166776  0.01600942 0.01454531 0.01541257
 0.01608569 0.01367494 0.01571432 0.01805421 0.01487616 0.01370318
 0.01534016 0.0147552  0.01666236 0.01665322 0.01385942 0.01647034
 0.01493962 0.01318355 0.01641467 0.02096492 0.01618015 0.01522931
 0.0134501  0.01300505 0.01869642 0.01249678 0.01332983 0.01805335
 0.01463793 0.02018527 0.0121688  0.01554972 0.01580851 0.01403049
 0.01656505 0.01209395 0.01411019 0.0136147  0.01906569 0.01787291
 0.01552927 0.01616079 0.02129428 0.01660992 0.01271245 0.01457276
 0.01286269 0.02025693 0.01512416 0.01313635 0.01764382 0.01411331
 0.01577373 0.01830101 0.02441996 0.01486207 0.01303879 0.01725069
 0.01562643 0.0144654  0.01443441 0.01443722]
[ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.         -0.95997622  0.          0.          0.
  0.          0.        

In [369]:
def executeEpisode(nnet):
    examples=[]
    s=turnBoard()
    mcts=MCTS()
    
    while True:
        for _ in range(2):
            mcts.search(s,nnet)
        pi=mcts.P[s.hash()]
        examples.append([s, pi, None])
        legalmoves=s.placable_positions(s.turn)
        legalprobs=np.take(pi,legalmoves)
        legalprobs/=sum(legalprobs)
        a=np.random.choice(legalmoves, p=legalprobs)
        s.push(a)
        if s.isGameOver():
            return s
            #update examples
            #return examples

In [370]:
s=executeEpisode(model)

In [373]:
s.print_board()

  a b c d e f g h
1 * * * * o * * *
2 o * o o o o * *
3 o * o o * * o *
4 o * o o * * o *
5 o * o o * * o *
6 o * o * * * * *
7 o * * * * * * *
8 o * * * * * o o
