In [1]:
import reversi_board
import numpy as np
import collections
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [2]:
class turnBoard(reversi_board.ReversiBoard):
    def __init__(self, turn=2):
        self.turn=turn
        super(turnBoard, self).__init__()
    
    def changeTurn(self, turn):
        #Maps 1->2 and 2->1
        return (-(turn-1)+2)
    
    def push(self, p):
        if p==-1: # pass turn
            self.turn=self.changeTurn(self.turn)
            return
        else:
            self.put_piece(p,self.turn)
            self.turn=self.changeTurn(self.turn)
            return
        
    
    def isGameOver(self):
        #Return 0 if draw, 1 if player whose turn it is wins, -1 if player whose turn it is loses, 2 if not over
        if len(self.placable_positions(1))!=0:
            return 2  # Not draw
        if len(self.placable_positions(2))!=0:
            return 2  # Not draw
        counts=collections.Counter(self.board)
        if counts[1]>counts[2]:  
            if self.turn==1:
                return 1
            else:
                return -1
        elif counts[1]<counts[2]:
            if self.turn==2:
                return 1
            else:
                return -1
        else:
            return 0
    
    def copy(self):
        boardCopy=turnBoard(self.turn)
        boardCopy.board=self.board.copy()
        return boardCopy

In [3]:
from scipy.stats import dirichlet
import math
x=0.75
alpha=2
c_puct=4
class MCTS:
    def __init__(self):
        self.Q={}#Array for some given state as to the rewards for taking each action
        self.N={}#Array for some given state as to the number of times each action has been visited from state
        self.P={}#Policy vector for given state
    
    def search(self, s, nnet):
        gameOver=s.isGameOver()
        if gameOver!=2:  # Is game over?
            return -gameOver
        sh=s.board.tobytes()
        if sh not in self.P.keys():  # Not visited
            prediction=nnet.predict(tf.convert_to_tensor(np.reshape(s.board,(1,8,8),order='F')))
            self.P[sh]=prediction[0][0]
            v=prediction[1][0][0]
            #Add dirichlet noise
            
            self.P[sh]=np.array(self.P[sh])
            self.P[sh]=x*self.P[sh]+(1-x)*dirichlet.rvs(np.ones(len(self.P[sh]))*alpha)[0]
            self.N[sh]=np.zeros(64)
            self.Q[sh]=np.zeros(64)
            return -v
        
        max_u, best_square = -float("inf"), -1
        for square in s.placable_positions(s.turn):
            u=self.Q[sh][square]+c_puct*self.P[sh][square]*math.sqrt(sum(self.N[sh]))/(1+self.N[sh][square])
            if u>max_u:
                max_u=u
                best_square=square
        square=best_square
        sp=s.copy()
        sp.push(square)
        v=self.search(sp, nnet)
        self.Q[sh][best_square]=(self.N[sh][best_square]*self.Q[sh][best_square]+v)/(self.N[sh][best_square]+1)
        self.N[sh][best_square]+=1
        return -1

