In [1]:
import sys
sys.path.append('../classes')

In [2]:
import logging
import os
import sys
from collections import deque
from pickle import Pickler, Unpickler
from random import shuffle

import numpy as np
from tqdm import tqdm

from arena import Arena
from mcts import MCTS

In [3]:
import logging

import coloredlogs

from coach import Coach
from beck.beck_game import BeckGame as Game
from beck.beck_nnet import NNetWrapper as nn
from utils import *


In [4]:
args = dotdict({
    'numIters': 2,#1000,
    'numEps': 2,#100,              # Number of complete self-play games to simulate during a new iteration.
    'tempThreshold': 15,        #
    'updateThreshold': 0.6,     # During arena playoff, new neural net will be accepted if threshold or more of games are won.
    'maxlenOfQueue': 2,#200000,    # Number of game examples to train the neural networks.
    'numMCTSSims': 3,#80,          # Number of games moves for MCTS to simulate.
    'arenaCompare': 3,#60,         # Number of games to play during arena play to determine if new net will be accepted.
    'cpuct': 3,

    'checkpoint': './temp/',
    'load_model': False,
    'load_folder_file': ('/dev/models/8x100x50','best.pth.tar'),
    'numItersForTrainExamplesHistory': 20,

})

In [65]:
# del(g)
del(nnet)

In [68]:
g = Game(4, 4, 4)
nnet = nn(g)
# c = Coach(g, nnet, args)

In [74]:
EPS = 1e-8
import math

class MCTS:
    def __init__(self,game,nnet,args):
        self.game = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {} #Q for s,a
        self.Nsa = {} #N for s,a
        self.Ns = {}
        self.Ps = {} #prior policy
        self.Es = {} #game.getGameEnded
        self.Vs = {} #game.getValidMoves
        
    def search(self, canonicalBoard):
#         s = self.game.stringRepresentation(canonicalBoard)
        s = np.array_str(canonicalBoard)
        
        # if end of game, back propagate game result as value
        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
        if self.Es[s]!=0:
            return -self.Es[s]
            
        
        # if a new leaf, add to the tree and end one search
        # use whether it's in self.Ps to check, because need self.Ps to do action selection
        if s not in self.Ps:
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids = self.game.getValidMoves(canonicalBoard,1)
            self.Ps[s] = self.Ps[s]*valids
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s
            else:
                log.error("all valid moves masked")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])
            self.Vs[s] = valids
            self.Ns[s] = 0 
            return -v
        
        # if already traversed, keep moving by selecting an action
        a = self.get_action_post(s,typ='ucb')
        # simulate and get value
        next_s, next_player = self.game.getNextState(canonicalBoard,1,a)
        next_s = self.game.getCanonicalForm(next_s, next_player)
        v = self.search(next_s) # V(s'), the value of the child state
        
        # online update of the q value with V(s'):
        if (s,a) in self.Qsa:
            self.Qsa[(s,a)] = (self.Nsa[(s,a)]*self.Qsa[(s.a)]+v)/(self.Nsa[(s,a)]+1)
            self.Nsa[(s,a)] += 1
        else:
            self.Qsa[(s,a)]=v
            self.Nsa[(s,a)]=1
        self.Ns[s] +=1
        return -v
    

    def get_action_post(self,s,typ='ucb'):
        valids = self.Vs[s]
        best_a = -1
        best_u = -float('inf')
        
        if typ=='ucb':
            for a in range(self.game.getActionSize()):
                if valids[a]:
                    if (s,a) in self.Qsa:
                        u = self.Qsa[(s,a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)])
                        
                    else:
                        u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]+EPS)
                    

                    
                    if u > best_u:
                        best_u = u
                        best_a = a
            
            a = best_a
        
        return a
            
            
        
        
        
        
  

In [75]:
board = g.getInitBoard()

In [76]:
mcts = MCTS(g,nnet,args)

In [83]:
mcts.search(board)

array([0.00074608], dtype=float32)

In [61]:
mcts.Nsa.values()

