In [1]:
import reversi_board
import numpy as np
import collections
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [2]:
class turnBoard(reversi_board.ReversiBoard):
    def __init__(self, turn=2):
        self.turn=turn
        super(turnBoard, self).__init__()
    
    def changeTurn(self, turn):
        #Maps 1->2 and 2->1
        return (-(turn-1)+2)
    
    def push(self, p):
        if p==-1: # pass turn
            self.turn=self.changeTurn(self.turn)
            return
        else:
            self.put_piece(p,self.turn)
            self.turn=self.changeTurn(self.turn)
            return
        
    
    def isGameOver(self):
        #Return 0 if draw, 1 if player whose turn it is wins, -1 if player whose turn it is loses, 2 if not over
        if len(self.placable_positions(1))!=0:
            return 2  # Not draw
        if len(self.placable_positions(2))!=0:
            return 2  # Not draw
        counts=collections.Counter(self.board)
        if counts[1]>counts[2]:  
            if self.turn==1:
                return 1
            else:
                return -1
        elif counts[1]<counts[2]:
            if self.turn==2:
                return 1
            else:
                return -1
        else:
            return 0
    
    def copy(self):
        boardCopy=turnBoard(self.turn)
        boardCopy.board=self.board.copy()
        return boardCopy

In [3]:
from scipy.stats import dirichlet
import math
x=0.75
alpha=2
c_puct=4
class MCTS:
    def __init__(self):
        self.Q={}#Array for some given state as to the rewards for taking each action
        self.N={}#Array for some given state as to the number of times each action has been visited from state
        self.P={}#Policy vector for given state
    
    def search(self, s, nnet):
        ss=time.time()
        gameOver=s.isGameOver()
        if gameOver!=2:  # Is game over?
            return -gameOver
        sh=s.board.tobytes()
        if sh not in self.P.keys():  # Not visited
            prediction=nnet.predict(tf.convert_to_tensor(np.reshape(s.board,(1,8,8),order='F')))
            self.P[sh]=prediction[0][0]
            v=prediction[1][0][0]
            #Add dirichlet noise
            
            self.P[sh]=np.array(self.P[sh])
            self.P[sh]=x*self.P[sh]+(1-x)*dirichlet.rvs(np.ones(len(self.P[sh]))*alpha)[0]
            self.N[sh]=np.zeros(64)
            self.Q[sh]=np.zeros(64)
            return -v
        
        max_u, best_square = -float("inf"), -1
        for square in s.placable_positions(s.turn):
            u=self.Q[sh][square]+c_puct*self.P[sh][square]*math.sqrt(sum(self.N[sh]))/(1+self.N[sh][square])
            if u>max_u:
                max_u=u
                best_square=square
        square=best_square
        sp=s.copy()
        sp.push(square)
        v=self.search(sp, nnet)
        self.Q[sh][best_square]=(self.N[sh][best_square]*self.Q[sh][best_square]+v)/(self.N[sh][best_square]+1)
        self.N[sh][best_square]+=1
        return -1

In [4]:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
class ReversiModel:
    def __init__(self):
        inx = x=Input((8,8))
        x=Reshape((8,8,1))(x)
        for _ in range(10):  # Create residual layers
            x=Conv2D(filters=64, kernel_size=(3,3), padding='same', data_format='channels_last')(x)
            x=BatchNormalization(axis=3)(x)
            x=Activation("relu")(x)
            
        res_out=x
        # Policy output
        x=Conv2D(filters=2, kernel_size=1, data_format='channels_last')(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        policy_out=Dense(8*8, activation="softmax", name="policy_out")(x)
        self.model=policy_out
        
        #Value output
        x=Conv2D(filters=1, kernel_size=1, data_format="channels_last")(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        value_out=Dense(1, activation='tanh', name='value_out')(x)
        self.model=Model(inx, [policy_out, value_out], name='reversi_model')
        self.model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer='adam')

In [5]:
def executeEpisode(nnet):
    searchtime=0
    examples=[]
    s=turnBoard()
    mcts=MCTS()
    while True:
        for _ in range(5):
            mcts.search(s,nnet)
        pi=mcts.P[s.board.tobytes()]
        examples.append([s.board, pi, None])
        legalmoves=s.placable_positions(s.turn)
        if len(legalmoves)==0:
            a=-1
        else:
            legalprobs=np.take(pi,legalmoves)
            legalprobs/=sum(legalprobs)
            a=np.random.choice(legalmoves, p=legalprobs)
        s.push(a)
        gameover=s.isGameOver()
        if gameover!=2:
            return assignRewards(examples, gameover)

In [6]:
def assignRewards(examples, reward):
    for i in range(len(examples)-1,-1,-1):
        examples[i][2]=reward
        reward*=-1
    
    return examples

In [7]:
def policyIterSP():
    nnet=ReversiModel().model
    examples=[]
    for _ in range(1):
        for _ in range(2):
            episode=executeEpisode(nnet)
            if len(examples)==0:
                examples=episode
            else:
                examples=np.concatenate((examples, episode))
        new_nnet=clone_model(nnet)
        new_nnet.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer='adam')
        np.random.shuffle(examples)
        modelFit(examples,new_nnet)
        frac_win=pit(new_nnet, nnet)
        if frac_win>0.55:
            nnet=new_nnet
    return nnet

In [8]:
def modelFit(examples,nnet):
    a=[]
    b=[]
    c=[]
    for i in range(len(examples)):
        a.append(np.reshape(examples[i][0],(8,8),order='F'))
        b.append(examples[i][1])
        c.append(examples[i][2])
    a=np.array(a)
    b=np.array(b)
    c=np.array(c)
    nnet.fit(a,[b,c], epochs=10, batch_size=64)

In [9]:
def choose_move_gameplay(nnet, mcts, s):
    for _ in range(25):
        mcts.search(s,nnet)
    pi=mcts.P[s.board.tobytes()]
    legalmoves=s.placable_positions(s.turn)
    if len(legalmoves)==0:
        a=-1
    else:
        legalprobs=np.take(pi,legalmoves)
        a=legalmoves[np.argmax(legalprobs)]
    return a

In [10]:
def pit(nnet, new_nnet):
    record=[0,0,0]#new_nnet wins, nnet wins, draws
    for _ in range(2):
        s=turnBoard()
        while True:
            mcts=MCTS()
            new_mcts=MCTS()
            a=choose_move_gameplay(nnet, mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[0]+=1
                elif gameOver==-1:
                    record[1]+=1
                else:
                    record[2]+=1
                break
            a=choose_move_gameplay(new_nnet, new_mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[1]+=1
                elif gameOver==-1:
                    record[0]+=1
                else:
                    record[2]+=1
                break
        s=turnBoard()
        while True:
            mcts=MCTS()
            new_mcts=MCTS()
            a=choose_move_gameplay(new_nnet, new_mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[1]+=1
                elif gameOver==-1:
                    record[0]+=1
                else:
                    record[2]+=1
                break
            a=choose_move_gameplay(nnet, mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[0]+=1
                elif gameOver==-1:
                    record[1]+=1
                else:
                    record[2]+=1
                break
    return record[0]/sum(record)

In [11]:
import time
s=time.time()
policyIterSP()
print(time.time()-s)

KeyboardInterrupt: 