In [8]:
from MCTS import MCTS
from sogo.SogoGame import SogoGame, display as display_board
import numpy as np
from sogo.keras.NNet import NNetWrapper as NNet

class Config(object):
    def __init__(self):    
      self.num_sampling_moves = 30
      self.max_moves = 512  # for chess and shogi, 722 for Go.
      self.numMCTSSims = 200

      # Root prior exploration noise.
      self.root_dirichlet_alpha = 0.3  # for chess, 0.03 for Go and 0.15 for shogi.
      self.root_exploration_fraction = 0.0

      # UCB formula
      self.pb_c_base = 19652
      self.pb_c_init = 1.25

      # Load model

      self.load_model = True
      self.load_folder_file = ('./temp/','best.pth.tar')

config = Config()
game = SogoGame(4)
nn = NNet(game)
nn.load_checkpoint(*(config.load_folder_file))
mcts = MCTS(game, nn, config)

def mcts_player(x, player):
    canonical_board = game.getCanonicalForm(x, player)
    pi, root = mcts.get_action_prob(canonical_board)
    return np.argmax(pi), root



In [64]:
from NeuralNet import NeuralNet
from Game import Game

class NN(NeuralNet):
  def __init__(self,game:Game):
    self.game = game
  def predict(self, board):
    return np.ones(self.game.getActionSize())/self.game.getActionSize(), 0


dummy_nn = NN(game)
dummy_mcts = MCTS(game, dummy_nn, config)

In [44]:
import time

class Timer:
    def __init__(self, msg):
        self.msg = msg
    def __enter__(self):
        self.start = time.clock()
        return self

    def __exit__(self, *args):
        self.end = time.clock()
        self.interval = self.end - self.start
        if self.msg:
            print(f"{self.msg} took {self.interval:0.3f} sec")

In [38]:
def setup_board(plays,do_print=True): 
    board = game.getInitBoard()
    player = 1    
    for play in plays:
        board, player = game.getNextState(board, player,play)
    if do_print:
        display_board(board)
    return board, player

In [5]:
def test_mcts(plays, expected):
    board, player = setup_board(plays)    
    with Timer() as t:
        play, root = mcts_player(board, player)    
    new_board, new_player = gane.getNextState(board, player, play)
    display_board(new_board)
    print(f"MCTS made {'correct' if play == expected else 'incorrect' } play in  {t.interval:0.3f} sec")
    return root

In [66]:
def nn_pred(plays):
    b,p = setup_board(plays, True)
    b = game.getCanonicalForm(b,p)
    with Timer("NN prediction"):
        pi, v = nn.predict(b)
    print(f"Probs: {np.array2string(pi, precision=2, separator=',', suppress_small=True, max_line_width=200)} Value: {v[0]:0.2f}")

In [67]:
def mcts_pred(plays):
    b,p = setup_board(plays, False)
    canonical_board = game.getCanonicalForm(b, p)
    with Timer("MCTS prediction"):
        pi, root = mcts.get_action_prob(canonical_board)
    print(f"Probs: {np.array2string(np.array(pi), precision=2, separator=',', suppress_small=True, max_line_width=200)}")

In [68]:
def mcts_only_pred(plays):
    b,p = setup_board(plays, False)
    canonical_board = game.getCanonicalForm(b, p)
    with Timer("MCTS only prediction"):
        pi, root = dummy_mcts.get_action_prob(canonical_board)
    print(f"Probs: {np.array2string(np.array(pi), precision=2, separator=',', suppress_small=True, max_line_width=200)}")

In [69]:
config.numMCTSSims = 1000
play = [0,1,0,1,0]
nn_pred(play)
mcts_pred(play)
mcts_only_pred(play)


z3+--------+
3 |- - - - |
2 |- - - - |
1 |- - - - |
0 |- - - - |
z3+--------+
   0 1 2 3 
z2+--------+
3 |- - - - |
2 |- - - - |
1 |- - - - |
0 |O - - - |
z2+--------+
   0 1 2 3 
z1+--------+
3 |- - - - |
2 |- - - - |
1 |- - - - |
0 |O X - - |
z1+--------+
   0 1 2 3 
z0+--------+
3 |- - - - |
2 |- - - - |
1 |- - - - |
0 |O X - - |
z0+--------+
   0 1 2 3 
--
NN prediction took 0.013 sec
Probs: [0.12,0.05,0.06,0.08,0.04,0.05,0.05,0.05,0.04,0.04,0.04,0.07,0.09,0.06,0.09,0.07] Value: 0.96
MCTS prediction took 15.910 sec
Probs: [0.76,0.03,0.01,0.01,0.01,0.02,0.01,0.01,0.02,0.02,0.01,0.01,0.03,0.03,0.01,0.02]
MCTS only prediction took 1.844 sec
Probs: [0.69,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02,0.02]
