# Monte Carlo Tree Search

## Explanation
- https://www.analyticsvidhya.com/blog/2019/01/monte-carlo-tree-search-introduction-algorithm-deepmind-alphago/
- https://github.com/hayoung-kim/mcts-tic-tac-toe
- https://nestedsoftware.com/2019/08/07/tic-tac-toe-with-mcts-2h5k.152104.html


In [1]:
import numpy as np
import time
from IPython.display import clear_output
import unittest

In [56]:
class Node(object):
    """
    Monte Carlo Tree Search Node
    """
    def __init__(self,env,turn):

        # Attrib
        self.win = 0 #reward
        self.visited = 0
        self.turn = turn
        self.isRoot = False
        self.value = float("inf")
        self.isAllChildVisitedOnce = False # node fully expanded ? when all uct is not inf
        self.isTerminal = False # is terminal node
        self.childNodesExtracted = False
        self.limit = 1
        self.isLeafNode = True

        self.backPropDoneOnce = False
        self.winUpdated = False

        self.move = None
        self.parent = None # brackprop
        self.childNodes = [] 

        self.state = None # state
        self.env = env 
        self.ifcount = 0

    def findChilds(self):
        # find all childs and add them
        self.childNodesExtracted = True

        if self.isTerminal:
            return self
        moves = self.env.getMoves(self.state, self.turn)

        for i in moves:
            temp_state = self.env.executeMove(self.state.copy(),i, self.turn)
            n =  Node(self.env,not self.turn)
            self.addChild(n,i,temp_state)



    def addChild(self,obj,move,temp_state = None):
        obj.parent = self
        obj.state = temp_state
        obj.move = move
        obj.isTerminal = self.env.isComplete(temp_state) != -1
        self.isExpanded = True 
        self.childNodes.append(obj)


    def print(self):
        self.env.printState(self.state)
        print("----")

    # env 
    # -> printState
    # -> executeMove
    # -> getMoves

In [57]:
def prRed(skk): print("\033[91m {}\033[00m" .format(skk)) 
def prGreen(skk): print("\033[92m {}\033[00m" .format(skk)) 
prGreen("hello")

