In [1]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np

MCTS is equivalent Monte Carlo control with UCB applied to the sub-MDP starting from the current state, where table lookup is represented by a tree. Since the action leads to deterministic results, $q(s,a)=v(s'_{\rm a})$. $\gamma=1$ and r=0 until endgame

# Barebone MCTS with UCT
https://youtu.be/ItMutbeOHtc?t=3922, https://www.youtube.com/watch?v=Bm7zah_LrmE

1. From each node (a node = a state of the board), repeat trying unexplored moves and rollout randomly unil gameover. The win/lose count is propagated back along the path (only above or at the level where new moves are being tried out -- for memory constraints) taken for each node
2. when all moves of the node have been tried at least once, *select* the one to try based on **Upper Confidence Bound** UCB (mean value + C\*uncertainty of the mean $\propto\sqrt{\frac{\log N_{\rm visit}}{n_{\rm child's visit}}}$). Repeat doing this until there is at least an untried move at the node
3. repeat 1&2 until desired time/iteration limit, then select the best move at root node (current real situation of the board) based on *mean value* of each move

4. consider some nodes may be reached from different ways; so keep the {move: child_node} dictionary. when selecting child node, return the corresponding move and child_node also (do not use parent of a node to backtrack)
5. ONLY for terminating (acyclic) games
6. Keep the nodes that (in principle) could be reached from the current state in the future. Game class need to implement `Reachable` wrt its key for this to work. Otherwise keep every node (memory...)
7. prune tree from the leaf. If searched to the end game already, we know the value of node exactly at that point, and can know if previous state will definitely (not) move towards there.

## Code: [mcts.ai](http://mcts.ai/code/python.html) modified + reuse nodes

instance of game class is passed to `MCTS`'s `search` function, and must have:
* `playerJustMoved`: 1 or 2
* `GetMoves()`: returns list of (hashable) valid moves from current state
* `Move(m)`: execute the passed in move (which is guaranteed to be chosen from `GetMoves()`)
* `RollOut()`: randomly execuate the game until game over
* `GetResult(viewpoint)`: (after game over) returns the reward (1 or 0) from the specified viewpoint (1 or 2 -- corresponding to playerJustMoved)
* `key()`: returns a string fully and uniquely characterize the status of the game
* [Optional]`Reachable(k)`: returns whether (True/False) the game could in principle reach the state represented by the passed-in key `k`, which was generated by the `key()` function of another state. If not provided, prune tree by deleting nodes not visited during a search

In [13]:
class Node:
    """ Wins is always from the viewpoint of playerJustMoved."""
    def __init__(self, state = None):
        self.childNodes = {} # {move:node} the move that get *into* next node
        self.wins = 0  #for the playerJustMoved
        self.visits = 0
        self.untriedMoves = state.GetMoves() # future child nodes
        self.playerJustMoved = state.playerJustMoved # the only part of the state that the Node needs later
        
    def UCTSelectChild(self,explore=1):
        """ Use the UCB1 formula to select a child node. Often a constant UCTK is applied so we have
            lambda c: c.wins/c.visits + UCTK * sqrt(2*log(self.visits)/c.visits to vary the amount of
            exploration versus exploitation. 
        """
        return max(list(self.childNodes.items()), 
                   key = lambda c: c[1].wins/c[1].visits + explore*math.sqrt(math.log(self.visits)/c[1].visits) )
    
    def AddChild(self, move, s, n = None):
        """ Remove m from untriedMoves and add a new child node for this move. if provided, use that  """
        if n is None:
            n = Node(state = s)
        self.untriedMoves.remove(move)
        self.childNodes[move] = n
        return n
    
    def Update(self, result, visits=None):
        """ update visit & win counts. (from the viewpoint of playerJustMoved) """
        self.visits += 1
        self.wins += result # game results in the range [0.0, 1.0]

    def __repr__(self):
        return f"[W/V: {self.wins:6g}/{self.visits:6d} | UnXplrd: {len(self.untriedMoves)}] "

    def TreeToString(self, indent):
        s = "\n"+ "| "*indent + str(self)
        s += ''.join(c.TreeToString(indent+1)+str(m) for m,c in self.childNodes.items())
        return s

    def ChildrenToString(self):        
        return "\n".join(f'{str(m):8s}|{c}' for m,c in sorted(self.childNodes.items(),key=lambda e: e[1].wins/e[1].visits))
    
class MCTS():   #reuses nodes
    def __init__(self, explore=1, verbosity=0):
        self.nodes = {}               #store all nodes previously explored
        self.explore = explore
        self.verbosity = verbosity
        self.inf = float('inf')

    def search(self, state, timemax=None, itermax=None):
        
        key = state.key()
        if key not in self.nodes:
            self.nodes[key] = Node(state = state)
        self.rootnode = self.nodes[key]
        
        if 'Reachable' in dir(state): #pruning by (game provided) finding which node will be reachable
            self.nodes = {k:v for k,v in self.nodes.items() if k==key or state.Reachable(k)}
                    
        if self.rootnode.untriedMoves or len(self.rootnode.childNodes)>1:  #unimportant -- determine when to stop
            if itermax or timemax:
                start=time.time()
                i = 0
                while (timemax is None or time.time()<start+timemax) and\
                      (itermax is None or i<itermax):
                    i = i+100
                    for _ in range(100):
                        v=self.simulate(deepcopy(state),self.rootnode)
                    if v==self.inf: break
            else: #debug
                while True:
                    for _ in range(10000): self.simulate(deepcopy(state),self.rootnode)
                    print(sum(c.visits for c in self.rootnode.childNodes.values()),
                          sorted(list(self.rootnode.childNodes.items()), 
                              key = lambda c: -c[1].wins/c[1].visits ),
                          end='\r')
                    
        moveToChild, bestChild = self.rootnode.UCTSelectChild(explore=0)
        if self.verbosity==2:   print(self.rootnode.TreeToString(0))
        elif self.verbosity==1: print(self.rootnode.ChildrenToString())
            
        return moveToChild
    
    def simulate(self,state,node):
        if node.untriedMoves == [] and node.childNodes != {}: #fully expanded, non-terminal
            move,nnode = node.UCTSelectChild(self.explore) #descend (following UCB) 
            if nnode.wins==self.inf or nnode.wins==-self.inf: #if the best is -inf or inf already, can backprop
                node.Update(-nnode.wins)
                return -nnode.wins
            state.Move(move)
            v = self.simulate(state,nnode)
            if v==-self.inf: v=0 # if v==-inf, next mover will not choose this branch as it leads to loss.
                            # Note: other branch may have finite v as UCT  saw this branch as finite v before  simulation
            node.Update(1-v)
            return 1-v
        elif node.untriedMoves != []: #not fully expanded: expand and then rollout
            move = random.choice(node.untriedMoves) 
            state.Move(move)
            k = state.key()
            if k in self.nodes:
                nnode = self.nodes[k]
            else:
                nnode = Node(state = state)
                self.nodes[k] = nnode
            node.AddChild(move,state,nnode)
            state.RollOut()
            v = state.GetResult(nnode.playerJustMoved) #not setting inf or -inf because there are indeterminism (unexplored moves)
            nnode.Update(v)
            node.Update(1-v)
            return 1-v
        elif node.childNodes == {}: #terminal
            v = state.GetResult(node.playerJustMoved)
            if v==1:    v=self.inf
            elif v==0:  v=-self.inf
            node.Update(v)
            return v
    
    def __simulate(self,state,node): #old, iterative version
        path = [node]
        #descend (following UCB) to the first branch not fully expanded (some moves not tried), or reaching end of game
        while node.untriedMoves == [] and node.childNodes != {}:
            move,node = node.UCTSelectChild(self.explore)
            if node.wins==float('inf') or node.wins==float('-inf'): #if the best is -inf or inf, can backprop already
                path[-1].Update(-node.wins)
                return
            state.Move(move)
            path += node,
        
        # if there are unexplored expand (add a childNode) and move the state into it 
        if node.untriedMoves != []: 
            move = random.choice(node.untriedMoves) 
            state.Move(move)
            k = state.key()
            if k in self.nodes:
                nextnode = self.nodes[k]
            else:
                nextnode = Node(state = state)
                
            node.AddChild(move,state,nextnode)
            self.nodes[k] = nextnode
            node = nextnode
            path += node,
        elif node.childNodes == {}: #nothing unexplored and is terminal: delete this child for its parent if loss
            reward = state.GetResult(node.playerJustMoved)
            if reward==1:
                node.Update(float('inf'))
            elif reward==0:
                node.Update(float('-inf'))

        # Rollout to END of a game randomly (not expanding childNodes -- just want to estimate the newly added node's value)
        state.RollOut()

        # state is now terminal; backpropagate this game's result to its path's nodes' win counts
        reward = state.GetResult(node.playerJustMoved)
        for n in path:
            if n.playerJustMoved == node.playerJustMoved:
                n.Update(reward)
            else:
                n.Update(1-reward)
        if self.prune: self.visitedNode|=set(path)

## [m,n,k-game](https://en.wikipedia.org/wiki/M,n,k-game)

In [3]:
#contains current state and who has last moved
import itertools
class OXOState:
    def __init__(self,width=3,height=3,inarow=None,moves_per_round=None,end1=True): # game ends as one success
        self.playerJustMoved = 2 #  (1) will have the first move
        self.width = width
        self.height = height
        self.board = [0]*width*height # 0 = empty, 1 = player 1, 2 = player 2
        self.empty = list(range(width*height))
        self.moves_per_round = moves_per_round or 1
        self.end1 = end1
        
        # winning lines
        inarow = inarow or min(width,height)
        self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]

    def Move(self, move):
        """ Update a state by carrying out the given move. Must also update playerJustMoved.  """
        self.playerJustMoved = 3 - self.playerJustMoved
        for m in move:
            self.board[m] = self.playerJustMoved
            self.empty.remove(m)
        if self.end1 and self.GetResult(self.playerJustMoved) in [0,1]:  
            self.empty=[]
        
    def GetMoves(self): # must clone as Node use this return as untriedMoves  # return self.empty[:]
        return list(itertools.combinations(self.empty,self.moves_per_round))
    
    def RollOut(self):
        possibleMoves = self.GetMoves()
        while possibleMoves:
            self.Move(random.choice(possibleMoves))
            possibleMoves = self.GetMoves()
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        myscore = opposcore = 0
        for l in self.lines:
            one=self.board[l[-1]]
            if one!=0 and all(self.board[p]==one for p in l):
                if one == viewpoint:
                    myscore += 1
                else:
                    opposcore += 1
        if myscore>opposcore: return 1
        elif opposcore>myscore: return 0
        return 0.5

    def __repr__(self): # for output: 1 (X), 2 = (O)
        s = ''
        for left in range(0,self.width*self.height,self.width):
            for sh in range(self.width): s+="·XO"[self.board[left+sh]]
            s+='\n'
        return s.strip()
    
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '·12' or ch==key[i] for i,ch in enumerate(mykey))

Fail to draw on 5,5,4, 2nd: after first player places at center, 2nd player should play at a position "diagonal" to the center (6,8,16,18). MCTS failed to identify these four as optimal until after ~10min

In [None]:
state = OXOState(width=5,height=5,inarow=4)
mcts = MCTS(verbosity=1,explore=1)
state.Move((12,)) # after this, found (7,11,13,17) leads to sure loss after ? simulations
state.Move((11,)) # after this, found sure win policy after 260663 simulations
state.Move((16,)) # after this, found sure loss after 189564 simulations
state.Move((8,))  # after this, found sure win policy after 35905 simulations
state.Move((18,)) # after this, found sure loss after 25918 simulations
state.Move((6,))  # after this, found sure win policy after 1601 simulations
mcts.search(state,timemax=None)

In [None]:
state = OXOState(width=3,height=3,inarow=3,moves_per_round=1,end1=True) #traditional
# state = OXOState(width=5,height=5,inarow=4) #http://boulter.com/ttt/
# state = OXOState(width=8,height=8,inarow=4,moves_per_round=2,end1=False) #http://www.atksolutions.com/games/tictactoedeluxe.html
# state = OXOState(width=15,height=15,inarow=5) #Gomoku
mcts = MCTS(verbosity=1,explore=1)
while state.GetMoves():
    if state.playerJustMoved==1: #computer plays 1st/2nd
#         clear_output()
        board=np.array([f'{("·XO"[s] if s else str(pos)):2s}' for pos,s in enumerate(state.board)]).reshape(state.height,state.width)
        print(board)
        if state.key() in mcts.nodes:
            n=mcts.nodes[state.key()]
            print(n.ChildrenToString())
            print(f'choose from: {state.GetMoves()[:10]}...',end='')
        m = (int(input()),)   #2 moves (int(input()),int(input()))
    else:
        m = mcts.search(state,timemax=10)
    state.Move(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1(X) wins!")
elif state.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

## Connect4

In [27]:
class ConnectNState:
    def __init__(self,width=7,height=6,inarow=4):
        self.playerJustMoved = 2
        self.width = width
        self.height = height
        self.board = [0]*width*height
        self.moves = list(range(width))
        self.firstEmpty = [height-1]*width
        
        # winning lines
        inarow = inarow or min(width,height)
        self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                     [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]

    def Move(self, move):  #move=column idx
        assert move in self.moves
        self.playerJustMoved = 3 - self.playerJustMoved
        self.board[self.firstEmpty[move]*self.width+move]=self.playerJustMoved
        self.firstEmpty[move] -= 1
        if self.firstEmpty[move]<0:
            self.moves.remove(move)
        if self.GetResult(self.playerJustMoved) in [0,1]:  #indicate game ended as someone won
            self.moves=[] 
        
    def GetMoves(self):  # legal moves for next player. empty if game is over
        return self.moves[:]
    
    def RollOut(self):
        while self.moves:
            self.Move(random.choice(self.moves))
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        for l in self.lines:
            one=self.board[l[-1]]
            if one!=0 and all(self.board[p]==one for p in l):
                if one == viewpoint:
                    return 1
                else:
                    return 0
        return 0.5

    def __repr__(self): # string representation: 1 (X), 2 = (O)
        s = ''
        for left in range(0,self.width*self.height,self.width):
            for sh in range(self.width):
                s+="·XO"[self.board[left+sh]]
            s+='\n'
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '12·' or ch==key[i] for i,ch in enumerate(mykey)) 

Fails to deduce that the 5th move (1st player) should be at center

In [None]:
state = ConnectNState(7,6,4) #https://connect4.gamesolver.org/
for _ in range(4): state.Move(3)
mcts = MCTS(verbosity=1,explore=2)
while state.GetMoves():
    if state.playerJustMoved==1:
#         clear_output()
        fn = lambda s: s if int(s)<7 else ''
        board=np.array([f'{("·XO"[s] if s else fn(str(pos))):2s}' for pos,s in enumerate(state.board)])
        print(board.reshape(state.height,state.width))
        if state.key() in mcts.nodes:
            n=mcts.nodes[state.key()]
            print(n.ChildrenToString())
            print('choose from:',end='')
        m = int(input(state.GetMoves()))
    else:
        m = mcts.search(state,timemax=None)
    state.Move(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1(X) wins!")
elif state.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

[(0, [W/V:  45806/ 80549 | UnXplrd: 0] ), (6, [W/V: 38631.5/ 68198 | UnXplrd: 0] ), (2, [W/V: 27781.5/ 49461 | UnXplrd: 0] ), (4, [W/V: 27232.5/ 48510 | UnXplrd: 0] ), (3, [W/V: 20572.5/ 36958 | UnXplrd: 0] ), (1, [W/V: 18619.5/ 33560 | UnXplrd: 0] ), (5, [W/V: 17588.5/ 31764 | UnXplrd: 0] )]

## Othello

In [14]:
class OthelloState:
    def __init__(self, size = 8):  # size must be integral and even
        self.playerJustMoved = 2
        self.size = size
        self.board = [] # 0 = empty, 1 = player 1, 2 = player 2
        for y in range(self.size):
            self.board.append([0]*size)
        self.board[size//2][size//2] = self.board[size//2-1][size//2-1] = 2
        self.board[size//2][size//2-1] = self.board[size//2-1][size//2] = 1
        self.board = [e for l in self.board for e in l]

    def Move(self, move):
        if move is not None:
            (x,y)=divmod(move,self.size)
            assert self.IsOnBoard(x,y) and self.board[x*self.size+y] == 0
            m = self.GetAllSandwichedCounters(x,y)
            self.playerJustMoved = 3 - self.playerJustMoved
            self.board[x*self.size+y] = self.playerJustMoved
            for (a,b) in m:
                self.board[a*self.size+b] = self.playerJustMoved
        else: 
            self.playerJustMoved = 3 - self.playerJustMoved
    
    def GetMoves(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        if not emptypos:
            return []
        else:
            viable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))] 
            if viable:
                return viable
            else:  #need to check if opponent also has no viable pos
                self.playerJustMoved = 3 - self.playerJustMoved
                oppoViable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))]
                self.playerJustMoved = 3 - self.playerJustMoved
                if oppoViable:
                    return [None]
                else:
                    return []
    
    def RollOut(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        bad=0 
        while emptypos:
            random.shuffle(emptypos)
            for pos in emptypos:
                if self.ExistsSandwiched(*divmod(pos,self.size)):
                    self.Move(pos)
                    emptypos = [i for i,e in enumerate(self.board) if e==0]
                    break
            else:
                if bad<2: #when bad=2, both side cannot play, game ends
                    bad+=1
                    self.Move(None)
                else:
                    break

    def AdjacentEnemyDirections(self,x,y):# Speeds up GetMoves by only considering squares which are adjacent to an enemy-occupied square.
        return [(dx,dy) for (dx,dy) in [(0,+1),(+1,+1),(+1,0),(+1,-1),(0,-1),(-1,-1),(-1,0),(-1,+1)]
                        if self.IsOnBoard(x+dx,y+dy) and self.board[(x+dx)*self.size+y+dy] == self.playerJustMoved]
    
    def ExistsSandwiched(self,x,y):# Is there at least one counter which would be flipped if my counter was placed at (x,y)? 
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            x1=x+dx
            y1=y+dy
            while self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == self.playerJustMoved:
                x1 += dx
                y1 += dy
            if self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == 3 - self.playerJustMoved:
                return True
        return False
    
    def GetAllSandwichedCounters(self, x, y):# Is (x,y) a possible move (i.e. opponent counters are sandwiched between (x,y) and my counter in some direction)?
        sandwiched = []
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            sandwiched.extend(self.SandwichedCounters(x,y,dx,dy))
        return sandwiched

    def SandwichedCounters(self, x, y, dx, dy):# Return the coordinates of all opponent counters sandwiched between (x,y) and my counter.
        x += dx
        y += dy
        sandwiched = []
        while self.IsOnBoard(x,y) and self.board[x*self.size+y] == self.playerJustMoved:
            sandwiched.append((x,y))
            x += dx
            y += dy
        if self.IsOnBoard(x,y) and self.board[x*self.size+y] == 3 - self.playerJustMoved:
            return sandwiched
        else:
            return [] # nothing sandwiched

    def IsOnBoard(self, x, y):
        return x >= 0 and x < self.size and y >= 0 and y < self.size
    
    def GetResult(self, viewpoint): #after gameover
        viewpointscore=oppositescore=0
        for e in self.board:
            if e==viewpoint: viewpointscore+=1
            elif e==3-viewpoint: oppositescore+=1
        if viewpointscore > oppositescore: return 1.0
        elif oppositescore > viewpointscore: return 0.0
        else: return 0.5 # draw

    def __repr__(self):
        s= ""
        for x in range(self.size):
            for y in range(self.size):
                s += "·XO"[self.board[x*self.size+y]]
            s += "\n"
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return key.count('·')<mykey.count('·') and \
               all(ch in '12·' or key[i]!='·' for i,ch in enumerate(mykey)) and\
               (mykey[1]=='·' or mykey[1]==key[1]) and\
               (mykey[self.size]=='·' or mykey[self.size]==key[self.size]) and\
               (mykey[-self.size]=='·' or mykey[-self.size]==key[-self.size]) and\
               (mykey[-1]=='·' or mykey[-1]==key[-1])

* timemax=60 beats Reversu Difficult
* timemax=15 beats Reversu Medium, lost to Reversu Difficult

In [17]:
# state = OthelloState(8)
# mcts = MCTS(verbosity=1)
while state.GetMoves():
    if state.playerJustMoved==1:
        clear_output()
        board=np.array([f'{("·XO"[e] if e else str(p)):2s}' for p,e in enumerate(state.board)]).reshape(8,8)
        print(board)
        if state.key() in mcts.nodes:
            print(mcts.nodes[state.key()].ChildrenToString())
            print('choose from:',end='')
        moves = state.GetMoves()
        if moves[0] is None: m = None
        else: m=int(input(moves))
    else:
        m = mcts.search(state,timemax=1+len(list(filter(None,state.board)))*0 )  #len...=number of pieces on board
    state.Move(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1(X) wins!")
elif state.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

[['X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ']
 ['X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ']
 ['O ' 'X ' 'X ' 'X ' 'X ' 'O ' 'X ' 'X ']
 ['O ' 'O ' 'X ' 'X ' 'X ' 'X ' 'O ' 'X ']
 ['O ' 'O ' 'X ' 'O ' 'O ' 'O ' 'X ' 'X ']
 ['40' 'O ' 'X ' 'O ' 'X ' 'X ' 'X ' 'X ']
 ['48' 'X ' 'O ' 'X ' 'X ' 'X ' 'X ' 'X ']
 ['X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ']]
48      |[W/V:   -inf/     4 | UnXplrd: 0] 
choose from:[48]48
XXXXXXXX
XXXXXXXX
OXXXXOXX
OOXXXXOX
OOXOOOXX
·OXOXXXX
OOOXXXXX
XXXXXXXX
455
40      |[W/V:    inf/     4 | UnXplrd: 0] 
XXXXXXXX
XXXXXXXX
XXXXXOXX
XOXXXXOX
XXXOOOXX
XXXOXXXX
XXOXXXXX
XXXXXXXX
38
Player 1(X) wins!


# AlphaZero-style (guide tree search by policy+value network)

https://www.youtube.com/watch?v=ld28AU7DDB4
* The uncertainty term used in UCB is replaced by **Polynomial Upper Confidence Trees** PUCT $\propto P(s_i|s)\frac{\sqrt{N}}{1+n_i}$ (where N is the total visits of current state s, and n_i are visit counts of each of the next possible nodes). $P(s_i|s)$ is a probability output by policy network given current state
* Rollout replaced by the value estimation of the current state by the network. For a newly-expanded child node (c), use value network to compute all the values of its possible child nodes (c's c), used for computing the mean of the current (c) node, with weights given by the policy network.
* In AlphaGo, policy network is trained with human expert positions. Then policy network is then used to play against itself to generate positions + target (win/loss) for training the value network
* AlphaZero combines policy and value network into a single network. Repeat:
  1. training examples (state,policy of next move,win/loss) are generated by MCTS, used to train the policy+value network.
  2. Use the network trained to guide MCTS
  
Other tweaks
* in the PUCT, normalize (sum of all=1) and *then* exponentiate all counts before putting into formula $N^{1/\tau}$ and $n_i^{1/\tau}$, where $\tau=1$ for the first 10-100 moves (exploratory) and ~0 for later moves (only choose the best).
* After each iteration, pit against previous version of the net to make sure it has really improved

## Code

### Train net using (states,policies,values) given by the self-plays of MCTS

In [None]:
train_Xy += mcts.selfplay()
net.train(examples)

### MCTS given guidance from Net, generates (states,policies,values) through self-play

In [None]:
class MCTS():
    def __init__(self, game, net):
        self.game = game
        self.net = net
        self.Qsa = {}       # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}       # stores #times edge s,a was visited
        self.Ns = {}        # stores #times board s was visited
        self.Ps = {}        # stores initial policy (returned by neural net)

        self.Es = {}        # stores game.getGameEnded ended for board s
        self.Vs = {}        # stores game.getValidMoves for board s

    def getActionProb(self, canonicalBoard, temp=1):
        """
        This function performs numMCTSSims simulations of MCTS starting from
        canonicalBoard.
        Returns:
            probs: a policy vector where the probability of the ith action is
                   proportional to Nsa[(s,a)]**(1./temp)
        """
        for i in range(self.args.numMCTSSims):
            self.search(canonicalBoard)

        s = self.game.stringRepresentation(canonicalBoard)
        counts = [self.Nsa[(s,a)] if (s,a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

        if temp==0:
            bestA = np.argmax(counts)
            probs = [0]*len(counts)
            probs[bestA]=1
            return probs

        counts = [x**(1./temp) for x in counts]
        probs = [x/float(sum(counts)) for x in counts]
        return probs


    def search(self, canonicalBoard):
        """
        This function performs one iteration of MCTS. It is recursively called
        till a leaf node is found. The action chosen at each node is one that
        has the maximum upper confidence bound as in the paper.
        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propogated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propogated up the search path. The values of Ns, Nsa, Qsa are
        updated.
        NOTE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current player, then its value is -v for the other player.
        Returns:
            v: the negative of the value of the current canonicalBoard
        """

        s = self.game.stringRepresentation(canonicalBoard)

        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
        if self.Es[s]!=0:
            # terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # leaf node
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids = self.game.getValidMoves(canonicalBoard, 1)
            self.Ps[s] = self.Ps[s]*valids      # masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s    # renormalize
            else:
                # if all valid moves were masked make all valid moves equally probable
                
                # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
                # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.   
                print("All valid moves were masked, do workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1

        # pick the action with the highest upper confidence bound
        for a in range(self.game.getActionSize()):
            if valids[a]:
                if (s,a) in self.Qsa:
                    u = self.Qsa[(s,a)] + self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)])
                else:
                    u = self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s] + EPS)     # Q = 0 ?

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
        next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
        next_s = self.game.getCanonicalForm(next_s, next_player)

        v = self.search(next_s)

        if (s,a) in self.Qsa:
            self.Qsa[(s,a)] = (self.Nsa[(s,a)]*self.Qsa[(s,a)] + v)/(self.Nsa[(s,a)]+1)
            self.Nsa[(s,a)] += 1

        else:
            self.Qsa[(s,a)] = v
            self.Nsa[(s,a)] = 1

        self.Ns[s] += 1
return -v

### MCTS using trained Net: given state, generate next move

### `Game` class

## [Connect4](https://github.com/plkmo/AlphaZero_Connect4)

## [Othello](https://github.com/suragnair/alpha-zero-general)