In [5]:
from MCTS import MCTS

from sogo.SogoGame import SogoGame, display as display_board
import numpy as np
from sogo.keras.NNet import NNetWrapper as NNet
from Timer import Timer

def display_probs(pi, prefix="Probs"):
    print(f"{prefix}: {np.array2string(np.array(pi), precision=2, separator=',', suppress_small=True, max_line_width=200)}")

class Config(object):
    def __init__(self):    
      self.num_sampling_moves = 30
      self.max_moves = 512  # for chess and shogi, 722 for Go.
      self.num_mcts_sims = 50

      # Root prior exploration noise.
      self.root_dirichlet_alpha = 0.3  # for chess, 0.03 for Go and 0.15 for shogi.
      self.root_exploration_fraction = 0.0

      # UCB formula
      self.pb_c_base = 19652
      self.pb_c_init = 1.25

      # Load model

      self.load_model = True
      self.load_folder_file = ('./save/','mixed3.h5')
        
    def initialize(self):
        global config 
        config = self
        self.game = SogoGame(4)

        self.nn = NNet(self.game)
        self.nn.load_checkpoint(*(config.load_folder_file))

        self.mcts = MCTS(self.game, self.nn, self)

    def setup_board(self, plays,verbose=True): 
        board = self.game.init_board()
        player = 1    
        for play in plays:
            board, player = self.game.next_state(board, player,play)
        return board, player
    
    def mcts_pred(self, plays, root=None, verbose=False):
        b,p = self.setup_board(plays, verbose = verbose)
        with Timer("MCTS prediction"):
            pi, root = self.mcts.get_action_prob(b, p, root)
        display_probs(pi,"MCTS")
        return pi
          
    def test(self,blah):
          return np.sum(blah)

          
config = Config()
          
def mcts_task(value):
    global config
    return config.mcts_pred(value)

In [6]:
from multiprocessing import Pool

In [7]:
pool = Pool(4, initializer=config.initialize)

MCTS prediction took 2.522 sec
MCTS prediction took 2.540 sec
MCTS prediction took 2.551 sec
MCTS: [0.02,0.04,0.  ,0.02,0.02,0.2 ,0.  ,0.02,0.02,0.  ,0.  ,0.02,0.31,0.02,0.  ,0.31]
MCTS: [0.  ,0.  ,0.  ,0.  ,0.  ,0.  ,0.06,0.  ,0.  ,0.1 ,0.  ,0.  ,0.  ,0.  ,0.  ,0.84]
MCTS: [0.16,0.  ,0.  ,0.16,0.  ,0.08,0.1 ,0.  ,0.  ,0.08,0.1 ,0.  ,0.16,0.  ,0.  ,0.14]
MCTS prediction took 1.385 sec
MCTS: [0.02,0.04,0.  ,0.02,0.02,0.2 ,0.  ,0.02,0.02,0.  ,0.  ,0.02,0.31,0.02,0.  ,0.31]
MCTS prediction took 1.387 sec
MCTS: [0.  ,0.  ,0.  ,0.  ,0.  ,0.  ,0.06,0.  ,0.  ,0.1 ,0.  ,0.  ,0.  ,0.  ,0.  ,0.84]
MCTS prediction took 2.213 sec
MCTS: [0.16,0.  ,0.  ,0.16,0.  ,0.08,0.1 ,0.  ,0.  ,0.08,0.1 ,0.  ,0.16,0.  ,0.  ,0.14]


In [9]:
pool.map(mcts_task,[[],[0,1,0],[0,0,0,0]])

[[0.16326530612244897,
  0.0,
  0.0,
  0.16326530612244897,
  0.0,
  0.08163265306122448,
  0.10204081632653061,
  0.0,
  0.0,
  0.08163265306122448,
  0.10204081632653061,
  0.0,
  0.16326530612244897,
  0.0,
  0.0,
  0.14285714285714285],
 [0.02040816326530612,
  0.04081632653061224,
  0.0,
  0.02040816326530612,
  0.02040816326530612,
  0.20408163265306123,
  0.0,
  0.02040816326530612,
  0.02040816326530612,
  0.0,
  0.0,
  0.02040816326530612,
  0.30612244897959184,
  0.02040816326530612,
  0.0,
  0.30612244897959184],
 [0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.061224489795918366,
  0.0,
  0.0,
  0.10204081632653061,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.8367346938775511]]