[92m hello[00m


In [88]:
class Mcts:
    def __init__(self,env,turn,root,debug =False):
        self.MaxTime = 3
        self.iniTime = time.time()
        self.env = env
        self.turn = turn
        self.root = root
        self.maxVal = float('inf')
        self.debug = debug

        # iteration resouce
        self.maxIter = 1000
        self.iter = 0

    
    def resourceAvailable(self,UseTimeResource = True):
        # within the time limit

        # with iterations
        if  not UseTimeResource:
            if self.iter < self.maxIter:
                self.iter += 1
                return True
            return False


        if time.time() - self.iniTime  <= self.MaxTime:
            return True
        return False

    def best(self,node):
        """ UCB selection """
        if node.isTerminal:
            return node

        past = -100
        obj = node
        
        for i in node.childNodes: 
            val = self.nodeValue(i)
            if val > past: # maxi.
                obj = i
                past = val
        return obj


    def selection(self,node):
        """Return the leaf node To be expanded """

        while not node.isTerminal and not node.isLeafNode :
            node = self.best(node) # higher UCT value
            # node.print()

        if node.isLeafNode:  
            node.findChilds()
            node.isLeafNode = False

        return self.best(node)



    def rollout(self,node):
        # simulate the game with env

        temp_turn = node.turn
        temp_state = node.state.copy()
        # invert_reward = True

        while self.env.isComplete(temp_state) == -1: # unless the Game is complete

            moves = self.env.getMoves(temp_state,temp_turn) # gives random moves
            temp_state = self.env.executeMove(temp_state.copy(), moves[0], temp_turn)
            temp_turn = not temp_turn
            # invert_reward = not invert_reward


        winner = self.env.isComplete(temp_state) # 0, 1,2

        if node.turn == True:
            wint = 2
        else:
            wint = 1

        if winner == 0:
            return 1 #draw
        elif wint == winner:
            return 1  # favourable condition
        return -1
    

    def nodeValue(self,node):
        # UCB funciton

        win = node.win
        nodeVisit = node.visited
        rootVisit = self.root.visited

        if nodeVisit==0:
            return 10000
        return win/nodeVisit +  np.sqrt(2*np.log(rootVisit)/nodeVisit)

    def backpropagate(self,node,result):
        # update the values to the root node
        assert node!=self.root

        while node != None :
            result = -1*result
            node.visited += 1
            node.win += result

            # for simulation
            node.winUpdated = True
            node = node.parent
        

    def bestMove(self,node):
        # find the best move with UCT
        # assert node.childNodesExtracted == True

        past = -self.maxVal
        obj = None
        assert node.isLeafNode == False

        for i in node.childNodes:
            if i.win >= past: # max
                obj = i
                past = i.win
            
        return obj.move


    def execute(self,UseTimeResource = True):
        """ the core of Mcts """

        self.iniTime = time.time() # restart time 

        while self.resourceAvailable( UseTimeResource):
            node = self.selection(self.root) # find the leaf nodes to be expanded
            simulation_result = self.rollout(node)
            self.backpropagate(node,simulation_result)

            # simulation display root node
            if self.debug == True:
                time.sleep(0.5)
                clear_output(True)
                for j in self.root.childNodes:
                    j.print()
                    print(j.win,j.visited)
                    if j.winUpdated == True:
                        prGreen(self.nodeValue(j))
                        j.winUpdated = False
                    else:
                        print(self.nodeValue(j))


        return self.bestMove(self.root)
    
    # bestMove ??

In [91]:
# M# 2 - True (you)
# 1 - False (bot)

class Env:
  """ 
  AI playing TicTacToe as your opponent 
  Solving it using MINMAX Algorithm
  state 1 = X
  state 2 = O
  state 0 = empty block

  """
  def __init__(self):
    self.size = 3
    self.state = np.zeros( (self.size,self.size),dtype = np.int8)
    self.bot = 1
    self.pos = (2,2)
    self.turn = True
    # print("ini")


  def isComplete(self,state):
    """ Check wheather game is over and return the winner
    0 if its Draw
    -1 not complete """

    li = [1,2]
    for i in li:
      # horizontal check
      for k in range(self.size):
        res = [True if h == i else False for h in state[k,:] ]
        if all(res):
          return i
      # vertical check
      for k in range(self.size):
        res = [True if h == i else False for h in state[:,k] ]
        if all(res):
          return i

      for k in range(self.size):
        res = [True if state[k][k] == i else False for k in range(self.size) ]
        res1 = [True if state[k][self.size-k-1] == i else False for k in range(self.size) ]
        if all(res) or all(res1):
          return i
        
    flag = False
    if(self.countSquare(state) == 0):
        # Draw condition
      return 0

    return -1

  def getMoves(self, state,turn):
    """
    Return possible moves from the given state
    """

    move = []
    for i in range(self.size):
      for j in range(self.size):
        if(state[i][j] == 0): 
          move.append([i,j])
    np.random.shuffle(move)
    return move


  def countSquare(self,state):
    count = 0
    for i in range(self.size):
      for j in range(self.size):
        if(state[i][j] == 0): 
          count+=1
    return count

  def executeMove(self,state,move,turn):
    #   print("execute move ",turn)
      assert len(state) == 3
      player = 2 if turn else 1
      state[move[0]][move[1]] = player
      return state


  def engine(self,debug = False):

      print("turn = ",self.turn)

      n = Node(self,turn= self.turn)
      n.state = self.state.copy()
      n.visited = 1
      n.isRoot = True # can be remobed
      n.isLeafNode = True


      mcts = Mcts(self,self.turn,n,debug = False)

      move =  mcts.execute(True)

      # del mcts
      assert move != None
      if debug:
          return mcts
      return move

  def printState(self,state):

    """ Display the board state node.state"""

    for i in range(self.size-1,-1,-1): # printing states in reverse
      for j in range(self.size): 
        if state[i][j] == 0 :
          curr = ' '
        elif state[i][j] == 1:
          curr = 'X'
        else:
          curr = 'O'  
        if j != self.size-1:
          print(f"{curr} | ",end = '')
        else:
          print(f"{curr}")
      if i != 0:
        print("--"*self.size*2)
                                                                                                                                     

  def play(self,youFirst = True):
    print("\t TicTacToe \n Bot = X \n You = O ")
    self.turn = youFirst
    while True:
      clear_output(True)
      res = self.isComplete(self.state)
      if res in [0,1,2]:
        self.printState(self.state)
        if res == 0:
          print("Draw")
          break
        print(f"Player {res} Won")
        break

      print(f"Player {self.turn} thinking ... ")
      
      self.printState(self.state)
      if self.turn:
        #   move = self.engine()
        #   self.state = self.executeMove(self.state,move,self.turn)
        
          a = int(input("intput number"))-1
          x,y = a//3, a%3
          self.state = self.executeMove(self.state,[x,y],self.turn)
      else:
          move = self.engine()
          self.state = self.executeMove(self.state,move,self.turn)

      
      self.turn = not self.turn
    #   time.sleep(0.5)
      


In [92]:
g = Env()

g.play(youFirst=True)

X |   | O
------------
O | X | X
------------
O | O | X
Player 1 Won


In [None]:
t = Env()
# t.state = np.array([[2,0,0],[1,1,2],[0,0,2]],np.int8)

t.state = np.array([[0,0,0],[0,0,0],[2,0,0]],np.int8)
# t.state = np.array([[2,0,0],[0,1,0],[1,0,2]],np.int8)
t.printState(t.state)

In [22]:
t.turn = False
m = t.engine(debug = True)
node = m.root 

turn =  False


In [23]:
for i in node.childNodes:
    i.print()
    print(i.value)
    # print(i.ifcount)
    print(i.visited)
    print(i.win)

O |   |  
------------
  |   |  
------------
  | X |  
----
0.08468394816985292
18
-14
O |   |  
------------
  |   |  
------------
  |   | X
----
0.11351231111838872
64
-22
O |   |  
------------
X |   |  
------------
  |   |  
----
0.10978336159461677
62
-22
O |   |  
------------
  |   |  
------------
X |   |  
----
0.11258372219991766
110
-26
O |   | X
------------
  |   |  
------------
  |   |  
----
0.10865189140598086
36
-18
O |   |  
------------
  | X |  
------------
  |   |  
----
0.1774228175973489
635
19
O |   |  
------------
  |   | X
------------
  |   |  
----
0.09346178815684536
29
-17
O | X |  
------------
  |   |  
------------
  |   |  
----
0.1046745930668937
46
-20


In [None]:
n1 = node.childNodes[1]
n1.print()
n1.turn

  |   | O
------------
X | X | O
------------
O | X |  
----


True

In [None]:
for i in n1.childNodes:
    i.print()
    print(i.value)
    # print(i.ifcount)
    print(i.visited)
    print(i.win)

  |   | O
------------
X | X | O
------------
O | X | O
----
-9.777106127839177
278
-2780
  | O | O
------------
X | X | O
------------
O | X |  
----
-9.777089965906356
278
-2780
O |   | O
------------
X | X | O
------------
O | X |  
----
-9.777042896267139
439
-4370


In [None]:
# node.findChilds()
node.print()
n1 = m.selection(node)
n1.print()

X |   |  
------------
O | O | O
------------
X |   |  
----
X |   |  
------------
O | O | O
------------
X | X |  
----


In [None]:
# n1 == node

False

In [None]:
for j in range(10):
    val = m.rollout(n1)
    m.backpropagate(n1,val)
    print(val)

10
-10
10
10
10
10
10
10
-10
10


In [None]:
n1.value 

2.5748662386428496

In [None]:
print(n1.win,n1.visited,node.visited)

8 10 1826


In [None]:
m.nodeValue(n1.win,n1.visited,node.visited)

2.5331916294691608

In [None]:
n2 = n2.best()
n2.print()

X |   |  
------------
  | X |  
------------
O |   |  
----


In [None]:
for i in range(50):
    n3 = m.selection(n2)
    

  |   | O
------------
  | X | O
------------
X | X |  
----
None
0.0
1
  |   | O
------------
  | X | O
------------
  | X | X
----
None
1.0601855883790088
9393
  |   | O
------------
X | X | O
------------
  | X |  
----
None
1.0600437880513882
2425
X |   | O
------------
  | X | O
------------
  | X |  
----
None
1.0601200790995824
5826
  | X | O
------------
  | X | O
------------
  | X |  
----
None
1.0601832494631187
11347


  |   |  
------------
  |   |  
------------
  |   |  
  |   |  
------------
  |   |  
------------
  |   |  


In [None]:
node.isExpanded

True

In [None]:
n1 =  m.selection(node)
n1.print()

X |   | O
------------
X |   |  
------------
X | O |  


In [None]:
# m.rollout(n1)
print(n1.win/n1.visited,n1.value)

1.0 1.0


In [None]:
! ls


sample_data


In [None]:
h = node.childNodes[7]
print(h.childNodesExtracted)
print(h.childNodes)

False
[]


In [None]:
class TestNotebook(unittest.TestCase):

    def test_Node(self):
        n = Node(Env,True)
        self.assertEqual(n.value, float("inf"))

    def test_add_child(self):
        n = Node(Env,True)
        n.value = 9

        n1 = Node(Env,True)
        n1.value = 5
        n.addChild(n1)
        child = n.childNodes[0]
        self.assertEqual(child.value, 5)
        parent = child.parent
        self.assertEqual(parent.value, 9)

unittest.main(argv=[''], verbosity=0, exit=False)

In [None]:
class base:
    def __init__(self):
        self.p = 9
        print("ini")

class derived(base):
    def __init__(self):
        base.__init__(self)
    # self.q = 4

d = derived()
d.p

# class Game(Mcts,Env):
#     def __init__(self):
#         Env.__init__(self)
#         Mcts.__init__(self)

#     def findChilds(self):
#         # find all childs and add them
#         # self.isExpanded = True
#         print(Env.getMove(self,self.state))



ini


9

In [None]:
if float('inf') > 7:
    print("t")

t
