# Monte Carlo Tree Search

## Explanation
- https://www.analyticsvidhya.com/blog/2019/01/monte-carlo-tree-search-introduction-algorithm-deepmind-alphago/
- https://github.com/hayoung-kim/mcts-tic-tac-toe



In [1]:
import numpy as np
import time
from IPython.display import clear_output
import unittest

In [2]:
class Node(object):
    """
    Monte Carlo Tree Search Node
    """
    def __init__(self,env,turn):

        # Attrib
        self.win = 0 #reward
        self.visited = 0
        self.turn = turn
        self.isRoot = False
        self.value = float("inf")
        self.isAllChildVisitedOnce = False # node fully expanded ? when all uct is not inf
        self.isTerminal = False # is terminal node
        self.childNodesExtracted = False

        self.move = None
        self.parent = None
        self.childNodes = []

        self.state = None
        self.env = env

    def findChilds(self):
        # find all childs and add them
        self.childNodesExtracted = True

        if self.isTerminal:
            return None
        moves = self.env.getMoves(self.state,not self.turn)

        for i in moves:
            temp_state = self.env.executeMove(self.state.copy(),i,not self.turn)
            n =  Node(self.env,not self.turn)
            self.addChild(n,i,temp_state)


    def best(self):
        if self.isTerminal:
            return self
        
        past = self.childNodes[0].value
        obj = self.childNodes[0]
        for i in self.childNodes:
            j = i.value
            if j > past:
                obj = i
                past = j
        return obj

    def addChild(self,obj,move,temp_state = None):
        obj.parent = self
        obj.state = temp_state
        obj.move = move
        obj.isTerminal = self.env.isComplete(temp_state) != -1
        self.isExpanded = True
        self.childNodes.append(obj)


    def print(self):
        self.env.printState(self.state)
        print("----")

In [25]:
class Mcts:
    def __init__(self,env,turn,root):
        self.MaxTime = 1
        self.iniTime = time.time()
        self.env = env
        self.turn = turn
        self.root = root
        self.maxVal = float('inf')
        # print("mcts turn ",self.turn)
        # pass
    
    def resourceAvailable(self):
        # within the time limit
        if time.time() - self.iniTime  <= self.MaxTime:
            return True
        return False

    def selection(self,node):
        """Return the leaf node To be expanded"""
        # print("sel")
        while not node.isTerminal and  node.isAllChildVisitedOnce:
            # print("allchildVisitedOnce")
            node = node.best() # higher UCT value

        if not node.childNodesExtracted:  
            node.findChilds()
        return node.best()

    def rollout(self,node):
        # simulate the game with 
        temp_turn = self.turn
        temp_state = node.state.copy()

        while self.env.isComplete(temp_state) == -1: # unless the Game is complete
            moves = self.env.getMoves(temp_state,self.turn)
            np.random.shuffle(moves)   # policy is to pick randomly
            temp_state = self.env.executeMove(temp_state, moves[0], temp_turn)
            temp_turn = not temp_turn
            # Debug purpose
            # self.env.printState(temp_state)
            # time.sleep(0.5)
            # clear_output(True)
        winner = self.env.isComplete(temp_state)

        if self.turn == True:
            wint = 1
        else:
            wint = 2

        if winner == 0:
            return 1
        elif wint == winner:
            return 1
        return -1
    
    def nodeValue(self,win,nodeVisit,rootVisit):
        return win/nodeVisit +  np.sqrt(2*np.log(rootVisit)/nodeVisit)

    def backpropagate(self,node,result):
        # update the values to the root node
        assert node!=self.root

        while node != None :
            node.visited += 1
            # if result >= 0:
            node.win += result
            # try:
            node.value = self.nodeValue(node.win,node.visited,self.root.visited) 

            if not node.isAllChildVisitedOnce and node.childNodesExtracted:
                flag = True
                for i in node.childNodes:
                    if i.value == self.maxVal:
                        flag = False
                        break
                if flag :
                    node.isAllChildVisitedOnce = True
            node = node.parent
        
        

    def bestMove(self,node):
        # find the best move with UCT
        assert node.childNodesExtracted == True

        past = node.childNodes[0].win
        for i in node.childNodes:
            j = i.win
            if j >= past:
                obj = i
                past = j
        return obj.move


    def execute(self):
        """ the core of Mcts """
        # node = root
        self.iniTime = time.time() # restart time 

        while self.resourceAvailable():
            node = self.selection(self.root) # find the leaf nodes to be expanded
            simulation_result = self.rollout(node)
            self.backpropagate(node,simulation_result)
        
        return self.bestMove(self.root)
    
    # bestMove ??