In [4]:
from keras.models import *
from keras.layers import *
from keras.optimizers import *
class ReversiModel:
    def __init__(self):
        inx = x=Input((8,8))
        x=Reshape((8,8,1))(x)
        for _ in range(10):  # Create residual layers
            x=Conv2D(filters=64, kernel_size=(3,3), padding='same', data_format='channels_last')(x)
            x=BatchNormalization(axis=3)(x)
            x=Activation("relu")(x)
            
        res_out=x
        # Policy output
        x=Conv2D(filters=2, kernel_size=1, data_format='channels_last')(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        policy_out=Dense(8*8, activation="softmax", name="policy_out")(x)
        self.model=policy_out
        
        #Value output
        x=Conv2D(filters=1, kernel_size=1, data_format="channels_last")(res_out)
        x=BatchNormalization(axis=3)(x)
        x=Activation("relu")(x)
        x=Flatten()(x)
        value_out=Dense(1, activation='tanh', name='value_out')(x)
        self.model=Model(inx, [policy_out, value_out], name='reversi_model')

In [5]:
def executeEpisode(nnet):
    searchtime=0
    examples=[]
    s=turnBoard()
    mcts=MCTS()
    move=0
    while True:
        for _ in range(25):
            mcts.search(s,nnet)
        pi=mcts.P[s.board.tobytes()]
        examples.append([s.board, pi, None])
        legalmoves=s.placable_positions(s.turn)
        if len(legalmoves)==0:
            a=-1
        else:
            legalprobs=np.take(pi,legalmoves)
            if move<30: #Temperature
                legalprobs/=sum(legalprobs)
                a=np.random.choice(legalmoves, p=legalprobs)
            else:
                a=legalmoves[np.argmax(legalprobs)]
        s.push(a)
        gameover=s.isGameOver()
        if gameover!=2:
            return assignRewards(examples, gameover)

In [6]:
def assignRewards(examples, reward):
    for i in range(len(examples)-1,-1,-1):
        examples[i][2]=reward
        reward*=-1
    
    return examples

In [7]:
from keras.optimizer_v2.adam import Adam
from keras.optimizer_v2.learning_rate_schedule import PiecewiseConstantDecay
import keras.backend as K
def policyIterSP(nnetIterStart=0):
    learning_rate_fn = PiecewiseConstantDecay(boundaries=[20000, 40000, 50000], values=[1e-4, 5e-5, 2.5e-4, 1e-5])
    if nnetIterStart==0:
        nnet=ReversiModel().model
        nnet.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=Adam(learning_rate_fn))
        loss=[]
        value_out_loss=[]
        policy_out_loss=[]
    else:
        print("Loading...")
        nnet=load_model("models/model"+str(nnetIterStart))
        with open('models/model'+str(nnetIterStart)+'/loss.npy', 'rb') as f:
            loss=np.load(f)
            value_out_loss=np.load(f)
            policy_out_loss=np.load(f)
            iters=np.load(f)
        K.set_value(nnet.optimizer.iterations, iters)
    for nnetIter in range(nnetIterStart, 60):
        for ep in range(100):
            episode=np.array(executeEpisode(nnet), dtype='O')
            print("ep: "+str(ep))
            if ep==0:
                examples=episode
            else:
                examples=np.concatenate((examples, episode))
        new_nnet=clone_model(nnet)
        new_nnet.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=Adam(learning_rate_fn))
        K.set_value(new_nnet.optimizer.iterations, nnet.optimizer.iterations.numpy())
        np.random.shuffle(examples)
        hist=modelFit(examples,new_nnet)
        loss=np.append(loss,hist.history['loss'])
        value_out_loss=np.append(value_out_loss,hist.history['value_out_loss'])
        policy_out_loss=np.append(value_out_loss,hist.history['policy_out_loss'])
        frac_win=pit(new_nnet, nnet)
        print(frac_win, nnetIter)
        if frac_win>=0.55:
            nnet=clone_model(new_nnet)
            nnet.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=Adam(learning_rate_fn))
            K.set_value(nnet.optimizer.iterations, new_nnet.optimizer.iterations.numpy())
        save_model(nnet, "models/model"+str(nnetIter+1))
        with open('models/model'+str(nnetIter+1)+'/loss.npy', 'wb') as f:
            np.save(f, loss)
            np.save(f, value_out_loss)
            np.save(f, policy_out_loss)
            np.save(f, nnet.optimizer.iterations.numpy())

In [8]:
def modelFit(examples,nnet):
    a=[]
    b=[]
    c=[]
    for i in range(len(examples)):
        a.append(np.reshape(examples[i][0],(8,8),order='F'))
        b.append(examples[i][1])
        c.append(examples[i][2])
    a=np.array(a)
    b=np.array(b)
    c=np.array(c)
    return nnet.fit(a,[b,c], epochs=10, batch_size=64, verbose=0)

