In [1]:
import numpy as np
import random
import datetime
import os
from collections import deque

In [2]:
from hexboard import HexBoard
from MCTS import MCTS, getLegalMoves, copy, goalValue
from nextki import HexKI
from Game import Game

In [3]:
import cProfile
from numba import jit
import timeit

In [16]:
KI = HexKI(11,11,'inter','MCTS',HexBoard(11,11))
KI.string = 'KI'
cProfile.run('MCTS(KI)')

         77606024 function calls (77586024 primitive calls) in 52.408 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   52.408   52.408 <string>:1(<module>)
    30992    0.020    0.000    0.020    0.000 MCTS.py:119(Backprop)
        1    0.095    0.095   52.408   52.408 MCTS.py:133(MCTS)
    10000    0.022    0.000    0.048    0.000 MCTS.py:16(addChildren)
      121    0.000    0.000    0.000    0.000 MCTS.py:175(<lambda>)
    20992    0.008    0.000    0.008    0.000 MCTS.py:28(isExpanded)
    20992    0.017    0.000    0.212    0.000 MCTS.py:34(isTerminal)
    10992    0.023    0.000    3.821    0.000 MCTS.py:42(UCT)
    10992    0.053    0.000    0.053    0.000 MCTS.py:43(<listcomp>)
  1328919    3.265    0.000    3.265    0.000 MCTS.py:46(<lambda>)
        1    0.000    0.000    0.000    0.000 MCTS.py:51(__init__)
    20001    0.038    0.000    1.479    0.000 MCTS.py:56(getLegalMoves)
       

We see that UCT is taking 1.3 seconds of those 3 sec. And further the lambda from line 46(in UCT) takes 1 sec bc it gets called 400k times! Let's see if we can speed this up with Numba JIT!

The UCT lambda function: 2 versions for it and how UCT should work with the non-lambda way

In [6]:
UCB = lambda x: x.reward/x.visits + np.sqrt(np.log(self.visits)/x.visits)
def func(x, y, z):
    return x/y + np.sqrt(np.log(z)/y)

In [7]:
jitu = jit()(func)

In [8]:
jitu(2, 3, 4)

1.3464446601125393

In [9]:
jitu.inspect_types()

func (int64, int64, int64)
--------------------------------------------------------------------------------
# File: <ipython-input-6-f1befe73a90c>
# --- LINE 2 --- 
# label 0
#   del x
#   del $0.4
#   del $0.6
#   del z
#   del $0.7
#   del y
#   del $0.9
#   del $0.5
#   del $0.11
#   del $0.3
#   del $0.12
#   del $0.13

def func(x, y, z):

    # --- LINE 3 --- 
    #   x = arg(0, name=x)  :: int64
    #   y = arg(1, name=y)  :: int64
    #   z = arg(2, name=z)  :: int64
    #   $0.3 = x / y  :: float64
    #   $0.4 = global(np: <module 'numpy' from 'C:\\Users\\Ernst\\Anaconda3\\lib\\site-packages\\numpy\\__init__.py'>)  :: Module(<module 'numpy' from 'C:\\Users\\Ernst\\Anaconda3\\lib\\site-packages\\numpy\\__init__.py'>)
    #   $0.5 = getattr(value=$0.4, attr=sqrt)  :: Function(<ufunc 'sqrt'>)
    #   $0.6 = global(np: <module 'numpy' from 'C:\\Users\\Ernst\\Anaconda3\\lib\\site-packages\\numpy\\__init__.py'>)  :: Module(<module 'numpy' from 'C:\\Users\\Ernst\\Anaconda3\\lib\\site

In [10]:
def bad(x, y, z):
    for i in range(400000):
        func(x, y, z)

def bad1(x, y ,z):
    for i in range(400000):
        jitu(x, y, z)

In [11]:
plain = %timeit -o bad(2, 3, 4)
jitted = %timeit -o bad1(2, 3, 4)

1.03 s ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
113 ms ± 383 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
print(plain.best)
print(jitted.best)
print(plain.best/jitted.best)

1.0323849999999997
0.11214745999999991
9.205603051553735


In [13]:
from copy import deepcopy

class Node:
    def __init__(self, Board = None, pa = None, vis=0.0, rew = 0.0):
        self.state = Board
        self.parent = pa
        self.visits = vis
        self.reward = rew
        self.children = []
        self.unusedmoves = getLegalMoves(self.state)

    def addChildren(self, ls, move):
        if ls not in self.children:
            self.children.append(ls)
            self.unusedmoves.remove(move)

    def isLeaf(self):
        if not self.children:
            return True
        else:
            return False

##  Checks if Node in Tree has tried all possible moves atleast once      
    def isExpanded(self):
        if self.unusedmoves == []:
            return True
        else:
            return False

    def isTerminal(self):
        if self.state.finished():
            return True
        else:
            return False
    
##  Handles Exploration vs. Exploitation. Therefore Nodes get selected by
##  already known outcome and C adds a possibility Bonus for unexplored Nodes    
    def UCT(self):
        legalchildren = [child for child in self.children]
        if legalchildren == []:
            return self 
        liste = [jitu(x.reward, x.visits, self.visits) for x in legalchildren]
        return legalchildren[liste.index(max(liste))]