In [26]:



class Env:
  """ 
  AI playing TicTacToe as your opponent 
  Solving it using MINMAX Algorithm
  state 1 = X
  state 2 = O
  state 0 = empty block

  """
  def __init__(self):
    self.size = 3
    self.state = np.zeros( (self.size,self.size),dtype = np.int8)
    self.bot = 1
    self.pos = (2,2)
    self.turn = True
    # print("ini")


  def isComplete(self,state):
    """ Check wheather game is over and return the winner
    0 if its Draw
    -1 not complete """

    li = [1,2]
    for i in li:
      # horizontal check
      for k in range(self.size):
        res = [True if h == i else False for h in state[k,:] ]
        if all(res):
          return i
      # vertical check
      for k in range(self.size):
        res = [True if h == i else False for h in state[:,k] ]
        if all(res):
          return i

      for k in range(self.size):
        res = [True if state[k][k] == i else False for k in range(self.size) ]
        res1 = [True if state[k][self.size-k-1] == i else False for k in range(self.size) ]
        if all(res) or all(res1):
          return i
        
    flag = False
    if(self.countSquare(state) == 0):
        # Draw condition
      return 0

    return -1

  def getMoves(self, state,turn):
    """
    Return possible moves from the given state
    """

    move = []
    for i in range(self.size):
      for j in range(self.size):
        if(state[i][j] == 0): 
          move.append([i,j])
    np.random.shuffle( move)
    return move


  def countSquare(self,state):
    count = 0
    for i in range(self.size):
      for j in range(self.size):
        if(state[i][j] == 0): 
          count+=1
    return count

  def executeMove(self,state,move,turn):
    #   print("execute move ",turn)
      player = 2 if turn else 1
      state[move[0]][move[1]] = player
      return state


  def engine(self,debug = False):
      print("turn = ",self.turn)
      n = Node(self,turn=not self.turn)
      n.state = self.state.copy()
      n.visited = 1
      n.isRoot = True
      mcts = Mcts(self,not self.turn,n)
      move =  mcts.execute()
    #   del mcts
      if debug:
          return mcts
      return move
    #   return mcts
    #   return n

  def printState(self,state):

    """ Display the board state node.state"""

    for i in range(self.size-1,-1,-1): # printing states in reverse
      for j in range(self.size): 
        if state[i][j] == 0 :
          curr = ' '
        elif state[i][j] == 1:
          curr = 'X'
        else:
          curr = 'O'  
        if j != self.size-1:
          print(f"{curr} | ",end = '')
        else:
          print(f"{curr}")
      if i != 0:
        print("--"*self.size*2)
                                                                                                                                     

  def play(self,youFirst = True):
    print("\t TicTacToe \n Bot = X \n You = O ")
    self.turn = youFirst
    while True:
      clear_output(True)
      res = self.isComplete(self.state)
      if res in [0,1,2]:
        self.printState(self.state)
        if res == 0:
          print("Draw")
          break
        print(f"Player {res} Won")
        break

      print(f"Player {self.turn} thinking ... ")
      
      self.printState(self.state)
      if self.turn:
          a = int(input("intput number"))-1
          x,y = a//3, a%3
          self.state = self.executeMove(self.state,[x,y],self.turn)
      else:
          move = self.engine()
          self.state = self.executeMove(self.state,move,self.turn)
      
      
      
      self.turn = not self.turn
    #   time.sleep(0.5)
      