dict_values([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [84]:
mcts.Nsa

{('[[0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]]', 0): 1,
 ('[[0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]]', 1): 1,
 ('[[0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]\n [0 0 0 0]]', 2): 1}

In [94]:
mcts.Qsa

{(b'\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00

In [4]:
import tournament

{'mcts100_cpuct1': [1, 2, 3, 4, 6, 7, 9, 11, 12, 16, 18, 21, 25, 28, 29, 30, 37, 38, 39, 45, 47], 'mcts100_cpuct2': [1, 2, 3, 4, 6, 8, 12, 14, 15, 16, 21, 22, 24, 26, 27, 30, 32, 35, 39], 'mcts100_cpuct3': [1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 16, 17, 18, 20, 21, 25, 26, 28, 32, 34, 35, 36], 'mcts25_cpuct1': [1, 2, 3, 4, 9, 11, 13, 14, 16, 19, 22, 25, 28, 29, 31, 37, 39, 45, 59, 61], 'mcts50_cpuct1': [1, 2, 3, 4, 5, 6, 7, 9, 10, 13, 17, 18, 20, 21, 23, 25, 28, 29, 31, 37, 39, 41, 44, 45, 47, 48, 53], 'mcts80_cpuct1': [1, 2, 3, 4, 5, 7, 8, 9, 10, 13, 16, 19, 20, 21, 23, 24, 28, 34, 35, 36, 40], 'mcts80_cpuct2': [1, 2, 5, 6, 8, 10, 12, 13, 15, 17, 18, 19, 22, 25, 28, 30, 31, 34, 37, 38, 39, 40, 43, 45, 46, 51], 'mcts80_cpuct3': [1, 2, 5, 6, 7, 9, 12, 13, 14, 19, 21, 22, 23, 24, 26, 27, 28, 35, 36, 42, 43, 47, 49, 50, 51, 53, 54, 55, 56, 57]}
189 participant iterations!


In [8]:
list(map(len,tournament.iters.values()))

[21, 19, 23, 20, 27, 21, 26, 30]

In [10]:
np.linspace(0,1,10)

array([0.        , 0.11111111, 0.22222222, 0.33333333, 0.44444444,
       0.55555556, 0.66666667, 0.77777778, 0.88888889, 1.        ])

In [41]:
def select_n_instances_each_from_iters(iters,n):
    '''
    uneven sampling; first 1/3 of the iters sample 2/3 of the n, vice versa
    Because rapid change in model in early iterations. 
    skip 0th index, since it's probably easy to beat. 
    '''
    first_third_n = int(2/3 * n)
    rest_n = n - first_third_n
    iters_subsamp = {}
    for k,v in iters.items():
        N = len(v)
        first_third_N = int(N)*1/3
        first_third_inds = np.linspace(1,first_third_N-1,first_third_n).astype(int)
        rest_inds = np.linspace(first_third_N-1,N-1,rest_n + 1).astype(int)[1:] # drop the first, since duplicate from the above
    
        tot_inds = np.concatenate([first_third_inds,rest_inds]).astype(int)
        iters_subsamp[k] = np.array(v)[tot_inds]
    return iters_subsamp

In [42]:
iters = tournament.iters
iters_human = select_n_instances_each_from_iters(iters,6)

[1, 2, 3, 4, 6, 7, 9, 11, 12, 16, 18, 21, 25, 28, 29, 30, 37, 38, 39, 45, 47]
[ 1  2  4  6 13 20]
[1, 2, 3, 4, 6, 8, 12, 14, 15, 16, 21, 22, 24, 26, 27, 30, 32, 35, 39]
[ 1  2  3  5 11 18]
[1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 16, 17, 18, 20, 21, 25, 26, 28, 32, 34, 35, 36]
[ 1  2  4  6 14 22]
[1, 2, 3, 4, 9, 11, 13, 14, 16, 19, 22, 25, 28, 29, 31, 37, 39, 45, 59, 61]
[ 1  2  4  5 12 19]
[1, 2, 3, 4, 5, 6, 7, 9, 10, 13, 17, 18, 20, 21, 23, 25, 28, 29, 31, 37, 39, 41, 44, 45, 47, 48, 53]
[ 1  3  5  8 17 26]
[1, 2, 3, 4, 5, 7, 8, 9, 10, 13, 16, 19, 20, 21, 23, 24, 28, 34, 35, 36, 40]
[ 1  2  4  6 13 20]
[1, 2, 5, 6, 8, 10, 12, 13, 15, 17, 18, 19, 22, 25, 28, 30, 31, 34, 37, 38, 39, 40, 43, 45, 46, 51]
[ 1  3  5  7 16 25]
[1, 2, 5, 6, 7, 9, 12, 13, 14, 19, 21, 22, 23, 24, 26, 27, 28, 35, 36, 42, 43, 47, 49, 50, 51, 53, 54, 55, 56, 57]
[ 1  3  6  9 19 29]
