In [5]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np

MCTS is equivalent Monte Carlo control with UCB applied to the sub-MDP starting from the current state, where table lookup is represented by a tree. Since the action leads to deterministic results, $q(s,a)=v(s'_{\rm a})$. $\gamma=1$ and r=0 until endgame

# Barebone MCTS with UCT
https://youtu.be/ItMutbeOHtc?t=3922, https://www.youtube.com/watch?v=Bm7zah_LrmE

1. From each node (a node = a state of the board), repeat trying unexplored moves and rollout randomly unil gameover. The win/lose count is propagated back along the path (only above or at the level where new moves are being tried out -- for memory constraints) taken for each node
2. when all moves of the node have been tried at least once, *select* the one to try based on **Upper Confidence Bound** UCB (mean value + C\*uncertainty of the mean $\propto\sqrt{\frac{\log N_{\rm visit}}{n_{\rm child's visit}}}$). Repeat doing this until there is at least an untried move at the node
3. repeat 1&2 until desired time/iteration limit, then select the best move at root node (current real situation of the board) based on *mean value* of each move

4. consider some nodes may be reached from different ways; so keep the {move: child_node} dictionary. when selecting child node, return the corresponding move and child_node also (do not use parent of a node to backtrack)
5. ONLY for terminating (acyclic) games
6. Keep the nodes that (in principle) could be reached from the current state in the future. Game class need to implement `Reachable` wrt its key for this to work. Otherwise keep every node (memory...)
7. prune tree from the leaf. If searched to the end game already, we know the value of node exactly at that point, and can know if previous state will definitely (not) move towards there.

## Code: [mcts.ai](http://mcts.ai/code/python.html) recursive + node reuse + pruning

instance of game class is passed to `MCTS`'s `search` function, and must have:
* `playerToMove`: 1 or 2
* `GetMoves()`: returns list of (hashable) valid moves from current state
* `Move(m)`: execute the passed in move (which is guaranteed to be chosen from `GetMoves()`)
* `RollOut()`: randomly execuate the game until game over
* `GetResult(viewpoint)`: (after game over) returns the reward (1 or 0) from the specified viewpoint (1 or 2 -- corresponding to playerToMove)
* `key()`: returns a string fully and uniquely characterize the status of the game
* [Optional]`Reachable(k)`: returns whether (True/False) the game could in principle reach the state represented by the passed-in key `k`, which was generated by the `key()` function of another state. If not provided, prune tree by deleting nodes not visited during a search

In [31]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np
#     def TreeToString(self, indent):
#         s = "\n"+ "| "*indent + str(self)
#         s += ''.join(c.TreeToString(indent+1)+str(m) for m,c in self.childNodes.items())
#         return s
    
class MCTS():   #reuses nodes
    def __init__(self, game):
        self.nodes = {}          #{key: {viewpoint:,wins:,visits:,childs:{mv:key},untried:[]}}     #store all nodes previously explored
        self.game = game
        self.inf = float('inf')

    def search(self, timemax=None, itermax=None, explore=1, verbose=0):
        key=self.game.key()
        if 'Reachable' in dir(self.game): #pruning by (game provided) finding which node will be reachable
            self.nodes = {k:v for k,v in self.nodes.items() if k==key or self.game.Reachable(k)}

        node = self.rootnode
        if node['untriedMoves'] or len(node['childNodes'])>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                OneBatch = max(1,itermax//10) if itermax else 100
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+OneBatch
                    for _ in range(OneBatch):
                        v=self.simulate(deepcopy(self.game),node)
                    if abs(v)==self.inf: break
            else: #debug
                start=time.time()
                while True:
                    for _ in range(1000): v=self.simulate(deepcopy(self.game),node)
                    print(time.time()-start,sum(self.nodes[k]['visits'] for k in node['childNodes'].values()),
                          self.ChildrenToString(node).replace('\n',' '),
                          end='\r')
                    if abs(v)==self.inf: break
                    
        moveToChild, bestChild = self.UCTSelectMove(node,explore=0)
        if verbose: 
            print(self.ChildrenToString(node))
            print(sum(self.nodes[k]['visits'] for k in node['childNodes'].values()))
            
        return moveToChild
    def ChildrenToString(self,node):
        return "\n".join(f'{str(mv):6s}-> [W/V: {n["wins"]:6g}/{n["visits"]:6.0f} | UnXplrd: {len(n["untriedMoves"])}]'
                            for mv,k in sorted(node['childNodes'].items(),key=lambda e: self.nodes[e[1]]['wins']/self.nodes[e[1]]['visits'])
                            for n in [self.nodes[k]]
                           )
    @property
    def rootnode(self):
        key = self.game.key()
        if key not in self.nodes:
            self.nodes[key] = {'viewpoint':3-self.game.playerToMove,
                               'wins':0.,
                               'visits':0,
                               'childNodes':{},
                               'untriedMoves':self.game.GetMoves()}
        return self.nodes[key]
    
    def UCTSelectMove(self,n,explore=1):
        bestval=None
        bestmv=None
        bestnn=None
        for mv,k in n['childNodes'].items():
            nn = self.nodes[k]
            curval = nn['wins']/nn['visits'] + explore*math.sqrt(math.log(n['visits'])/nn['visits'])
            if bestval is None or curval>bestval:
                bestmv = mv
                bestval = curval
                bestnn = nn
        return bestmv,bestnn
    
    def simulate(self,game,node,explore=1):
        if node['untriedMoves'] == [] and node['childNodes'] != {}: #fully expanded, non-terminal
            move,nnode = self.UCTSelectMove(node,explore)            
            if nnode['wins']==self.inf or nnode['wins']==-self.inf: #if the best is -inf or inf already, can backprop
                node['wins']+= -nnode['wins']; node['visits']+=1
                return -nnode['wins']
            game.Move(move)
            v = self.simulate(game,nnode)
            if v==-self.inf: v=0 # if v==-inf, next mover will not choose this branch as it leads to loss.
                                 # Note: other branch may have finite v as UCT  saw this branch as finite v before simulation
            node['wins']+=1-v; node['visits']+=1;
            return 1-v
        elif node['untriedMoves'] != []: #not fully expanded: expand and then rollout
            move = random.choice(node['untriedMoves']) 
            game.Move(move)
            k = game.key()
            if k not in self.nodes:
                self.nodes[k] = {'viewpoint':3-game.playerToMove,
                                 'wins':0.,
                                 'visits':0.,
                                 'childNodes':{},
                                 'untriedMoves':game.GetMoves()}
            nnode = self.nodes[k]
            node['untriedMoves'].remove(move)
            node['childNodes'][move]=k
            game.RollOut()
            v = game.GetResult(nnode['viewpoint']) #not setting inf or -inf because there are indeterminism (unexplored moves)
            nnode['wins']+=v; nnode['visits']+=1
            node['wins']+=1-v; node['visits']+=1
            return 1-v
        elif node['childNodes'] == {}: #terminal
            v = game.GetResult(node['viewpoint'])
            if v==1:    v=self.inf
            elif v==0:  v=-self.inf
            node['wins']+=v; node['visits']+=1
            return v

separate `node` dict into multiple dicts

In [33]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np
inf = float('inf')

class MCTS():   #reuses nodes
    def __init__(self, game):
#         self.nodes = {}          #{key: {viewpoint:,wins:,visits:,childs:{mv:key},untried:[]}}     #store all nodes previously explored
        self.game = game
        self.viewpoint = {}
        self.wins = {}
        self.visits = {}
        self.childNodes = {}
        self.untriedMoves = {}

    def search(self, timemax=None, itermax=None, explore=1, verbose=0):
        rootkey=self.game.key()
        if rootkey not in self.wins:
            self.viewpoint[rootkey] = 3 - self.game.playerToMove
            self.wins[rootkey] = 0
            self.visits[rootkey] = 0
            self.childNodes[rootkey] = {}
            self.untriedMoves[rootkey] = self.game.GetMoves()
            
        if 'Reachable' in dir(self.game): #pruning by (game provided) finding which node will be reachable
            self.wins = {k:v for k,v in self.wins.items() if k==rootkey or self.game.Reachable(k)}
            self.viewpoint = {k:self.viewpoint[k] for k in self.wins}
            self.visits = {k:self.visits[k] for k in self.wins}
            self.childNodes = {k:self.childNodes[k] for k in self.wins}
            self.untriedMoves = {k:self.untriedMoves[k] for k in self.wins}

        if self.untriedMoves[rootkey] or len(self.childNodes[rootkey])>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                OneBatch = max(1,itermax//10) if itermax else 100
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+OneBatch
                    for _ in range(OneBatch):
                        v=self.simulate(deepcopy(self.game),rootkey)
                    if abs(v)==inf: break
            else: #debug
                start=time.time()
                while True:
                    for _ in range(1000): v=self.simulate(deepcopy(self.game),rootkey)
                    print(time.time()-start,
                          sum(self.visits[k] for k in self.childNodes[rootkey].values()),
                          self.ChildrenToString(rootkey).replace('\n',' '),
                          end='\r')
                    if abs(v)==inf: break
                    
        moveToChild = self.UCTSelectMove(rootkey,explore=0)
        if verbose: 
            print(self.ChildrenToString(rootkey))
            print(sum(self.visits[k] for k in self.childNodes[rootkey].values()))
            
        return moveToChild
    def ChildrenToString(self,key):
        return "\n".join(f'{str(mv):6s}-> [W/V: {self.wins[k]:6g}/{self.visits[k]:6.0f} | UnXplrd: {len(self.untriedMoves[k])}]'
                            for mv,k in sorted(self.childNodes[key].items(),key=lambda e: self.wins[e[1]]/self.visits[e[1]])
                           )
    
    def UCTSelectMove(self,key,explore=1):
        bestval=None
        bestmv=None
        for mv,k in self.childNodes[key].items():
            curval = self.wins[k]/self.visits[k] + explore*math.sqrt(math.log(self.visits[key])/self.visits[k])
            if bestval is None or curval>bestval:
                bestmv = mv
                bestval = curval
        return bestmv
    
    def simulate(self,game,key,explore=1):
        if self.untriedMoves[key] == [] and self.childNodes[key] != {}: #fully expanded, non-terminal
            move = self.UCTSelectMove(key,explore)
            nkey = self.childNodes[key][move]
            if self.wins[nkey]==inf or self.wins[nkey]==-inf: #if the best is -inf or inf already, can backprop
                self.wins[key] = -self.wins[nkey]; self.wins[key]+=1
                return -self.wins[nkey]
            game.Move(move)
            v = self.simulate(game,nkey)
            if v==-inf: v=0 # if v==-inf, next mover will not choose this branch as it leads to loss.
                                 # Note: other branch may have finite v as UCT  saw this branch as finite v before simulation
            self.wins[key]+=1-v; self.visits[key]+=1
            return 1-v
        elif self.untriedMoves[key] != []: #not fully expanded: expand and then rollout
            move = random.choice(self.untriedMoves[key]) 
            game.Move(move)
            k = game.key()
            if k not in self.wins:
                self.viewpoint[k] = 3-game.playerToMove
                self.wins[k] = 0
                self.visits[k] = 0
                self.childNodes[k] = {}
                self.untriedMoves[k] = game.GetMoves()
            self.untriedMoves[key].remove(move)
            self.childNodes[key][move]=k
            game.RollOut()
            v = game.GetResult(self.viewpoint[k]) #not setting inf or -inf because there are indeterminism (unexplored moves)
            self.wins[k]+=v; self.visits[k]+=1
            self.wins[key]+=1-v; self.visits[key]+=1
            return 1-v
        elif self.childNodes[key] == {}: #terminal
            v = game.GetResult(self.viewpoint[key])
            if v==1:    v=inf
            elif v==0:  v=-inf
            self.wins[key]+=v; self.visits[key]+=1
            return v

## [m,n,k-game](https://en.wikipedia.org/wiki/M,n,k-game)

In [34]:
class OXO():
    def __init__(self,width=3,height=3,inarow=None): # game ends as one success
        self.playerToMove = 1
        self.width = width
        self.height = height
        self.board = [0]*width*height # 0 = empty, 1 = player 1, 2 = player 2
        self.empty = list(range(width*height))
        
        # winning linesrequires
        inarow = inarow or min(width,height)
        self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]

    def Move(self, m):
        """ Update a game by carrying out the given move. Must also update playerToMove.  """
        self.board[m] = self.playerToMove
        self.empty.remove(m)
        if self.GetResult(self.playerToMove) in [0,1]:  
            self.empty=[]
        self.playerToMove = 3 - self.playerToMove
        
    def GetMoves(self): # must clone as Node use this return as untriedMoves  # return self.empty[:]
        return self.empty[:]
    
    def RollOut(self):
        possibleMoves = self.GetMoves()
        while possibleMoves:
            self.Move(random.choice(possibleMoves))
            possibleMoves = self.GetMoves()
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        myscore = opposcore = 0
        for l in self.lines:
            one=self.board[l[-1]]
            if one!=0 and all(self.board[p]==one for p in l):
                if one == viewpoint:
                    myscore += 1
                else:
                    opposcore += 1
        if myscore>opposcore: return 1
        elif opposcore>myscore: return 0
        return 0.5

    def __repr__(self): # for output: 1 (X), 2 = (O)
        lines=[''.join("·XO"[self.board[i*self.width+j]] for j in range(self.width)) for i in range(self.height)]
        return '\n'.join(lines)    
    def key(self): #string containing all information for hashing the node
        return str(self.playerToMove)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '·12' or ch==key[i] for i,ch in enumerate(mykey))

In [35]:
mcts = MCTS(OXO(width=5,height=5,inarow=4))
mcts.game.Move(12) # after this, found (6/8/16/18) is optimal after ~1,500,000 simulations
# mcts.game.Move(11) # after this, found sure win policy after 260663 simulations
# mcts.game.Move(16) # after this, found sure loss after 189564 simulations
# mcts.game.Move(8)  # after this, found sure win policy after 38150 simulations
# mcts.game.Move(18) # after this, found sure loss after 25918 simulations
# mcts.game.Move(6)  # after this, found sure win policy after 1601 simulations
mcts.search(timemax=60,verbose=1)

14    -> [W/V:  208.5/   741 | UnXplrd: 0]
20    -> [W/V:    299/   999 | UnXplrd: 0]
24    -> [W/V:  372.5/  1205 | UnXplrd: 0]
4     -> [W/V:  374.5/  1211 | UnXplrd: 0]
2     -> [W/V:  409.5/  1308 | UnXplrd: 0]
0     -> [W/V:  421.5/  1341 | UnXplrd: 0]
22    -> [W/V:    528/  1634 | UnXplrd: 0]
3     -> [W/V:  558.5/  1717 | UnXplrd: 0]
23    -> [W/V:    562/  1727 | UnXplrd: 0]
21    -> [W/V:    567/  1740 | UnXplrd: 0]
10    -> [W/V:  600.5/  1831 | UnXplrd: 0]
9     -> [W/V:  692.5/  2080 | UnXplrd: 0]
5     -> [W/V:  693.5/  2083 | UnXplrd: 0]
19    -> [W/V:    927/  2709 | UnXplrd: 0]
1     -> [W/V:  952.5/  2776 | UnXplrd: 0]
15    -> [W/V:   1025/  2969 | UnXplrd: 0]
6     -> [W/V: 3410.5/  9154 | UnXplrd: 0]
18    -> [W/V:   3542/  9491 | UnXplrd: 0]
8     -> [W/V:   4317/ 11469 | UnXplrd: 0]
16    -> [W/V:   4533/ 12019 | UnXplrd: 0]
11    -> [W/V: 6624.5/ 17323 | UnXplrd: 0]
7     -> [W/V: 6292.5/ 16368 | UnXplrd: 0]
13    -> [W/V: 9362.5/ 24226 | UnXplrd: 0]
17    -> [W

17

Fail to draw on 5,5,4, 2nd: after first player places at center, 2nd player should play at a position "diagonal" to the center (6,8,16,18). MCTS failed to identify these four as optimal until after ~10min

In [None]:
game = OXO(width=3,height=3,inarow=3,moves_per_round=1,end1=True) #traditional
# game = OXO(width=5,height=5,inarow=4) #http://boulter.com/ttt/
# game = OXO(width=8,height=8,inarow=4,moves_per_round=2,end1=False) #http://www.atksolutions.com/games/tictactoedeluxe.html
# game = OXO(width=15,height=15,inarow=5) #Gomoku
mcts = MCTS(game)
while mcts.game.GetMoves():
    if mcts.game.playerToMove==0:
#         clear_output()
        board=np.array([f'{("·XO"[s] if s else str(pos)):2s}' for pos,s in enumerate(mcts.game.board)]).reshape(mcts.game.height,mcts.game.width)
        print(board)
        n=mcts.rootnode
        print(n.ChildrenToString())
        print(f'choose from: {mcts.game.GetMoves()[:10]}...',end='')
        m = (int(input()),)   #2 moves (int(input()),int(input()))
    else:
        m = mcts.search(timemax=1,verbose=1)
    mcts.game.Move(m)
    print(mcts.game)
    print(len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

## Connect4

In [110]:
class ConnectN:
    def __init__(self,width=7,height=6,inarow=4):
        self.playerToMove = 1
        self.width = width
        self.height = height
        self.board = [0]*width*height
        self.moves = list(range(width))
        self.firstEmpty = [height-1]*width
        
        # winning lines
        inarow = inarow or min(width,height)
        self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]

    def Move(self, move):  #move=column idx
        assert move in self.moves
        self.board[self.firstEmpty[move]*self.width+move]=self.playerToMove
        self.firstEmpty[move] -= 1
        if self.firstEmpty[move]<0:
            self.moves.remove(move)
        if self.GetResult(self.playerToMove) in [0,1]:  #indicate game ended as someone won
            self.moves=[] 
        self.playerToMove = 3 - self.playerToMove
        
    def GetMoves(self):  # legal moves for next player. empty if game is over
        return self.moves[:]
    
    def RollOut(self):
        while self.moves:
            self.Move(random.choice(self.moves))
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        for l in self.lines:
            one=self.board[l[-1]]
            if one!=0 and all(self.board[p]==one for p in l):
                if one == viewpoint:
                    return 1
                else:
                    return 0
        return 0.5

    def __repr__(self): # string representation: 1 (X), 2 = (O)
        s = ''
        for left in range(0,self.width*self.height,self.width):
            for sh in range(self.width):
                s+="·XO"[self.board[left+sh]]
            s+='\n'
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerToMove)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '12·' or ch==key[i] for i,ch in enumerate(mykey)) 

Fails to deduce that the 5th move (1st player) should be at center

In [None]:
game = ConnectN(7,6,4) #https://connect4.gamesolver.org/
for _ in range(4): game.Move(3)
mcts = MCTS(game)
while mcts.game.GetMoves():
    if mcts.game.playerToMove==2:
#         clear_output()
        fn = lambda s: s if int(s)<7 else ''
        board=np.array([f'{("·XO"[s] if s else fn(str(pos))):2s}' for pos,s in enumerate(mcts.game.board)])
        print(board.reshape(mcts.game.height,mcts.game.width))
        print(mcts.rootnode.ChildrenToString())
        print('choose from:',end='')
        m = int(input(mcts.game.GetMoves()))
    else:
        m = mcts.search(timemax=10,verbose=1,explore=2)
    mcts.game.Move(m)
    print(mcts.game)
    print(len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

## Othello

In [146]:
class Othello():
    def __init__(self, size = 8):  # size must be integral and even
        self.playerToMove = 1
        self.size = size
        self.board = [] # 0 = empty, 1 = player 1, 2 = player 2
        for y in range(self.size):
            self.board.append([0]*size)
        self.board[size//2][size//2] = self.board[size//2-1][size//2-1] = 2
        self.board[size//2][size//2-1] = self.board[size//2-1][size//2] = 1
        self.board = [e for l in self.board for e in l]

    def Move(self, move):
        if move is not None:
            (x,y)=divmod(move,self.size)
            assert self.IsOnBoard(x,y) and self.board[x*self.size+y] == 0
            m = self.GetAllSandwichedCounters(x,y)
            self.board[x*self.size+y] = self.playerToMove
            for (a,b) in m:
                self.board[a*self.size+b] = self.playerToMove
        self.playerToMove = 3 - self.playerToMove
    
    def GetMoves(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        if not emptypos:
            return []
        else:
            viable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))] 
            if viable:
                return viable
            else:  #need to check if opponent also has no viable pos
                self.playerToMove = 3 - self.playerToMove
                oppoViable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))]
                self.playerToMove = 3 - self.playerToMove
                if oppoViable:
                    return [None]
                else:
                    return []
    
    def RollOut(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        bad=0 
        while emptypos:
            random.shuffle(emptypos)
            for pos in emptypos:
                if self.ExistsSandwiched(*divmod(pos,self.size)):
                    self.Move(pos)
                    bad=0
                    emptypos = [i for i,e in enumerate(self.board) if e==0]
                    break
            else:
                if bad<2: #when bad=2, both side cannot play, game ends
                    bad+=1
                    self.Move(None)
                else:
                    break

    def AdjacentEnemyDirections(self,x,y):# Speeds up GetMoves by only considering squares which are adjacent to an enemy-occupied square.
        return [(dx,dy) for (dx,dy) in [(0,+1),(+1,+1),(+1,0),(+1,-1),(0,-1),(-1,-1),(-1,0),(-1,+1)]
                        if self.IsOnBoard(x+dx,y+dy) and self.board[(x+dx)*self.size+y+dy] == 3-self.playerToMove]
    
    def ExistsSandwiched(self,x,y):# Is there at least one counter which would be flipped if my counter was placed at (x,y)? 
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            x1=x+dx
            y1=y+dy
            while self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == 3 - self.playerToMove:
                x1 += dx
                y1 += dy
            if self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == self.playerToMove:
                return True
        return False
    
    def GetAllSandwichedCounters(self, x, y):# Is (x,y) a possible move (i.e. opponent counters are sandwiched between (x,y) and my counter in some direction)?
        sandwiched = []
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            sandwiched.extend(self.SandwichedCounters(x,y,dx,dy))
        return sandwiched

    def SandwichedCounters(self, x, y, dx, dy):# Return the coordinates of all opponent counters sandwiched between (x,y) and my counter.
        x += dx
        y += dy
        sandwiched = []
        while self.IsOnBoard(x,y) and self.board[x*self.size+y] == 3 - self.playerToMove:
            sandwiched.append((x,y))
            x += dx
            y += dy
        if self.IsOnBoard(x,y) and self.board[x*self.size+y] == self.playerToMove:
            return sandwiched
        else:
            return [] # nothing sandwiched

    def IsOnBoard(self, x, y):
        return x >= 0 and x < self.size and y >= 0 and y < self.size
    
    def GetResult(self, viewpoint): #after gameover
        viewpointscore=oppositescore=0
        for e in self.board:
            if e==viewpoint: viewpointscore+=1
            elif e==3-viewpoint: oppositescore+=1
        if viewpointscore > oppositescore: return 1.0
        elif oppositescore > viewpointscore: return 0.0
        else: return 0.5 # draw

    def __repr__(self):
        s= ""
        for x in range(self.size):
            for y in range(self.size):
                s += "·XO"[self.board[x*self.size+y]]
            s += "\n"
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerToMove)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return key.count('·')<mykey.count('·') and \
               all(ch in '12·' or key[i]!='·' for i,ch in enumerate(mykey)) and\
               (mykey[1]=='·' or mykey[1]==key[1]) and\
               (mykey[self.size]=='·' or mykey[self.size]==key[self.size]) and\
               (mykey[-self.size]=='·' or mykey[-self.size]==key[-self.size]) and\
               (mykey[-1]=='·' or mykey[-1]==key[-1])

In [148]:
#benchmark
mcts = MCTS(Othello(8))
mcts.search(verbose=1, timemax=10) #4700

19      |[W/V:  276.5/   614 | UnXplrd: 0] 
44      |[W/V:  341.5/   741 | UnXplrd: 0] 
37      |[W/V:  387.5/   831 | UnXplrd: 0] 
26      |[W/V:  639.5/  1314 | UnXplrd: 0] 
3500


26

* timemax=60 beats Reversu Difficult
* timemax=15 beats Reversu Medium, lost to Reversu Difficult

In [None]:
game = Othello(8)
mcts = MCTS(game)
while mcts.game.GetMoves():
    if mcts.game.playerToMove==2:
        clear_output()
        board=np.array([f'{("·XO"[e] if e else str(p)):2s}' for p,e in enumerate(mcts.game.board)]).reshape(8,8)
        print(board)
        print(mcts.rootnode.ChildrenToString())
        print('choose from:',end='')
        moves = mcts.game.GetMoves()
        if moves[0] is None: m = None
        else: m=int(input(moves))
    else:
        m = mcts.search(verbose=1,timemax=3+len(list(filter(None,mcts.game.board)))*0 )  #len...=number of pieces on board
    mcts.game.Move(m)
    print(mcts.game)
    print(len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

# MCTS with Multiprocessing

In [17]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np
import multiprocessing as mp
inf = float('inf')
    
class MCTS():   #reuses nodes
    def __init__(self, game):
        self.nodes = {}          #{key: {viewpoint:,wins:,visits:,childs:{mv:key},untried:[]}}     #store all nodes previously explored
        self.game = game
        self.rootkey = self.game.key()

    def search(self, timemax=None, itermax=None, explore=1, verbose=0):
        self.rootkey = self.game.key()
        cpus = mp.cpu_count()
        pool = mp.Pool(cpus)
        if 'Reachable' in dir(self.game): #pruning by (game provided) finding which node will be reachable
            self.nodes = {k:v for k,v in self.nodes.items() if k==rootkey or self.game.Reachable(k)}
        
        nextmoves =  self.rootnode['untriedMoves']+list(self.rootnode['childNodes'])
        if len(nextmoves)>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                OneBatch=100000000
                if timemax: self.simulate_multiple(3); OneBatch=int(timemax/(time.time()-start))
                if itermax: OneBatch = min(OneBatch,max(1,itermax//3))
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+OneBatch
#                     for _ in range(OneBatch): v=self.simulate(deepcopy(self.game),self.rootnode,self.nodes)
                    vs,nodes = zip(*pool.map(self.simulate_multiple,range(OneBatch//cpus*100,OneBatch//cpus*100+cpus)))
                    v=max(vs)                    
                    for d in nodes:     # subtract the input wins and visits from the output
                        for key, node in d.items():
                            if key in self.nodes:
                                selfnode = self.nodes[key]
                                if np.isfinite(selfnode['wins']):
                                    node['wins']-=selfnode['wins']
                                node['visits']-=selfnode['visits']
                    for d in nodes:     # add the generated wins and visits to the record
                        for key, node in d.items():
                            if key in self.nodes:
                                selfnode = self.nodes[key]
                                selfnode['wins'] += node['wins']
                                selfnode['visits'] += node['visits']
                                if set(node['childNodes'])-set(selfnode['childNodes']):
                                    selfnode['childNodes'].update(node['childNodes'])
                                    selfnode['untriedMoves'] = [mv for mv in selfnode['untriedMoves'] if mv not in selfnode['childNodes']]
                            else:
                                self.nodes[key]=deepcopy(node)
                    if abs(v)==inf: break
            else:
                start=time.time()
                while True:
                    vs,nodes = zip(*pool.map(self.simulate_multiple,range(1000*100,1000*100+cpus)))
                    v=max(vs)
                    wins = {}      #(for speed), just choose a node and update the nextlevel childNodes counts only from other returned nodes
                    visits = {}
                    for mv,key in self.nodes[self.rootkey]['childNodes'].items():
                        if np.isfinite(self.nodes[key]['wins']):
                            wins[mv] = self.nodes[key]['wins']
                        visits[mv] = self.nodes[key]['visits']                    
                    self.nodes=nodes[0]
                    for d in nodes[1:]:
                        for mv,key in d[self.rootkey]['childNodes'].items():
                            if key in self.nodes:
                                self.nodes[key]['wins']+=d[key]['wins']-wins.get(mv,0)
                                self.nodes[key]['visits']+=d[key]['visits']-visits.get(mv,0)
                    
                    print(time.time()-start,sum(self.nodes[k]['visits'] for k in self.rootnode['childNodes'].values()),
                          self.ChildrenToString().replace('\n',' '),
                          end='\r')
                    if abs(v)==inf: break
                    
        if len(nextmoves)>1:
            moveToChild, bestChild = self.UCTSelectMove(self.rootnode,explore=0)
        else:
            moveToChild = nextmoves[0]
        if verbose: 
            print(self.ChildrenToString(self.rootnode))
            print('#simulations: ',sum(self.nodes[k]['visits'] for k in self.rootnode['childNodes'].values()))

        pool.close()
        return moveToChild
    def ChildrenToString(self,node=None):
        if node is None: node=self.rootnode
        return "\n".join(f'{str(mv):6s}-> [W/V: {n["wins"]:6g}/{n["visits"]:6.0f} | UnXplrd: {len(n["untriedMoves"])}]'
                            for mv,k in sorted(node['childNodes'].items(),key=lambda e: self.nodes[e[1]]['wins']/self.nodes[e[1]]['visits'])
                            for n in [self.nodes[k]]
                           )
    @property
    def rootnode(self):
        if self.rootkey not in self.nodes:
            self.nodes[self.rootkey] = {'viewpoint':3-self.game.playerToMove,
                                        'wins':0.,
                                        'visits':0,
                                        'childNodes':{},
                                        'untriedMoves':self.game.GetMoves()}
        return self.nodes[self.rootkey]
    
    def UCTSelectMove(self,n,nodes=None,explore=1):
        bestval=None
        bestmv=None
        bestnn=None
        if nodes is None: nodes=self.nodes
        for mv,k in n['childNodes'].items():
            nn = nodes[k]
            curval = nn['wins']/nn['visits'] + explore*math.sqrt(math.log(n['visits'])/nn['visits'])
            if bestval is None or curval>bestval:
                bestmv = mv
                bestval = curval
                bestnn = nn
        return bestmv,bestnn

    def simulate_multiple(self,i):
        rootkey = self.game.key()
        nodes=deepcopy(self.nodes)
        node=nodes[rootkey]
        np.random.seed(i)
        random.seed(i)
        OneBatch,i=divmod(i,100)
        greatestval = -inf
        for _ in range(OneBatch+1):
            game=deepcopy(self.game)
            v = self.simulate(game,node,nodes)
            if v>greatestval:greatestval=v
        return greatestval,nodes
    def simulate(self,game,node,nodes,explore=1):
        if node['untriedMoves'] == [] and node['childNodes'] != {}: #fully expanded, non-terminal
            move,nnode = self.UCTSelectMove(node,nodes=nodes,explore=explore)            
            if nnode['wins']==inf or nnode['wins']==-inf: #if the best is -inf or inf already, can backprop
                node['wins']+= -nnode['wins']; node['visits']+=1
                return -nnode['wins']
            game.Move(move)
            v = self.simulate(game,nnode,nodes)
            if v==-inf: v=0 # if v==-inf, next mover will not choose this branch as it leads to loss.
                                 # Note: other branch may have finite v as UCT  saw this branch as finite v before simulation
            node['wins']+=1-v; node['visits']+=1;
            return 1-v
        elif node['untriedMoves'] != []: #not fully expanded: expand and then rollout
            move = random.choice(node['untriedMoves']) 
            game.Move(move)
            k = game.key()
            if k not in nodes:
                nodes[k] = {'viewpoint':3-game.playerToMove,
                             'wins':0.,
                             'visits':0,
                             'childNodes':{},
                             'untriedMoves':game.GetMoves()}
            nnode = nodes[k]
            node['untriedMoves'].remove(move)
            node['childNodes'][move]=k
            game.RollOut()
            v = game.GetResult(nnode['viewpoint']) #not setting inf or -inf because there are indeterminism (unexplored moves)
            nnode['wins']+=v; nnode['visits']+=1
            node['wins']+=1-v; node['visits']+=1
            return 1-v
        elif node['childNodes'] == {}: #terminal
            v = game.GetResult(node['viewpoint'])
            if v==1:    v=inf
            elif v==0:  v=-inf
            node['wins']+=v; node['visits']+=1
            return v

In [18]:
mcts = MCTS(OXO(width=5,height=5,inarow=4))
mcts.game.Move(12) # after this, found (6/8/16/18) is optimal after ~1,500,000 simulations
mcts.game.Move(11) # after this, found sure win policy after 260663 simulations
# mcts.game.Move(16) # after this, found sure loss after 189564 simulations
# mcts.game.Move(8)  # after this, found sure win policy after 38150 simulations
# mcts.game.Move(18) # after this, found sure loss after 25918 simulations
# mcts.game.Move(6)  # after this, found sure win policy after 1601 simulations
start=time.time()
mcts.search(itermax=100000,verbose=1)
time.time()-start

10    -> [W/V:  284.5/   560 | UnXplrd: 0]
14    -> [W/V:  312.5/   606 | UnXplrd: 0]
0     -> [W/V:  441.5/   853 | UnXplrd: 0]
24    -> [W/V:    506/   953 | UnXplrd: 0]
4     -> [W/V:  731.5/  1354 | UnXplrd: 0]
20    -> [W/V:  601.5/  1109 | UnXplrd: 0]
22    -> [W/V:  779.5/  1414 | UnXplrd: 0]
9     -> [W/V:  883.5/  1587 | UnXplrd: 0]
2     -> [W/V:    819/  1461 | UnXplrd: 0]
13    -> [W/V: 1200.5/  2137 | UnXplrd: 0]
21    -> [W/V:   1783/  3120 | UnXplrd: 0]
5     -> [W/V:   1554/  2717 | UnXplrd: 0]
15    -> [W/V:   1868/  3246 | UnXplrd: 0]
19    -> [W/V: 2006.5/  3467 | UnXplrd: 0]
23    -> [W/V:   2040/  3512 | UnXplrd: 0]
3     -> [W/V: 2045.5/  3510 | UnXplrd: 0]
1     -> [W/V: 2167.5/  3716 | UnXplrd: 0]
7     -> [W/V:   6982/ 11560 | UnXplrd: 0]
8     -> [W/V: 7323.5/ 12039 | UnXplrd: 0]
18    -> [W/V:   6412/ 10415 | UnXplrd: 0]
17    -> [W/V:   8904/ 14378 | UnXplrd: 0]
6     -> [W/V: 16503.5/ 26487 | UnXplrd: 0]
16    -> [W/V:  14482/ 23143 | UnXplrd: 0]
#simulatio

20.77542018890381

# MCTS written in Cython

## mcts library

In [1]:
%load_ext Cython
import os, tempfile
os.chdir(tempfile.mkdtemp())

In [2]:
%%file mcts.pxd
# cython: language_level=3
from libcpp.vector cimport vector
cdef class Game:
    cdef public int playerToMove
    cpdef list GetMoves(self)
    cpdef void Move(self,int move)
    cpdef void RollOut(self)
    cpdef str key(self)
    cpdef double GetResult(self,int viewpoint)
cdef class Node:
    cdef public int visits
    cdef public double wins
    cdef public dict childNodes # {move:nextnode} the move that get *into* next node
    cdef public int viewpoint
    cdef public list untriedMoves
    cpdef int UCTSelectChild(self,double explore=?)
    cpdef void AddChild(self, int move, n)
    cpdef void Update(self, double result)
cdef class MCTS:
    cdef public dict nodes
    cdef public Game game
    cdef double simulate(self,Game game,Node node,double explore=?)

Writing mcts.pxd


In [3]:
%%file mcts.pyx
# %%cython -+ -a
cimport cython
from copy import deepcopy
import time, random
from libc.math cimport sqrt, log
from libc.stdlib cimport rand
from libcpp.vector cimport vector
from cython.operator cimport dereference as deref, preincrement
from numpy.math cimport INFINITY as inf

cdef class Game:
    cpdef list GetMoves(self): return []
    cpdef void Move(self,int move):pass
    cpdef void RollOut(self):pass
    cpdef str key(self): return "";
    cpdef double GetResult(self,int viewpoint): return 0.5;

cdef class Node:
    def __init__(self, Game game):
        self.wins = 0.  #for the player Just Moved
        self.visits = 0
        self.childNodes={}
        self.viewpoint = 3-game.playerToMove
        self.untriedMoves = game.GetMoves() # future child nodes
        
    cpdef int UCTSelectChild(self,double explore=1):
        cdef double bestval=-19278, curval=-1
        cdef int k,bestmove=0
        cdef Node bestchild,v
        for k,v in self.childNodes.items():
            curval = v.wins/v.visits + explore*sqrt(log(self.visits)/v.visits)
            if curval>bestval or bestval==-19278:
                bestval=curval
                bestmove=k
        return bestmove
    
    cpdef void AddChild(self, int move, n):
        self.untriedMoves.remove(move)
        self.childNodes[move] = n
    
    cpdef void Update(self, double result):
        self.visits += 1
        self.wins += result # game results in the range [0.0, 1.0]
    def __repr__(self): return f"[W/V: {self.wins:6g}/{self.visits:6d} | UnXplrd: {len(self.untriedMoves)}] "
    def ChildrenToString(self): return "\n".join(f'{str(m):8s}|{c}' for m,c in sorted(self.childNodes.items(),key=lambda e: e[1].wins/e[1].visits))

cdef class MCTS():   #reuses nodes
    def __init__(self, Game game):
        self.nodes = {}               #store all nodes previously explored
        self.game = game

    def search(self, timemax=None, itermax=None, explore=1, verbose=0):
        key=self.game.key()
        if 'Reachable' in dir(self.game): #pruning by (game provided) finding which node will be reachable
            self.nodes = {k:v for k,v in self.nodes.items() if k==key or self.game.Reachable(k)}
        if self.rootnode.untriedMoves or len(self.rootnode.childNodes)>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+100
                    for _ in range(100):
                        v=self.simulate(deepcopy(self.game),self.rootnode)
                    if abs(v)==inf: break
            else: #debug
                start=time.time()
                while True:
                    for _ in range(1000): v=self.simulate(deepcopy(self.game),self.rootnode)
                    print(time.time()-start,sum(c.visits for c in self.rootnode.childNodes.values()),
                          sorted(list(self.rootnode.childNodes.items()), 
                              key = lambda c: -c[1].wins/c[1].visits ),
                          end='\r')
                    if abs(v)==inf: break
                    
        best = self.rootnode.UCTSelectChild(explore=0)
        if verbose==2:   print(self.rootnode.TreeToString(0))
        elif verbose==1:
            print(self.rootnode.ChildrenToString())
            print(sum(c.visits for c in self.rootnode.childNodes.values()))
        return best    
    
    @property  #  called when accessing self.rootnode
    def rootnode(self):
        key = self.game.key()
        if key not in self.nodes:
            self.nodes[key] = Node(game = self.game)
        return self.nodes[key]
    
    cdef double simulate(self,Game game,Node node,double explore=1.0):
        cdef double v=0 #return value
        cdef int move=0
        cdef Node nnode
        if len(node.untriedMoves) == 0 and node.childNodes != {}: #fully expanded, non-terminal
            move = node.UCTSelectChild(explore)
            nnode = node.childNodes[move]
            if nnode.wins==inf or nnode.wins==-inf: #if the best is -inf or inf already, can backprop
                node.Update(-nnode.wins)
                return -nnode.wins
            game.Move(move)
            v = self.simulate(game,nnode)
            if v==-inf: v=0
            node.Update(1-v)
            return 1-v
        elif len(node.untriedMoves) > 0: #not fully expanded: expand and then rollout
            move = random.choice(node.untriedMoves) 
            game.Move(move)
            k = game.key()
            if k in self.nodes:
                nnode = self.nodes[k]
            else:
                nnode = Node(game = game)
                self.nodes[k] = nnode
            node.AddChild(move,nnode)
            game.RollOut()
            v = game.GetResult(nnode.viewpoint) #not setting inf or -inf because there are indeterminism (unexplored moves)
            nnode.Update(v)
            node.Update(1-v)
            return 1.0-v
        elif node.childNodes == {}: #terminal
            v = game.GetResult(node.viewpoint)
            if v==1.0:    v=inf
            elif v==0.0:  v=-inf
            node.Update(v)
            return v

Writing mcts.pyx


In [4]:
%%file setup.py
import os
import numpy
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

setup(  name = "cythoncode",
        cmdclass = {"build_ext": build_ext},
        ext_modules = [ Extension("mcts",
                                  sources=["mcts.pyx"],
                                  language='c++',   #using C++
                                  libraries=["m"],  #for using C's math lib
                                  extra_compile_args = ["-ffast-math"])],
        include_dirs=[numpy.get_include(),
                      os.path.join(numpy.get_include(), 'numpy')])

Writing setup.py


In [5]:
!python3 setup.py build_ext --inplace

running build_ext
cythoning mcts.pyx to mcts.cpp
building 'mcts' extension
creating build
creating build/temp.linux-x86_64-3.7
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/hoi/.local/lib/python3.7/site-packages/numpy/core/include -I/home/hoi/.local/lib/python3.7/site-packages/numpy/core/include/numpy -I/usr/include/python3.7m -c mcts.cpp -o build/temp.linux-x86_64-3.7/mcts.o -ffast-math
x86_64-linux-gnu-g++ -pthread -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-Bsymbolic-functions -Wl,-z,relro -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 build/temp.linux-x86_64-3.7/mcts.o -lm -o /tmp/tmpg4i0epas/mcts.cpython-37m-x86_64-linux-gnu.so


## m,n,k-game

In [6]:
%%cython -I . -+
from libc.stdlib cimport rand
from libcpp.vector cimport vector
from mcts cimport *
cimport cython
cdef extern from "<algorithm>" namespace "std":
    iter std_find "std::find" [iter, T](iter first, iter last, const T& val)
    
cdef class OXO(Game):
    cdef public int width
    cdef public int height
    cdef public vector[int] board
    cdef public vector[vector[int]] lines
    cdef public vector[int] empty
    
    def __init__(self,int width=3,int height=3,int inarow=0):
        cdef int i,j,iar
        cdef vector[int] v
        self.playerToMove = 1
        self.width = width
        self.height = height
        self.board = vector[int](width*height)
        for i in range(width*height):
            self.empty.push_back(i)
        
        if inarow==0: inarow=width*height
        for j in range(width-inarow+1):
            for i in range(height):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+iar)
                self.lines.push_back(v)
        for j in range(width):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+width*iar)
                self.lines.push_back(v)
        for j in range(width-inarow+1):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+(width+1)*iar)
                self.lines.push_back(v)
        for j in range(inarow-1,width):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+(width-1)*iar)
                self.lines.push_back(v)
    cpdef void Move(self, int m):
        self.board[m] = self.playerToMove
        self.empty.erase(std_find[vector[int].iterator, int](self.empty.begin(), self.empty.end(), m))#         self.empty.remove(m)
        if self.GetResult(self.playerToMove) in [0,1]:  
            self.empty=vector[int](0)
        self.playerToMove = 3 - self.playerToMove
        
    cpdef list GetMoves(self): # must clone as Node use this return as untriedMoves  # return self.empty[:]
        return self.empty
    
    cpdef void RollOut(self):
        cdef int sz = self.empty.size()
        cdef int mv
        while sz>0:
            self.Move(self.empty[rand()%sz])
            sz=self.empty.size()
    
    cpdef double GetResult(self,int viewpoint):
        cdef int myscore = 0, opposcore = 0, pos
        cdef vector[int] l
        for l in self.lines:
            one=self.board[l[0]]
            if one!=0:
                for pos in l:               # all(self.board[p]==one for p in l)
                    if self.board[pos]!=one:
                        break
                else:
                    if one == viewpoint:
                        myscore += 1
                    else:
                        opposcore += 1
        if myscore>opposcore: return 1.0
        elif opposcore>myscore: return 0.0
        return 0.5

    def __repr__(self):
        s=''
        for i in range(self.width):
            for j in range(self.height):
                e=self.board[i*self.width+j]
                s+="·XO"[e]                    #does not allow s+="·XO"[self.board[i*self.width+j]]
            s+='\n'
        return s
    cpdef str key(self):
        cdef int e
        k=str(self.playerToMove)
        for e in self.board: k+="·XO"[e] 
        return k
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '·12' or ch==key[i] for i,ch in enumerate(mykey))

Fail to draw on 5,5,4, 2nd: after first player places at center, 2nd player should play at a position "diagonal" to the center (6,8,16,18). MCTS failed to identify these four as optimal until after ~10min

In [7]:
from mcts import *
mcts = MCTS(OXO(width=5,height=5,inarow=4))
mcts.game.Move(12) # after this, found (6/8/16/18) is optimal after ~1,500,000 simulations
mcts.game.Move(11) # after this, found sure win policy after 260663 simulations
# mcts.game.Move(16) # after this, found sure loss after 189564 simulations
# mcts.game.Move(8)  # after this, found sure win policy after 38150 simulations
# mcts.game.Move(18) # after this, found sure loss after 25918 simulations
# mcts.game.Move(6)  # after this, found sure win policy after 1601 simulations
mcts.search(timemax=10,verbose=1)

20      |[W/V:    195/   390 | UnXplrd: 0] 
2       |[W/V:  271.5/   519 | UnXplrd: 0] 
4       |[W/V:    306/   577 | UnXplrd: 0] 
13      |[W/V:    323/   605 | UnXplrd: 0] 
0       |[W/V:    337/   628 | UnXplrd: 0] 
10      |[W/V:  336.5/   627 | UnXplrd: 0] 
14      |[W/V:  380.5/   700 | UnXplrd: 0] 
24      |[W/V:  412.5/   753 | UnXplrd: 0] 
22      |[W/V:    419/   763 | UnXplrd: 0] 
5       |[W/V:  620.5/  1091 | UnXplrd: 0] 
3       |[W/V:    730/  1267 | UnXplrd: 0] 
19      |[W/V:    866/  1485 | UnXplrd: 0] 
23      |[W/V:  886.5/  1517 | UnXplrd: 0] 
21      |[W/V:  901.5/  1541 | UnXplrd: 0] 
9       |[W/V:  922.5/  1575 | UnXplrd: 0] 
1       |[W/V:  924.5/  1578 | UnXplrd: 0] 
15      |[W/V: 1116.5/  1883 | UnXplrd: 0] 
8       |[W/V: 2561.5/  4143 | UnXplrd: 0] 
18      |[W/V:   3016/  4848 | UnXplrd: 0] 
17      |[W/V: 4022.5/  6401 | UnXplrd: 0] 
7       |[W/V: 5232.5/  8261 | UnXplrd: 0] 
6       |[W/V: 8363.5/ 13050 | UnXplrd: 0] 
16      |[W/V: 8591.5/ 13398 | U

16

## Othello

In [None]:
%%cython -I . -+ -a
from libc.stdlib cimport rand
from libcpp.vector cimport vector
from mcts cimport *
cimport numpy as np
cimport cython
import numpy.random as random
cdef extern from "<algorithm>" namespace "std":
    iter std_find "std::find" [iter, T](iter first, iter last, const T& val)
cdef extern from "<algorithm>" namespace "std":
    iter std_random_shuffle "std::random_shuffle" [iter](iter first, iter last)
    
cdef class Othello(Game):
    cdef public int size
    cdef public vector[int] board
    def __init__(self, int size = 8):
        self.playerToMove = 1
        self.size = size
        self.board = vector[int](size*size)
        self.board[(size//2)*size+size//2] = self.board[(size//2-1)*size+size//2-1] = 2
        self.board[(size//2-1)*size+size//2] = self.board[(size//2)*size+size//2-1] = 1

    cpdef void Move(self, int move):
        if move >= 0:
#             assert self.IsOnBoard(x,y) and self.board[move] == 0
            m = self.GetAllSandwichedCounters(move)
            self.board[move] = self.playerToMove
            for pos in m:
                self.board[pos] = self.playerToMove
        self.playerToMove = 3 - self.playerToMove
    
    cpdef list GetMoves(self):
        cdef vector[int] emptypos, viable
        for i in range(self.board.size()):
            if self.board[i]==0:
                emptypos.push_back(i)
        if emptypos.size()==0: 
            return []
        else:
            for e in emptypos:
                if self.ExistsSandwiched(e):
                    viable.push_back(e)
            if viable.size()>0:
                return viable
            else:  #need to check if opponent also has no viable pos
                self.playerToMove = 3 - self.playerToMove
                viable.clear()
                for e in emptypos:
                    if self.ExistsSandwiched(e):
                        viable.push_back(e)
                self.playerToMove = 3 - self.playerToMove
                if viable.size()>0:
                    return [-1]
                else:
                    return []
    
    cpdef void RollOut(self):
        cdef int bad=0 
#         emptypos=[pos for pos,e in enumerate(self.board) if e==0]
        cdef vector[int] emptypos
        for i in range(self.board.size()):
            if self.board[i]==0:
                emptypos.push_back(i)
        while emptypos.size()>0:
#             random.shuffle(emptypos)
            std_random_shuffle[vector[int].iterator]( emptypos.begin(), emptypos.end() )
            for pos in emptypos:
                if self.ExistsSandwiched(pos):
                    self.Move(pos)
                    bad=0
                    emptypos.clear()
                    for i in range(self.board.size()):
                        if self.board[i]==0:
                            emptypos.push_back(i)
                    break
            else:
                if bad<2: #when bad=2, both side cannot play, game ends
                    bad+=1
                    self.Move(-1)
                else:
                    break

    cdef vector[int] AdjacentEnemyDirections(self,int pos):# Speeds up GetMoves by only considering squares which are adjacent to an enemy-occupied square.
        cdef vector[int] direction
        cdef int x=pos//self.size, y=pos%self.size
        if self.IsOnBoard(x+1,y) and self.board[(x+1)*self.size+y] == 3-self.playerToMove:
            direction.push_back(0)
        if self.IsOnBoard(x+1,y+1) and self.board[(x+1)*self.size+y+1] == 3-self.playerToMove:
            direction.push_back(1)
        if self.IsOnBoard(x,y+1) and self.board[(x)*self.size+y+1] == 3-self.playerToMove:
            direction.push_back(2)
        if self.IsOnBoard(x-1,y+1) and self.board[(x-1)*self.size+y+1] == 3-self.playerToMove:
            direction.push_back(3)
        if self.IsOnBoard(x-1,y) and self.board[(x-1)*self.size+y] == 3-self.playerToMove:
            direction.push_back(4)
        if self.IsOnBoard(x-1,y-1) and self.board[(x-1)*self.size+y-1] == 3-self.playerToMove:
            direction.push_back(5)
        if self.IsOnBoard(x,y-1) and self.board[(x)*self.size+y-1] == 3-self.playerToMove:
            direction.push_back(6)
        if self.IsOnBoard(x+1,y-1) and self.board[(x+1)*self.size+y-1] == 3-self.playerToMove:
            direction.push_back(7)
        return direction
    
    cdef int ExistsSandwiched(self,int pos):# Is there at least one counter which would be flipped if my counter was placed at (x,y)? 
        cdef int x=pos//self.size, y=pos%self.size, dx,dy,x1,y1
        cdef vector[int] directions
        directions=self.AdjacentEnemyDirections(pos)
        for direction in directions:
            if direction==7 or direction==0 or direction==1:
                dx=1
            elif direction==3 or direction==4 or direction==5:
                dx=-1
            else:
                dx=0
            if direction==1 or direction==2 or direction==3:
                dy=1
            elif direction==5 or direction==6 or direction==7:
                dy=-1
            else:
                dy=0
            x1=x+dx
            y1=y+dy
            while self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == 3 - self.playerToMove:
                x1 += dx
                y1 += dy
            if self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == self.playerToMove:
                return 1
        return 0
    
    cdef vector[int] GetAllSandwichedCounters(self, int pos):# Is (x,y) a possible move (i.e. opponent counters are sandwiched between (x,y) and my counter in some direction)?
        cdef vector[int] sandwiched, more, directions
        cdef int direction
        directions=self.AdjacentEnemyDirections(pos)
        for direction in directions:
            more = self.SandwichedCounters(pos,direction)
            sandwiched.insert(sandwiched.end(), more.begin(), more.end());
        return sandwiched

    cdef vector[int] SandwichedCounters(self,int pos, int direction):# Return the coordinates of all opponent counters sandwiched between (x,y) and my counter.
        cdef int x=pos//self.size, y=pos%self.size, dx=0, dy=0
        cdef vector[int] sandwiched
        if direction==7 or direction==0 or direction==1:
            dx=1
        elif direction==3 or direction==4 or direction==5:
            dx=-1
        if direction==1 or direction==2 or direction==3:
            dy=1
        elif direction==5 or direction==6 or direction==7:
            dy=-1        
        x += dx
        y += dy
        while self.IsOnBoard(x,y) and self.board[x*self.size+y] == 3 - self.playerToMove:
            sandwiched.push_back(x*self.size+y)
            x += dx
            y += dy
        if self.IsOnBoard(x,y) and self.board[x*self.size+y] == self.playerToMove:
            return sandwiched
        else:
            sandwiched.clear()
            return sandwiched

    cdef int IsOnBoard(self, int x, int y):
        if x >= 0 and x < self.size and y >= 0 and y < self.size:
            return 1
        else:
            return 0
    
    cpdef double GetResult(self,int viewpoint):
        cdef int viewpointscore=0, oppositescore=0, e
        for e in self.board:
            if e==viewpoint: viewpointscore+=1
            elif e==3-viewpoint: oppositescore+=1
        if viewpointscore > oppositescore: return 1.0
        elif oppositescore > viewpointscore: return 0.0
        else: return 0.5 # draw

    def __repr__(self):
        s=''
        for i in range(self.size):
            for j in range(self.size):
                e=self.board[i*self.size+j]
                s+="·XO"[e]                    #does not allow s+="·XO"[self.board[i*self.width+j]]
            s+='\n'
        return s
    cpdef str key(self):
        cdef int e
        k=str(self.playerToMove)
        for e in self.board: k+="·XO"[e] 
        return k
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return key.count('·')<mykey.count('·') and \
               all(ch in '12·' or key[i]!='·' for i,ch in enumerate(mykey)) and\
               (mykey[1]=='·' or mykey[1]==key[1]) and\
               (mykey[self.size]=='·' or mykey[self.size]==key[self.size]) and\
               (mykey[-self.size]=='·' or mykey[-self.size]==key[-self.size]) and\
               (mykey[-1]=='·' or mykey[-1]==key[-1])

* timemax=15 beats Reversu Difficult
* timemax=40 beaten by Reversu Nightmare

In [None]:
game = Othello(8)
mcts = MCTS(game)
while mcts.game.GetMoves():
    if mcts.game.playerToMove==2:
        clear_output()
        board=np.array([f'{("·XO"[e] if e else str(p)):2s}' for p,e in enumerate(mcts.game.board)]).reshape(8,8)
        print(board)
        print(mcts.rootnode.ChildrenToString())
        print('choose from:',end='')
        moves = mcts.game.GetMoves()
        if moves[0] is None: m = None
        else: m=int(input(moves))
    else:
        m = mcts.search(verbose=1,timemax=45+len(list(filter(None,mcts.game.board)))*0 )  #len...=number of pieces on board
    mcts.game.Move(m)
    print(mcts.game)
    print(len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

# MCTS in Cython with multiprocessing?

## mcts library

In [1]:
%load_ext Cython
import os, tempfile
os.chdir(tempfile.mkdtemp())

In [2]:
%%file mcts.pxd
# cython: language_level=3
from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp.unordered_map cimport unordered_map
cdef class Game:
    cdef public int playerToMove
    cpdef list GetMoves(self)
    cpdef void Move(self,int move)
    cpdef void RollOut(self)
    cpdef str key(self)
    cpdef double GetResult(self,int viewpoint)
cdef class MCTS:
    cdef public unordered_map[string, vector[int]] untriedMoves
    cdef public unordered_map[string, unordered_map[int, string]] childNodes
    cdef public unordered_map[string, double] wins
    cdef public unordered_map[string, int] visits
    cdef public Game game
    cdef double simulate(self,Game game,dict node,double explore=?)

Writing mcts.pxd


In [None]:
%%cython -+ -a
# %%file mcts.pyx
cimport cython
from copy import deepcopy
import time, random
from libc.math cimport sqrt, log
from libc.stdlib cimport rand
from libcpp.vector cimport vector
from cython.operator cimport dereference as deref, preincrement
from numpy.math cimport INFINITY as inf

cdef class Game:
    cpdef list GetMoves(self): return []
    cpdef void Move(self,int move):pass
    cpdef void RollOut(self):pass
    cpdef str key(self): return "";
    cpdef double GetResult(self,int viewpoint): return 0.5;

cdef class MCTS():
    def __init__(self, Game game):
        self.game = game

    def search(self, timemax=None, itermax=None, explore=1, verbose=0):
        string rootkey=self.game.key()
        if 'Reachable' in dir(self.game): #pruning by (game provided) finding which node will be reachable
            self.nodes = {k:v for k,v in self.nodes.items() if k==key or self.game.Reachable(k)}
            
        if self.untriedMoves[rootkey].size()>0 or self.childNodes[rootkey].size()>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+100
                    for _ in range(100):
                        v=self.simulate(deepcopy(self.game),rootkey)
                    if abs(v)==inf: break
            else: #debug
                start=time.time()
                while True:
                    for _ in range(1000): v=self.simulate(deepcopy(self.game),rootkey)
                    print(time.time()-start,sum(c.visits for c in self.childNodes[rootkey].values()),
                          sorted(list(self.childNodes[rootkey].items()), 
                              key = lambda c: -c[1].wins/c[1].visits ),
                          end='\r')
                    if abs(v)==inf: break
                        
        moveToChild = self.UCTSelectMove(self.rootnode,explore=0)
        if verbose: 
            print(self.ChildrenToString(self.rootnode))
            print(sum(self.nodes[k]['visits'] for k in self.rootnode['childNodes'].values()))
            
        return moveToChild
    
    def ChildrenToString(self,node):
        return "\n".join(f'{str(mv):6s}-> [W/V: {n["wins"]:6g}/{n["visits"]:6.0f} | UnXplrd: {len(n["untriedMoves"])}]'
                            for mv,k in sorted(node['childNodes'].items(),key=lambda e: self.nodes[e[1]]['wins']/self.nodes[e[1]]['visits'])
                            for n in [self.nodes[k]]
                        )
        
    cdef UCTSelectMove(self,n,explore=1):
        cdef double bestval
        cdef int bestmv
        cdef string bestkey
        for mv,k in n['childNodes'].items():
            nn = self.nodes[k]
            curval = nn['wins']/nn['visits'] + explore*sqrt(log(n['visits'])/nn['visits'])
            if bestval is None or curval>bestval:
                bestmv = mv
                bestval = curval
                bestnn = nn
        return bestmv
    
    cdef double simulate(self,Game game,string key,double explore=1.0):
        cdef double v=0 #return value
        cdef int move=0
        if self.untriedMoves[key].size()==0 and self.childNodes[key].size()>0: #fully expanded, non-terminal
            move = self.UCTSelectMove(node,explore)
            nnode = self.nodes[node['childNodes'][move]]
            if nnode['wins']==inf or nnode['wins']==-inf: #if the best is -inf or inf already, can backprop
                node['wins']+= -nnode['wins']; node['visits']+=1
                return -nnode['wins']
            game.Move(move)
            v = self.simulate(game,nnode)
            if v==-inf: v=0
            node['wins']+=1-v; node['visits']+=1;
            return 1-v
        elif self.untriedMoves[key].size()>0: #not fully expanded: expand and then rollout
            move = random.choice(node['untriedMoves'])
            game.Move(move)
            nkey = game.key()
            if nk not in self.nodes:
                self.nodes[nkey] = {'viewpoint':3-game.playerToMove,
                                 'wins':0.,
                                 'visits':0.,
                                 'childNodes':{},
                                 'untriedMoves':game.GetMoves()}
            nnode = self.nodes[nkey]
            node['untriedMoves'].remove(move)
            node['childNodes'][move]=nkey
            game.RollOut()
            v = game.GetResult(nnode['viewpoint']) #not setting inf or -inf because there are indeterminism (unexplored moves)
            self.wins[nkey]+=v; self.visits[nkey]+=1
            self.wins[key]+=1.0-v; self.visits[key]+=1
            return 1.0-v
        elif self.childNodes[key] == {}: #terminal
            v = game.GetResult(node['viewpoint'])
            if v==1:    v=self.inf
            elif v==0:  v=-self.inf
            self.wins[key]+=v; self.visits[key]+=1
            return v

In [4]:
%%file setup.py
import os
import numpy
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

setup(  name = "cythoncode",
        cmdclass = {"build_ext": build_ext},
        ext_modules = [ Extension("mcts",
                                  sources=["mcts.pyx"],
                                  language='c++',   #using C++
                                  libraries=["m"],  #for using C's math lib
                                  extra_compile_args = ["-ffast-math"])],
        include_dirs=[numpy.get_include(),
                      os.path.join(numpy.get_include(), 'numpy')])

Writing setup.py


In [5]:
!python3 setup.py build_ext --inplace

running build_ext
cythoning mcts.pyx to mcts.cpp
building 'mcts' extension
creating build
creating build/temp.linux-x86_64-3.7
x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/home/hoi/.local/lib/python3.7/site-packages/numpy/core/include -I/home/hoi/.local/lib/python3.7/site-packages/numpy/core/include/numpy -I/usr/include/python3.7m -c mcts.cpp -o build/temp.linux-x86_64-3.7/mcts.o -ffast-math
x86_64-linux-gnu-g++ -pthread -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-Bsymbolic-functions -Wl,-z,relro -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 build/temp.linux-x86_64-3.7/mcts.o -lm -o /tmp/tmphrgz58xj/mcts.cpython-37m-x86_64-linux-gnu.so


## m,n,k-game

In [8]:
# %load_ext autoreload
# %autoreload 2
from mcts import *
mcts = MCTS(OXO(width=5,height=5,inarow=4))
mcts.game.Move(12) # after this, found (6/8/16/18) is optimal after ~1,500,000 simulations
mcts.game.Move(11) # after this, found sure win policy after 260663 simulations
# mcts.game.Move(16) # after this, found sure loss after 189564 simulations
# mcts.game.Move(8)  # after this, found sure win policy after 38150 simulations
# mcts.game.Move(18) # after this, found sure loss after 25918 simulations
# mcts.game.Move(6)  # after this, found sure win policy after 1601 simulations
mcts.search(timemax=10,verbose=1)

14    -> [W/V:  215.5/   426 | UnXplrd: 0]
4     -> [W/V:    244/   474 | UnXplrd: 0]
10    -> [W/V:    249/   483 | UnXplrd: 0]
20    -> [W/V:  318.5/   599 | UnXplrd: 0]
5     -> [W/V:  371.5/   687 | UnXplrd: 0]
0     -> [W/V:    399/   733 | UnXplrd: 0]
22    -> [W/V:  426.5/   778 | UnXplrd: 0]
23    -> [W/V:    484/   873 | UnXplrd: 0]
24    -> [W/V:    514/   922 | UnXplrd: 0]
19    -> [W/V:    542/   967 | UnXplrd: 0]
3     -> [W/V:  603.5/  1067 | UnXplrd: 0]
9     -> [W/V:  606.5/  1072 | UnXplrd: 0]
2     -> [W/V:    704/  1230 | UnXplrd: 0]
1     -> [W/V:    737/  1283 | UnXplrd: 0]
15    -> [W/V:    949/  1623 | UnXplrd: 0]
13    -> [W/V: 1127.5/  1908 | UnXplrd: 0]
21    -> [W/V: 1190.5/  2008 | UnXplrd: 0]
18    -> [W/V:   2393/  3898 | UnXplrd: 0]
8     -> [W/V: 2416.5/  3935 | UnXplrd: 0]
17    -> [W/V: 2607.5/  4233 | UnXplrd: 0]
7     -> [W/V:   4439/  7074 | UnXplrd: 0]
16    -> [W/V:   5639/  8924 | UnXplrd: 0]
6     -> [W/V:   6219/  9803 | UnXplrd: 0]
55000.0


6

In [6]:
%%cython -I . -+
from libc.stdlib cimport rand
from libcpp.vector cimport vector
from mcts cimport *
cimport cython
cdef extern from "<algorithm>" namespace "std":
    iter std_find "std::find" [iter, T](iter first, iter last, const T& val)
    
cdef class OXO(Game):
    cdef public int width
    cdef public int height
    cdef public vector[int] board
    cdef public vector[vector[int]] lines
    cdef public vector[int] empty
    
    def __init__(self,int width=3,int height=3,int inarow=0):
        cdef int i,j,iar
        cdef vector[int] v
        self.playerToMove = 1
        self.width = width
        self.height = height
        self.board = vector[int](width*height)
        for i in range(width*height):
            self.empty.push_back(i)
        
        if inarow==0: inarow=width*height
        for j in range(width-inarow+1):
            for i in range(height):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+iar)
                self.lines.push_back(v)
        for j in range(width):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+width*iar)
                self.lines.push_back(v)
        for j in range(width-inarow+1):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+(width+1)*iar)
                self.lines.push_back(v)
        for j in range(inarow-1,width):
            for i in range(height-inarow+1):
                v=vector[int](0)
                for iar in range(inarow):
                    v.push_back(i*width+j+(width-1)*iar)
                self.lines.push_back(v)
    cpdef void Move(self, int m):
        self.board[m] = self.playerToMove
        self.empty.erase(std_find[vector[int].iterator, int](self.empty.begin(), self.empty.end(), m))#         self.empty.remove(m)
        if self.GetResult(self.playerToMove) in [0,1]:  
            self.empty=vector[int](0)
        self.playerToMove = 3 - self.playerToMove
        
    cpdef list GetMoves(self):
        return self.empty
    
    cpdef void RollOut(self):
        cdef int sz = self.empty.size()
        cdef int mv
        while sz>0:
            self.Move(self.empty[rand()%sz])
            sz=self.empty.size()
    
    cpdef double GetResult(self,int viewpoint):
        cdef int myscore = 0, opposcore = 0, pos
        cdef vector[int] l
        for l in self.lines:
            one=self.board[l[0]]
            if one!=0:
                for pos in l:               # all(self.board[p]==one for p in l)
                    if self.board[pos]!=one:
                        break
                else:
                    if one == viewpoint:
                        myscore += 1
                    else:
                        opposcore += 1
        if myscore>opposcore: return 1.0
        elif opposcore>myscore: return 0.0
        return 0.5

    def __repr__(self):
        s=''
        for i in range(self.width):
            for j in range(self.height):
                e=self.board[i*self.width+j]
                s+="·XO"[e]                    #does not allow s+="·XO"[self.board[i*self.width+j]]
            s+='\n'
        return s
    cpdef str key(self):
        cdef int e
        k=str(self.playerToMove)
        for e in self.board: k+="·XO"[e] 
        return k
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '·12' or ch==key[i] for i,ch in enumerate(mykey))

Fail to draw on 5,5,4, 2nd: after first player places at center, 2nd player should play at a position "diagonal" to the center (6,8,16,18). MCTS failed to identify these four as optimal until after ~10min

# AlphaZero-style (guide tree search by policy+value network)

https://www.youtube.com/watch?v=ld28AU7DDB4
* The uncertainty term used in UCB is replaced by **Polynomial Upper Confidence Trees** PUCT $\propto P(s_i|s)\frac{\sqrt{N}}{1+n_i}$ (where N is the total visits of current state s, and n_i are visit counts of each of the next possible nodes). $P(s_i|s)$ is a probability output by policy network given current state
* Rollout replaced by the value estimation of the current state by the network. For a newly-expanded child node (c), use value network to compute all the values of its possible child nodes (c's c), used for computing the mean of the current (c) node, with weights given by the policy network.
* In AlphaGo, policy network is trained with human expert positions. Then policy network is then used to play against itself to generate positions + target (win/loss) for training the value network
* AlphaZero combines policy and value network into a single network. Repeat:
  1. training examples (state,policy of next move,win/loss) are generated by MCTS, used to train the policy+value network.
  2. Use the network trained to guide MCTS
  
Other tweaks
* in the PUCT, normalize (sum of all=1) and *then* exponentiate all counts before putting into formula $N^{1/\tau}$ and $n_i^{1/\tau}$, where $\tau=1$ for the first 10-100 moves (exploratory) and ~0 for later moves (only choose the best).
* After each iteration, pit against previous version of the net to make sure it has really improved

## [Code](http://web.stanford.edu/~surag/posts/alphazero.html) (modified)

`Game` class requires:
* `reset()` returns init board (numpy array passable to net)
* `actionSize` the largest possible number of moves (not given a state). e.g. Othello is size^2-4+1 (for pass)
* `getValidMoves()` returns np.array of 0/1 indicting the corresponding index action is a valid move or not at the moment
* `board1`: Board in such a way that the one next move is always same set of numbers in the numpy array. e.g., for othello if playerToMove=2, return board*(-1). e.g. in chess if playerToMove=2, exchange pieces number 1-16 with number 17-32
* `symConfig()`: input a policy (prob dist), generate tuples of (board**1**,policy) due to board symmetry

### MCTS given guidance from Net, generates (states,policies,values) through self-plays

In [114]:
class Node:
    def __init__(self, state = None):
        self.childNodes = {} # {move:nextnode} the move that get *into* next node
        self.policy = {}     # {move:prob}
        self.viewpoint = 3-state.playerToMove
        self.wins = 0        #for the player Just Moved = self.viewpoint
        self.visits = 0
        
    def UCTSelectChild(self,explore=1):
        mv = max(self.policy, 
                 key = lambda k: self.childNodes[k].wins/(self.childNodes[k].visits) + \
                       explore*self.policy[k]*math.sqrt(self.visits)/(1+self.childNodes[k].visits)\
                       if k in self.childNodes else explore*self.policy[k]*math.sqrt(self.visits+1e-8)
                )
        return mv
    
    def AddChild(self, move, n):
        self.childNodes[move] = n
    
    def Update(self, result):
        self.visits += 1
        self.wins += result # game results in the range [0.0, 1.0]
    def __repr__(self):        return f"[W/V: {self.wins:6g}/{self.visits:6d}] "
    def ChildrenToString(self):return "\n".join(f'{str(m):8s}|{c}' for m,c in sorted(self.childNodes.items(),key=lambda e: e[1].wins/e[1].visits))

In [138]:
from collections import deque
from itertools import count
import numpy as np

class MCTS():
    def __init__(self, game, net):
        self.game = game
        self.net = net
        self.numMCTSSims=25
        self.inf = float('inf')
        self.nodes = {}
    def search(self, timemax=None, itermax=None, explore=1, verbose=False):
        key=self.game.key()
        if 'Reachable' in dir(self.game):
            self.nodes = {k:v for k,v in self.nodes.items() if k==key or self.game.Reachable(k)}

        if len(self.rootnode.policy)>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+100
                    for _ in range(100):
                        v=self.simulate(deepcopy(self.game),self.rootnode)
                    if v==self.inf: break
            else: #debug
                start=time.time()
                while True:
                    for _ in range(1000): self.simulate(deepcopy(self.game),self.rootnode)
                    print(time.time()-start,sum(c.visits for c in self.rootnode.childNodes.values()),
                          sorted(list(self.rootnode.childNodes.items()), 
                              key = lambda c: -c[1].wins/c[1].visits ),
                          end='\r')
                    
        moveToChild = self.rootnode.UCTSelectChild(explore=0)
        if verbose: print(self.rootnode.ChildrenToString())

        return moveToChild    
    
    def net_predict(self,board1,valids): # returns a dict of {act:prob} for the valid moves, and a value of broad1's position
        probs, v = self.net.predict(board1[np.newaxis])
        probs = probs[0]*valids
        sum_probs = np.sum(probs)
        if sum_probs > 0:
            probs /= sum_probs 
        else:# print("All valid moves were masked, do workaround.")
            probs = probs + valids
            probs /= np.sum(probs)
        return {a:probs[a] for a,val in enumerate(valids) if val}, v[0]
    
    def selfplay(self,temperature=lambda s: 1 if s<5 else 0): #train batch from one complete game
        """
        This function executes one episode of self-play, starting with player 1.
        As the game is played, each turn is added to trainExamples. After the game
        ends, the outcome of the game is used to assign values to each example.
        It uses a temp=1 if episodeStep < tempThreshold, and thereafter
        uses temp=0.
        Returns:
            trainExamples: a list of examples of the form (canonicalBoard,pi,v)
                           pi is the MCTS informed policy vector, v is +1 if
                           the player eventually won the game, else -1.
        """
        trainExamples = []
        board = self.game.reset()

        for movestep in count():
            tau = temperature(movestep)

            probs = self.getActionProb(tau)
            
            for board1,policy in self.game.symConfig(probs):
                trainExamples += (self.game.playerToMove, board1.astype(np.float32), policy),  #value target not known yet

            action = np.random.choice(len(probs), p=probs)
            self.game.Move(action)

            if not any(self.game.getValidMoves()):
                r = self.game.GetResult(self.game.playerToMove)
                return [(e[1],e[2],r if e[0]==self.game.playerToMove else 1-r) for e in trainExamples]
        
    @property  #  called when accessing self.rootnode
    def rootnode(self):
        key = self.game.key()
        if key not in self.nodes:
            self.nodes[key] = Node(state = self.game)
            probs, v = self.net_predict(self.game.board1,self.game.getValidMoves())
            self.nodes[key].policy = probs
        return self.nodes[key]

    def getActionProb(self, temp=1): #Returns: probs for next move by performing multiple searches starting from current state#
        for _ in range(self.numMCTSSims):
            self.simulate(deepcopy(self.game), self.rootnode)
            
        get_count=lambda c: 1e8 if c.wins==self.inf else 1e-8 if c.wins==-self.inf else c.visits #1e-8 to avoid all-zero
        counts = [get_count(self.rootnode.childNodes[a]) if a in self.rootnode.childNodes else 0 for a in range(self.game.actionSize)]
        if temp==0:
            bestA = np.argmax(counts)
            probs = [0]*len(counts)
            probs[bestA]=1
        else:
            counts = [x**(1./temp) for x in counts]
            probs = [x/float(sum(counts)) for x in counts]
        return probs

    def simulate(self, state, node, explore=1): #not aiming to fully expand every node
        """
        At a newly expanded state, the net is called to return an initial P and v for the state.
        If the leaf node is a terminal state, the value is v. the value is propogated up the search path
        """
        if node.policy == {}: #terminal
            v = state.GetResult(node.viewpoint)
            if v==1:    v=self.inf
            elif v==0:  v=-self.inf
            node.Update(v)
            return v
        else:                 # non-terminal
            move = node.UCTSelectChild(explore) #descend (following UCB)
            state.Move(move)
            if move in node.childNodes:
                nnode = node.childNodes[move]
                if nnode.wins==self.inf or nnode.wins==-self.inf: #if the best is -inf or inf already, can backprop
                    node.Update(-nnode.wins)
                    return -nnode.wins
            else:
                k = state.key()                
                if k in self.nodes:
                    nnode = self.nodes[k]
                else:
                    nnode = Node(state = state)
                    valid = state.getValidMoves()
                    if np.any(valid):
                        probs, v = self.net_predict(state.board1,valid)
                        nnode.policy = probs
                    self.nodes[k] = nnode
                node.AddChild(move,nnode)
            
            v = self.simulate(state,nnode)
            if v==-self.inf: v=0 # if v==-inf, next mover will not choose this branch as it leads to loss.
                                 # Note: other branch may have finite v as UCT  saw this branch as finite v before simulation
#             nnode.Update(v)
            node.Update(1-v)
            return 1-v

### Train net with dataset (states,policies,values) given by the self-plays of MCTS

In [8]:
import tensorflow as tf
@tf.function
def train_step(net, boardstates, policies_target, values_target, optimizer):
    with tf.GradientTape() as tape:
        policies, values = net(boardstates, training=True)
        loss_p = tf.losses.softmax_cross_entropy(policies_target, policies)
        loss_v = tf.losses.mean_squared_error(values_target, values)
        combined_loss = loss_p + loss_v

    grad = tape.gradient(combined_loss, net.trainable_variables)
    optimizer.apply_gradients(zip(grad, net.trainable_variables))

In [None]:
# Example training script with a **Game** class and a **net**
from collections import deque
train_set = deque(maxlen=200000)
mcts = MCTS(Game(),net)
optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

for _ in range(100):
    train_set += mcts.selfplay()
    for s,p,v in tf.data.Dataset.from_tensor_slices(zip(*train_set)).shuffle(1000).batch(64):
        train_step(net, s,p,v, optimizer)

### MCTS using trained Net: given state, generate next move

In [None]:
# Example running) through script with a **Game** class and a **net**
game = OXO(5,5,4)
mcts = MCTS(game,net)
while np.any(mcts.game.getValidMoves()):
    m = mcts.search(timemax=10,verbose=1)
    mcts.game.Move(m)
    print(mcts.game, len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

## [m,n,k-game](https://en.wikipedia.org/wiki/M,n,k-game)

In [139]:
class OXO:
    def __init__(self,width=3,height=3,inarow=None,end1=True): # game ends as one success
        self.playerToMove = 1
        self.width = width
        self.height = height
        self.end1 = end1
        self.actionSize = width*height
        self.reset()        
        inarow = inarow or min(width,height)
        self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]
    
    @property
    def board1(self):
        return self.board.reshape(self.height,self.width) if self.playerToMove==1\
               else np.mod(3-self.board,3).reshape(self.height,self.width)
        
    def reset(self):
        self.playerToMove = 1
        self.board = np.zeros(self.width*self.height,dtype=np.int32)
        self.empty = list(range(self.width*self.height))
        
    def getValidMoves(self): 
        return np.clip(1-self.board,0,None) if self.empty else np.zeros(self.actionSize)

    def symConfig(self,pi):
        board1 = self.board1
        pi=np.array(pi).reshape(self.height,self.width)
        return [(board1,pi.flatten()),
                (np.rot90(board1,1),np.rot90(pi,1).flatten()),
                (np.rot90(board1,2),np.rot90(pi,2).flatten()),
                (np.rot90(board1,3),np.rot90(pi,3).flatten()),
                (np.flip(board1,0), np.flip(pi,0).flatten()),
                (np.rot90(np.flip(board1,0),1), np.rot90(np.flip(pi,0),1).flatten()),
                (np.rot90(np.flip(board1,0),2), np.rot90(np.flip(pi,0),2).flatten()),
                (np.rot90(np.flip(board1,0),3), np.rot90(np.flip(pi,0),3).flatten()),
               ]

    def Move(self, move):
        assert move in self.empty
        self.board[move] = self.playerToMove
        self.empty.remove(move)
        if self.end1 and self.GetResult(self.playerToMove) in [0,1]:  
            self.empty=[]
        self.playerToMove = 3 - self.playerToMove

    def GetResult(self, viewpoint):
        myscore = opposcore = 0
        for l in self.lines:
            one=self.board[l[-1]]
            if one!=0 and all(self.board[p]==one for p in l):
                if one == viewpoint: myscore += 1
                else: opposcore += 1
        if myscore>opposcore: return 1
        elif opposcore>myscore: return 0
        return 0.5

    def __repr__(self): return '\n'.join(''.join("·XO"[self.board[i*self.width+j]] for j in range(self.width)) for i in range(self.height))    
    def key(self): return str(self.playerToMove)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key): return all(ch in '·12' or ch==key[i] for i,ch in enumerate(self.key()))

In [140]:
import tensorflow as tf
from tensorflow.keras.layers import *

def genNet(oxo):
    inp = tf.keras.layers.Input(shape=[oxo.height,oxo.width])
    x = Reshape([oxo.height,oxo.width,1])(inp)
    x = Activation('relu')(BatchNormalization(axis=3)(Conv2D(512, 2, padding='same')(x)))         # batch_size  x board_x x board_y x num_channels
    x = Activation('relu')(BatchNormalization(axis=3)(Conv2D(512, 2, padding='same')(x)))         # batch_size  x board_x x board_y x num_channels
    x = Activation('relu')(BatchNormalization(axis=3)(Conv2D(512, 2, padding='valid')(x)))        # batch_size  x (board_x-2) x (board_y-2) x num_channels
    x = Activation('relu')(BatchNormalization(axis=3)(Conv2D(512, 2, padding='valid')(x)))        # batch_size  x (board_x-4) x (board_y-4) x num_channels
    x = Flatten()(x)       
    
    s_fc1 = Dropout(.3)(Activation('relu')(BatchNormalization(axis=1)(Dense(1024)(x))))  # batch_size x 1024
    s_fc2 = Dropout(.3)(Activation('relu')(BatchNormalization(axis=1)(Dense(512)(x))))   # batch_size x 1024
    p = Dense(oxo.actionSize, activation='softmax', name='pi')(s_fc2)
    v = Dense(1, activation='tanh', name='v')(s_fc2)
    
    m = tf.keras.Model(inputs=inp, outputs=[p,v])
    m.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=tf.keras.optimizers.Adam(0.001))
    return m

In [None]:
import tensorflow as tf
from collections import deqrrue

game = OXO(5,5,4)
# net = genNet(game)
train_set = deque(maxlen=12800)
mcts = MCTS(game,net)

while True:
    train_set += mcts.selfplay()
    clear_output()
    s,p,v=map(np.asarray,zip(*random.sample(train_set,len(train_set))))
    net.fit(x = s, y = [p, v], verbose=1, batch_size=64)

Train on 12800 samples


single game

In [None]:
game = OXO(5,5,4)
mcts = MCTS(game,net)
for s,p,v in mcts.selfplay():
    print(s,'\n',p.reshape(5,5),v)

full game playout

In [None]:
game = OXO(5,5,4)
mcts = MCTS(game,net)
while np.any(mcts.game.getValidMoves()):
    m = mcts.search(timemax=10,verbose=1)
    mcts.game.Move(m)
    print(mcts.game, len(mcts.nodes))
if mcts.game.GetResult(1) == 1:   print("Player 1(X) wins!")
elif mcts.game.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

## [Connect4](https://github.com/plkmo/AlphaZero_Connect4)

## [Othello](https://github.com/suragnair/alpha-zero-general)