In [None]:
g = Env()
# g.state = np.array([[0,0,0],[0,1,0],[2,0,2]],np.int8)
# g.printState(g.state)

g.play(youFirst=False)

In [31]:
t = Env()
t.state = np.array([[0,2,0],[0,1,0],[0,2,1]],np.int8)
t.turn = False
m = t.engine(debug = True)
node = m.root 

turn =  False


In [32]:
for i in node.childNodes:
    i.print()
    print(i.value)
    print(i.visited)
    print(i.win)

  | O | X
------------
  | X | X
------------
  | O |  
----
1.0
1
1
  | O | X
------------
X | X |  
------------
  | O |  
----
1.0996787044197527
1267
1245
X | O | X
------------
  | X |  
------------
  | O |  
----
1.0996571086172202
1115
1087
  | O | X
------------
  | X |  
------------
X | O |  
----
1.099676819140401
1747
1747
  | O | X
------------
  | X |  
------------
  | O | X
----
1.0996758417451507
1747
1747


In [20]:
# node.findChilds()
node.print()
n1 = m.selection(node)
n1.print()

X |   |  
------------
O | O | O
------------
X |   |  
----
X |   |  
------------
O | O | O
------------
X | X |  
----


In [None]:
# n1 == node

False

In [None]:
for j in range(10):
    val = m.rollout(n1)
    m.backpropagate(n1,val)
    print(val)

10
-10
10
10
10
10
10
10
-10
10


In [None]:
n1.value 

2.5748662386428496

In [None]:
print(n1.win,n1.visited,node.visited)

8 10 1826


In [None]:
m.nodeValue(n1.win,n1.visited,node.visited)

2.5331916294691608

In [None]:
n2 = n2.best()
n2.print()

X |   |  
------------
  | X |  
------------
O |   |  
----


In [None]:
for i in range(50):
    n3 = m.selection(n2)
    

  |   | O
------------
  | X | O
------------
X | X |  
----
None
0.0
1
  |   | O
------------
  | X | O
------------
  | X | X
----
None
1.0601855883790088
9393
  |   | O
------------
X | X | O
------------
  | X |  
----
None
1.0600437880513882
2425
X |   | O
------------
  | X | O
------------
  | X |  
----
None
1.0601200790995824
5826
  | X | O
------------
  | X | O
------------
  | X |  
----
None
1.0601832494631187
11347


  |   |  
------------
  |   |  
------------
  |   |  
  |   |  
------------
  |   |  
------------
  |   |  


In [None]:
node.isExpanded

True

In [None]:
n1 =  m.selection(node)
n1.print()

X |   | O
------------
X |   |  
------------
X | O |  


In [None]:
# m.rollout(n1)
print(n1.win/n1.visited,n1.value)

1.0 1.0


In [34]:
! ls


sample_data


In [None]:
h = node.childNodes[7]
print(h.childNodesExtracted)
print(h.childNodes)

False
[]


In [None]:
class TestNotebook(unittest.TestCase):

    def test_Node(self):
        n = Node(Env,True)
        self.assertEqual(n.value, float("inf"))

    def test_add_child(self):
        n = Node(Env,True)
        n.value = 9

        n1 = Node(Env,True)
        n1.value = 5
        n.addChild(n1)
        child = n.childNodes[0]
        self.assertEqual(child.value, 5)
        parent = child.parent
        self.assertEqual(parent.value, 9)

unittest.main(argv=[''], verbosity=0, exit=False)

In [None]:
class base:
    def __init__(self):
        self.p = 9
        print("ini")

class derived(base):
    def __init__(self):
        base.__init__(self)
    # self.q = 4

d = derived()
d.p

# class Game(Mcts,Env):
#     def __init__(self):
#         Env.__init__(self)
#         Mcts.__init__(self)

#     def findChilds(self):
#         # find all childs and add them
#         # self.isExpanded = True
#         print(Env.getMove(self,self.state))



ini


9

In [None]:
if float('inf') > 7:
    print("t")

t