In [9]:
def choose_move_gameplay(nnet, mcts, s):
    for _ in range(25):
        mcts.search(s,nnet)
    pi=mcts.P[s.board.tobytes()]
    legalmoves=s.placable_positions(s.turn)
    if len(legalmoves)==0:
        a=-1
    else:
        legalprobs=np.take(pi,legalmoves)
        a=legalmoves[np.argmax(legalprobs)]
    return a

In [10]:
def pit(nnet, new_nnet):
    record=[0,0,0]#new_nnet wins, nnet wins, draws
    for i in range(40):
        print("pit: "+str(i))
        s=turnBoard()
        mcts=MCTS()
        new_mcts=MCTS()
        while True:
            a=choose_move_gameplay(nnet, mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[0]+=1
                elif gameOver==-1:
                    record[1]+=1
                else:
                    record[2]+=1
                break
            a=choose_move_gameplay(new_nnet, new_mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[1]+=1
                elif gameOver==-1:
                    record[0]+=1
                else:
                    record[2]+=1
                break
        s=turnBoard()
        mcts=MCTS()
        new_mcts=MCTS()
        while True:
            a=choose_move_gameplay(new_nnet, new_mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[1]+=1
                elif gameOver==-1:
                    record[0]+=1
                else:
                    record[2]+=1
                break
            a=choose_move_gameplay(nnet, mcts, s)
            s.push(a)
            gameOver=s.isGameOver()
            if gameOver!=2:
                if gameOver==1:
                    record[0]+=1
                elif gameOver==-1:
                    record[1]+=1
                else:
                    record[2]+=1
                break
    return (record[0]+0.5*record[2])/sum(record)

In [11]:
import time
s=time.time()
model=policyIterSP()
print(time.time()-s)

ep: 0
ep: 1
ep: 2
ep: 3
ep: 4
ep: 5
ep: 6
ep: 7
ep: 8
ep: 9
ep: 10
ep: 11
ep: 12
ep: 13
ep: 14
ep: 15
ep: 16
ep: 17
ep: 18
ep: 19
ep: 20
ep: 21
ep: 22
ep: 23
ep: 24
ep: 25
ep: 26
ep: 27
ep: 28
ep: 29
ep: 30
ep: 31
ep: 32
ep: 33
ep: 34
ep: 35
ep: 36
ep: 37
ep: 38
ep: 39
ep: 40
ep: 41
ep: 42
ep: 43
ep: 44
ep: 45
ep: 46
ep: 47
ep: 48
ep: 49
ep: 50
ep: 51
ep: 52
ep: 53
ep: 54
ep: 55
ep: 56
ep: 57
ep: 58
ep: 59
ep: 60
ep: 61
ep: 62
ep: 63
ep: 64
ep: 65
ep: 66
ep: 67
ep: 68
ep: 69
ep: 70
ep: 71
ep: 72
ep: 73
ep: 74
ep: 75
ep: 76
ep: 77
ep: 78
ep: 79
ep: 80
ep: 81
ep: 82
ep: 83
ep: 84
ep: 85
ep: 86
ep: 87
ep: 88
ep: 89
ep: 90
ep: 91
ep: 92
ep: 93
ep: 94
ep: 95
ep: 96
ep: 97
ep: 98
ep: 99
pit: 0
pit: 1
pit: 2
pit: 3
pit: 4
pit: 5
pit: 6
pit: 7
pit: 8
pit: 9
pit: 10
pit: 11
pit: 12
pit: 13
pit: 14
pit: 15
pit: 16
pit: 17
pit: 18
pit: 19
pit: 20
pit: 21
pit: 22
pit: 23
pit: 24
pit: 25
pit: 26
pit: 27
pit: 28
pit: 29
pit: 30
pit: 31
pit: 32
pit: 33
pit: 34
pit: 35
pit: 36
pit: 37
pit: 38
pit: 39


KeyboardInterrupt: 