In [1]:
import time, math, random
from copy import deepcopy
from IPython.display import clear_output
import numpy as np

# Barebone MCTS with UCT
https://youtu.be/ItMutbeOHtc?t=3922, https://www.youtube.com/watch?v=Bm7zah_LrmE

1. From each node (a node = a state of the board), repeat trying unexplored moves and rollout randomly unil gameover. The win/lose count is propagated back along the path (only above or at the level where new moves are being tried out -- for memory constraints) taken for each node
2. when all moves of the node have been tried at least once, *select* the one to try based on UCB (mean value + C\*uncertainty of the mean $\sqrt{\frac{2\log N_{\rm visit}}{n_{\rm child's visit}}}$). Repeat doing this until there is at least an untried move at the node
3. repeat 1&2 until desired time/iteration limit, then select the best move at root node (current real situation of the board) based on *mean value* of each move

4. consider some nodes may be reached from different ways; so keep the {move: child_node} dictionary. when selecting child node, return the corresponding move and child_node also (do not use parent of a node to backtrack)
5. ONLY for terminating (acyclic) games
6. Keep the nodes that (in principle) could be reached from the current state in the future. Game class need to implement `Reachable` wrt its key for this to work. Otherwise keep every node (memory...)

## Code: [mcts.ai](http://mcts.ai/code/python.html) modified + reuse nodes

In [2]:
class Node:
    """ Wins is always from the viewpoint of playerJustMoved."""
    def __init__(self, state = None):
        self.childNodes = {} # the move that get *into* next node
        self.wins = 0
        self.visits = 0
        self.untriedMoves = state.GetMoves() # future child nodes
        self.playerJustMoved = state.playerJustMoved # the only part of the state that the Node needs later
        
    def UCTSelectChild(self,explore=1):
        """ Use the UCB1 formula to select a child node. Often a constant UCTK is applied so we have
            lambda c: c.wins/c.visits + UCTK * sqrt(2*log(self.visits)/c.visits to vary the amount of
            exploration versus exploitation. 
        """
        return max(list(self.childNodes.items()), 
                   key = lambda c: c[1].wins/c[1].visits + explore*math.sqrt(2*math.log(self.visits)/c[1].visits) )
    
    def AddChild(self, move, s, n = None):
        """ Remove m from untriedMoves and add a new child node for this move. if provided, use that  """
        if n is None:
            n = Node(state = s)
        self.untriedMoves.remove(move)
        self.childNodes[move] = n
        return n
    
    def Update(self, result):
        """ update visit & win counts. (from the viewpoint of playerJustmoved) """
        self.visits += 1
        self.wins += result # game results in the range [0.0, 1.0]

    def __repr__(self):
        return f"[W/V: {self.wins:6g}/{self.visits:6d} | UnXplrd: {len(self.untriedMoves)}] "

    def TreeToString(self, indent):
        s = "\n"+ "| "*indent + str(self)
        s += ''.join(c.TreeToString(indent+1)+str(m) for m,c in self.childNodes.items())
        return s

    def ChildrenToString(self):        
        return "\n".join(f'{str(m):8s}|{c}' for m,c in sorted(self.childNodes.items(),key=lambda e: e[1].wins/e[1].visits))
    
class MCTS():   #reuses nodes
    def __init__(self, explore=1, verbosity=0):
        self.nodes = {}               #store all nodes previously explored
        self.explore = explore
        self.verbosity = verbosity

    def search(self, state, timemax=None, itermax=None):
        
        key = state.key()
        if key not in self.nodes:
            self.nodes[key] = Node(state = state)
        self.rootnode = self.nodes[key]
        
        if 'Reachable' in dir(state): #pruning
            self.nodes = {k:v for k,v in self.nodes.items() if key==k or state.Reachable(k)}
        if self.rootnode.untriedMoves or len(self.rootnode.childNodes)>1:
            if timemax is None:
                for _ in range(itermax):
                    self.simulate(state.Clone(),self.rootnode)
            elif itermax is None:
                start = time.time(); self.simulate(state.Clone(),self.rootnode); OneBatch = int(timemax/(time.time()-start)/10)
                while time.time()<start+timemax:
                    for _ in range(OneBatch):
                        self.simulate(state.Clone(),self.rootnode)
            else:
                start = time.time(); self.simulate(state.Clone(),self.rootnode); OneBatch = int(timemax/(time.time()-start)/10)
                i = 0
                while time.time()<start+timemax and i<itermax:
                    for _ in range(OneBatch):
                        i = i+1
                        self.simulate(state.Clone(),self.rootnode)

        moveToChild, bestChild = self.rootnode.UCTSelectChild(explore=0)
        if self.verbosity==2:   print(self.rootnode.TreeToString(0))
        elif self.verbosity==1: print(self.rootnode.ChildrenToString())
            
        return moveToChild
    
    def simulate(self,state,node):        
        path = [node]
        #descend (following UCB) to the first branch not fully expanded (some moves not tried), or reaching end of game
        while node.untriedMoves == [] and node.childNodes != {}:
            move,node = node.UCTSelectChild(self.explore)
            state.DoMove(move)
            path += node,
        
        # if we can expand (i.e. node is not terminal), expand (add a childNode) and move the state into it 
        if node.untriedMoves != []: 
            move = random.choice(node.untriedMoves) 
            state.DoMove(move)
            k = state.key()
            if k in self.nodes:
                nextnode = self.nodes[k]
            else:
                nextnode = Node(state = state)
                
            node.AddChild(move,state,nextnode)
            self.nodes[k] = nextnode
            node = nextnode
            path += node,

        # Rollout to END of a game randomly (not expanding childNodes -- just want to estimate the newly added node's value)
        state.RollOut()

        # state is now terminal; backpropagate this game's result to its path's nodes' win counts
        viewpoint = node.playerJustMoved
        reward = state.GetResult(viewpoint)
        for node in path:
            if node.playerJustMoved == viewpoint:
                node.Update(reward)
            else:
                node.Update(1-reward)

## [m,n,k-game](https://en.wikipedia.org/wiki/M,n,k-game)

In [143]:
#contains current state and who has last moved
import itertools
class OXOState:
    def __init__(self,width=3,height=3,inarow=None,moves_per_round=None,end1=True, # game ends as one success
                 board=None,empty=None,lines=None):
        self.playerJustMoved = 2 #  (1) will have the first move
        self.width = width
        self.height = height
        self.board = board or [0]*width*height # 0 = empty, 1 = player 1, 2 = player 2
        self.empty = empty or list(range(width*height))
        self.moves_per_round = moves_per_round or 1
        self.end1 = end1
        
        # winning lines
        self.lines = lines
        if self.lines is None:
            inarow = inarow or min(width,height)
            self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                         [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                         [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                         [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]

    def Clone(self):
        """ Create a deep clone of this game state. """
        st = OXOState(self.width,self.height,None,self.moves_per_round,self.end1,self.board[:],self.empty[:],self.lines)
        st.playerJustMoved = self.playerJustMoved
        return st

    def DoMove(self, move):
        """ Update a state by carrying out the given move. Must also update playerJustMoved.  """
        self.playerJustMoved = 3 - self.playerJustMoved
        for m in move:
            self.board[m] = self.playerJustMoved
            self.empty.remove(m)
        if self.end1 and self.GetResult(self.playerJustMoved) in [0,1]:  
            self.empty=[]
        
    def GetMoves(self): # must clone as Node use this return as untriedMoves  # return self.empty[:]
        return list(itertools.combinations(self.empty,self.moves_per_round))
    
    def RollOut(self):
        possibleMoves = self.GetMoves()
        while possibleMoves:
            self.DoMove(random.choice(possibleMoves))
            possibleMoves = self.GetMoves()
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        myscore = opposcore = 0
        for l in self.lines:
            if self.board[l[0]]!=0 and len(set(self.board[p] for p in l))==1:
                if self.board[l[0]] == viewpoint:
                    myscore += 1
                else:
                    opposcore += 1
        if myscore>opposcore: return 1
        elif opposcore>myscore: return 0
        return 0.5

    def __repr__(self): # for output: 1 (X), 2 = (O)
        s = ''
        for left in range(0,self.width*self.height,self.width):
            for sh in range(self.width): s+="·XO"[self.board[left+sh]]
            s+='\n'
        return s.strip()
    
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '·12' or ch==key[i] for i,ch in enumerate(mykey))

In [None]:
state = OXOState(width=3,height=3,inarow=3,moves_per_round=1,end1=True) #traditional
state = OXOState(width=5,height=5,inarow=4) #http://boulter.com/ttt/
# state = OXOState(width=8,height=8,inarow=4,moves_per_round=2,end1=False) #http://www.atksolutions.com/games/tictactoedeluxe.html
# state = OXOState(width=15,height=15,inarow=5) #Gomoku
mcts = MCTS(verbosity=0)
while state.GetMoves():
    if state.playerJustMoved==1: #computer plays 1st/2nd
        clear_output()
        board=np.array([f'{("·XO"[s] if s else str(pos)):2s}' for pos,s in enumerate(state.board)]).reshape(state.height,state.width)
        print(board)
        if state.key() in mcts.nodes:
            n=mcts.nodes[state.key()]
            print(n.ChildrenToString())
            print(f'choose from: {state.GetMoves()[:10]}...',end='')
        m = (int(input()),)   #2 moves (int(input()),int(input()))
    else:
        m = mcts.search(state,timemax=15) # m = UCT(rootstate = state, itermax = 10000, verbosity=1)
    state.DoMove(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1(X) wins!")
elif state.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

[['0 ' 'X ' '2 ' 'O ' '4 ']
 ['5 ' '6 ' 'X ' 'O ' 'X ']
 ['X ' 'O ' 'X ' 'X ' '14']
 ['O ' 'X ' 'O ' 'O ' 'O ']
 ['20' 'X ' 'O ' 'X ' '24']]
(20,)   |[W/V:   1831/  3662 | UnXplrd: 0...] 
(2,)    |[W/V:   1831/  3662 | UnXplrd: 0...] 
(0,)    |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
(5,)    |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
(14,)   |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
(4,)    |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
(6,)    |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
(24,)   |[W/V: 1830.5/  3661 | UnXplrd: 0...] 
choose from: [(0,), (2,), (4,), (5,), (6,), (14,), (20,), (24,)]...

## Connect4

In [39]:
class ConnectNState:
    def __init__(self,width=7,height=6,inarow=4,board=None,moves=None,lines=None):
        self.playerJustMoved = 2 #  (1) will have the first move
        self.width = width
        self.height = height
        self.board = board or [0]*width*height
        self.moves = moves or list(range(width))
        self.firstEmpty = [height-1]*width
        
        # winning lines
        self.lines = lines
        if self.lines is None:
            inarow = inarow or min(width,height)
            self.lines = [tuple(range(i*width+j,i*width+j+inarow)) for i in range(0,height) for j in range(width-inarow+1)]+\
                         [tuple(range(i*width+j,(i+inarow)*width+j,width)) for i in range(0,height-inarow+1) for j in range(width)]+\
                         [tuple(range(i*width+j,i*width+j+inarow*(width+1),width+1)) for i in range(height-inarow+1) for j in range(width-inarow+1)]+\
                         [tuple(range(i*width+j,i*width+j+inarow*(width-1),width-1)) for i in range(height-inarow+1) for j in range(inarow-1,width)]
        
    def Clone(self):
        st = ConnectNState(self.width,self.height,None,self.board[:],self.moves[:],self.lines)
        st.playerJustMoved = self.playerJustMoved
        st.firstEmpty = self.firstEmpty[:]
        return st

    def DoMove(self, move):  #move=column idx
        assert move in self.moves
        self.playerJustMoved = 3 - self.playerJustMoved
#         for i in reversed(range(self.height)):
#             if self.board[i*self.width+move]==0:
#                 self.board[i*self.width+move]=self.playerJustMoved
#                 break
        self.board[self.firstEmpty[move]*self.width+move]=self.playerJustMoved
        self.firstEmpty[move] -= 1
        if self.firstEmpty[move]<0:
            self.moves.remove(move)
        if self.GetResult(self.playerJustMoved) in [0,1]:  #indicate game ended as someone won
            self.moves=[] 
        
    def GetMoves(self):  # legal moves for next player. empty if game is over
        return self.moves[:]
    
    def RollOut(self):
        while self.moves:
            self.DoMove(random.choice(self.moves))
    
    def GetResult(self, viewpoint):  # reward from `viewpoint`, in the range [0.0, 1.0]
        for l in self.lines:
            if self.board[l[0]]!=0 and len(set(self.board[p] for p in l))==1:
                if self.board[l[0]] == viewpoint:
                    return 1
                else:
                    return 0
        return 0.5

    def __repr__(self): # string representation: 1 (X), 2 = (O)
        s = ''
        for left in range(0,self.width*self.height,self.width):
            for sh in range(self.width):
                s+="·XO"[self.board[left+sh]]
            s+='\n'
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return all(ch in '12·' or ch==key[i] for i,ch in enumerate(mykey)) 

In [41]:
# state = ConnectNState(7,6,4) #https://connect4.gamesolver.org/
# mcts = MCTS(verbosity=0)
while state.GetMoves():
    if state.playerJustMoved==2:
        clear_output()
        fn = lambda s: s if int(s)<7 else ''
        board=np.array([f'{("·XO"[s] if s else fn(str(pos))):2s}' for pos,s in enumerate(state.board)])
        print(board.reshape(state.height,state.width))
        if state.key() in mcts.nodes:
            n=mcts.nodes[state.key()]
            print(n.ChildrenToString())
            print('choose from:',end='')
        m = int(input(state.GetMoves()))
    else:
        m = mcts.search(state,timemax=20) #         m = UCT(rootstate = state, itermax = 10000, verbosity=1)
    state.DoMove(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1 wins!")
elif state.GetResult(2) == 1: print("Player 2 wins!")
else: print("Nobody wins!")

[['0 ' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ']
 ['  ' '  ' '  ' '  ' '  ' '  ' '  ']
 ['  ' '  ' 'O ' 'O ' '  ' '  ' '  ']
 ['  ' '  ' 'X ' 'O ' '  ' '  ' '  ']
 ['  ' '  ' 'X ' 'X ' 'O ' '  ' '  ']
 ['  ' '  ' 'X ' 'O ' 'X ' '  ' '  ']]
4       |[W/V:   69.5/  1305 | UnXplrd: 0] 
3       |[W/V:  128.5/  1783 | UnXplrd: 0] 
0       |[W/V:  160.5/  2029 | UnXplrd: 0] 
2       |[W/V:    172/  2117 | UnXplrd: 0] 
6       |[W/V:    182/  2192 | UnXplrd: 0] 
1       |[W/V:  182.5/  2196 | UnXplrd: 0] 
5       |[W/V:   8577/ 52605 | UnXplrd: 0] 
choose from:

KeyboardInterrupt: 

## Othello

In [3]:
class OthelloState:
    def __init__(self, size = 8, board=None):  # size must be integral and even
        self.playerJustMoved = 2
        self.size = size
        self.board = board
        if self.board is None:
            self.board = [] # 0 = empty, 1 = player 1, 2 = player 2
            for y in range(self.size):
                self.board.append([0]*size)
            self.board[size//2][size//2] = self.board[size//2-1][size//2-1] = 2
            self.board[size//2][size//2-1] = self.board[size//2-1][size//2] = 1
            self.board = [e for l in self.board for e in l]

    def Clone(self):
        st = OthelloState(size=self.size, board=self.board[:])
        st.playerJustMoved = self.playerJustMoved
        return st

    def DoMove(self, move):
        if move is not None:
            (x,y)=divmod(move,self.size)
            assert self.IsOnBoard(x,y) and self.board[x*self.size+y] == 0
            m = self.GetAllSandwichedCounters(x,y)
            self.playerJustMoved = 3 - self.playerJustMoved
            self.board[x*self.size+y] = self.playerJustMoved
            for (a,b) in m:
                self.board[a*self.size+b] = self.playerJustMoved
        else: 
            self.playerJustMoved = 3 - self.playerJustMoved
    
    def GetMoves(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        if not emptypos:
            return []
        else:
            viable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))] 
            if viable:
                return viable
            else:  #need to check if opponent also has no viable pos
                self.playerJustMoved = 3 - self.playerJustMoved
                oppoViable = [pos for pos in emptypos if self.ExistsSandwiched(*divmod(pos,self.size))]
                self.playerJustMoved = 3 - self.playerJustMoved
                if oppoViable:
                    return [None]
                else:
                    return []
    
    def RollOut(self):
        emptypos = [pos for pos,e in enumerate(self.board) if e==0]
        bad=0 
        while emptypos:
            random.shuffle(emptypos)
            for pos in emptypos:
                if self.ExistsSandwiched(*divmod(pos,self.size)):
                    self.DoMove(pos)
                    emptypos = [i for i,e in enumerate(self.board) if e==0]
                    break
            else:
                if bad<2: #when bad=2, both side cannot play, game ends
                    bad+=1
                    self.DoMove(None)
                else:
                    break

    def AdjacentEnemyDirections(self,x,y):# Speeds up GetMoves by only considering squares which are adjacent to an enemy-occupied square.
        return [(dx,dy) for (dx,dy) in [(0,+1),(+1,+1),(+1,0),(+1,-1),(0,-1),(-1,-1),(-1,0),(-1,+1)]
                        if self.IsOnBoard(x+dx,y+dy) and self.board[(x+dx)*self.size+y+dy] == self.playerJustMoved]
    
    def ExistsSandwiched(self,x,y):# Is there at least one counter which would be flipped if my counter was placed at (x,y)? 
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            x1=x+dx
            y1=y+dy
            while self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == self.playerJustMoved:
                x1 += dx
                y1 += dy
            if self.IsOnBoard(x1,y1) and self.board[x1*self.size+y1] == 3 - self.playerJustMoved:
                return True
        return False
    
    def GetAllSandwichedCounters(self, x, y):# Is (x,y) a possible move (i.e. opponent counters are sandwiched between (x,y) and my counter in some direction)?
        sandwiched = []
        for (dx,dy) in self.AdjacentEnemyDirections(x,y):
            sandwiched.extend(self.SandwichedCounters(x,y,dx,dy))
        return sandwiched

    def SandwichedCounters(self, x, y, dx, dy):# Return the coordinates of all opponent counters sandwiched between (x,y) and my counter.
        x += dx
        y += dy
        sandwiched = []
        while self.IsOnBoard(x,y) and self.board[x*self.size+y] == self.playerJustMoved:
            sandwiched.append((x,y))
            x += dx
            y += dy
        if self.IsOnBoard(x,y) and self.board[x*self.size+y] == 3 - self.playerJustMoved:
            return sandwiched
        else:
            return [] # nothing sandwiched

    def IsOnBoard(self, x, y):
        return x >= 0 and x < self.size and y >= 0 and y < self.size
    
    def GetResult(self, viewpoint): #after gameover
        viewpointscore=oppositescore=0
        for e in self.board:
            if e==viewpoint: viewpointscore+=1
            elif e==3-viewpoint: oppositescore+=1
        if viewpointscore > oppositescore: return 1.0
        elif oppositescore > viewpointscore: return 0.0
        else: return 0.5 # draw

    def __repr__(self):
        s= ""
        for x in range(self.size):
            for y in range(self.size):
                s += "·XO"[self.board[x*self.size+y]]
            s += "\n"
        return s.strip()
    def key(self): #string containing all information for hashing the node
        return str(self.playerJustMoved)+''.join("·XO"[e] for e in self.board)
    def Reachable(self,key):  #for pruning Tree
        mykey = self.key()
        return key.count('·')<mykey.count('·') and \
               (mykey[1]=='·' or mykey[1]==key[1]) and\
               (mykey[self.size]=='·' or mykey[self.size]==key[self.size]) and\
               (mykey[-self.size]=='·' or mykey[-self.size]==key[-self.size]) and\
               (mykey[-1]=='·' or mykey[-1]==key[-1])

In [10]:
# state = OthelloState(8)
# mcts = MCTS(verbosity=1)
while state.GetMoves():
    if state.playerJustMoved==1:
        clear_output()
        board=np.array([f'{("·XO"[e] if e else str(p)):2s}' for p,e in enumerate(state.board)]).reshape(8,8)
        print(board)
        if state.key() in mcts.nodes:
            n=mcts.nodes[state.key()]
            print(n.ChildrenToString())
            print('choose from:',end='')
        moves = state.GetMoves()
        if moves[0] is None: m = None
        else: m=int(input(moves))
    else:
        m = mcts.search(state,timemax=10)
    state.DoMove(m)
    print(state)
    print(len(mcts.nodes))
if state.GetResult(1) == 1:   print("Player 1(X) wins!")
elif state.GetResult(2) == 1: print("Player 2(O) wins!")
else: print("Nobody wins!")

[['O ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ']
 ['X ' 'X ' 'X ' 'O ' 'O ' 'O ' 'X ' 'X ']
 ['X ' 'X ' 'O ' 'X ' 'O ' 'O ' 'X ' 'X ']
 ['O ' 'X ' 'O ' 'O ' 'X ' 'O ' 'X ' 'X ']
 ['O ' 'X ' 'X ' 'O ' 'O ' 'O ' 'X ' 'O ']
 ['X ' 'X ' 'X ' 'X ' 'X ' 'O ' 'X ' 'X ']
 ['X ' 'X ' 'X ' 'X ' 'X ' 'X ' 'X ' '55']
 ['O ' 'O ' 'O ' 'O ' 'O ' 'O ' 'X ' '63']]
63      |[W/V:      0/140001 | UnXplrd: 0] 
55      |[W/V:      0/140001 | UnXplrd: 0] 
choose from:[55, 63]63
OXXXXXXX
XXXOOOXX
XXOXOOXX
OXOOXOXX
OXXOOOXO
XXXXXOXX
XXXXXXO·
OOOOOOOO
12838
55      |[W/V: 668697/668697 | UnXplrd: 0] 
OXXXXXXX
XXXOOOXX
XXOXOOXX
OXOOXOXX
OXXOOOXO
XXXXXOXX
XXXXXXXX
OOOOOOOO
1145
Player 1(X) wins!


# AlphaGo-style (guide tree search by policy+value network)

## [Connect4](https://github.com/plkmo/AlphaZero_Connect4)

## Othello