##  Container for Nodes
class SearchTree():
    def __init__(self, r=None):
        self.root = Node(r)
        #self.Tree = [self.root]


def getLegalMoves(board):
    height = len(board.board)
    width = len(board.board[0])
    if board.zug == 0:
        legalmoves = [(i, j) for i in range(height) for j in range(width)]
    else:
        legalmoves = [(i, j) for i in range(height) for j in range(width) if board.board[i][j] == 0]
    return legalmoves

def copy(Board):
    m = len(Board.board[0])
    a, b = Board.zug, Board.no_filled
    newBoard = HexBoard(m, m)
    newBoard.board = deepcopy(Board.board)
    newBoard.starter = Board.starter
    newBoard.win = Board.win
    newBoard.zug = a
    newBoard.lastmove = Board.lastmove
    newBoard.no_filled = b
    newBoard.swap = Board.swap
    return newBoard 

def Expand(node, Board,
           #Tree
          ):
    Board.receiveMove(random.choice(node.unusedmoves))
    newNode = Node(copy(Board), node)
    newNode.state.lastmove = deepcopy(Board.lastmove)
    node.addChildren(newNode, Board.lastmove)
    #Tree.append(newNode)
    return newNode

def Simulate(legalmoves, Board, KI):
    while not Board.finished():
        move = random.choice(legalmoves)
        Board.receiveMove(move)
        legalmoves.remove(move)
    return goalValue(Board, KI)

def goalValue(Board, KI):
    if Board.starter == KI.string:
        if Board.swap:
            if Board.winner() == 1:
                return -1
            else:
                return 1
        else:
            if Board.winner() == 1:
                return 1
            else:
                return -1
    else:
        if Board.swap:
            if Board.winner() == 1:
                return 1
            else:
                return -1
        else:
            if Board.winner() == 1:
                return -1
            else:
                return 1

def Backprop(node, R):
    node.reward += R
    node.visits += 1
    if node.parent != None:
        return node.parent
    else:
        return node

##  Selection till Node that's not terminal and not expanded
##  Expansion of selected Node, it gets expanded by one possible move of
##  untried moves and appended to Tree
##  Simulation of expanded Node
##  Backprop traverses Tree upwards till root, updating simulated outcome
##  on every Node on the trace(root within)
def MCTS1(KI):
    #timelimit = datetime.timedelta(seconds = 5)
    #begin = datetime.datetime.now()
    #timelimit >= datetime.datetime.now() - begin
    S = SearchTree(KI.startState)
    cur = S.root
    last = cur
    i = 0
    while i < 10000:
        Board = copy(last.state)

        #Selection
        while last.isExpanded() and not last.isTerminal():
            last = last.UCT()
            Board.receiveMove(last.state.lastmove)
        
        #Expansion
        if last.unusedmoves and not last.isTerminal():
            last = Expand(last, Board
            #, S.Tree
            )
        
        #Simulation
        legalmoves = getLegalMoves(Board)
        R = Simulate(legalmoves, Board, KI)
        
        #Backpropagation
        while last.parent != None:
            last = Backprop(last, R)
        last = Backprop(last, R)
        i += 1
    
    return max(S.root.children, key = lambda x: x.visits).state.lastmove, i#, S.Tree

In [17]:
KI = HexKI(11,11,'inter','MCTS',HexBoard(11,11))
KI.string = 'KI'

In [18]:
cProfile.run('MCTS1(KI)')

         77871955 function calls (77851955 primitive calls) in 50.791 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    31038    0.021    0.000    0.021    0.000 <ipython-input-13-d55827235358>:115(Backprop)
    10000    0.022    0.000    0.049    0.000 <ipython-input-13-d55827235358>:12(addChildren)
        1    0.101    0.101   50.791   50.791 <ipython-input-13-d55827235358>:129(MCTS1)
      121    0.000    0.000    0.000    0.000 <ipython-input-13-d55827235358>:161(<lambda>)
    21038    0.008    0.000    0.008    0.000 <ipython-input-13-d55827235358>:24(isExpanded)
    21038    0.017    0.000    0.208    0.000 <ipython-input-13-d55827235358>:30(isTerminal)
    11038    0.026    0.000    1.527    0.000 <ipython-input-13-d55827235358>:38(UCT)
    11038    0.053    0.000    0.053    0.000 <ipython-input-13-d55827235358>:39(<listcomp>)
    10001    0.014    0.000    0.706    0.000 <ipython-input-13-d55827235358>:4(__init_

Ok our jitted function takes 0.146sec and the normal lambda-version would take 3.265sec. What an improve. But we need to fix the Simulationphase... to be more precise Simulate takes 46.655sec, finished(inside Simulate) takes 40.783sec and adjazenz(inside finished) 20.842sec. Can we fix this via numba? No because those functions dont work with ctypes. Finished is using BFS. There might be a solution by reforming the board.