In [2]:
from vinemcts.mcts import *
import numpy as np
import random

In [3]:
corr = np.array([
    [1.00, 0.40, 0.15, 0.41, 0.32, 0.62],
    [0.40, 1.00, 0.45, 0.54, 0.76, 0.48],
    [0.15, 0.45, 1.00, 0.20, 0.51, 0.26],
    [0.41, 0.54, 0.20, 1.00, 0.42, 0.41],
    [0.32, 0.76, 0.51, 0.42, 1.00, 0.36],
    [0.62, 0.48, 0.26, 0.41, 0.36, 1.00]
])
n_sample = 1000
ntrunc = 2
output_dir = "temp"
itermax = 3000
FPU = 1.5
PB = 0.1
log_freq = 500

In [4]:
random.seed(0)
np.random.seed(0)
corr_mat = CorrMat(corr, n_sample)
root_state = VineState(ntrunc=ntrunc, corr_mat=corr_mat)

transpos_table = {}
config = {
    # a dictionary: state -> node.
    'transpos_table': transpos_table,
    # UCB1 formula: \bar{x} + UCT_const \sqrt{log(n)/log(n_j)}
    'UCT_const': (-corr_mat.log_det()),
    # First Play Urgency
    'FPU': FPU,
    # Progressive Bias
    'PB': PB
}
root_node = MctsNode(config, root_state)

best_score = 0   # we want to maximize the score
best_vine = None

for i in range(itermax):
    node = root_node
    temp_node_list = [node]

    # Select
    while not node.is_leaf():
        node = node.select_child()
        temp_node_list.append(node)

    # Expand
    if node.visits > 0:
        # Only expand the leaf node if it has been visited.

        add_children_success = node.add_children()

        if add_children_success:
            node = node.select_child()
            temp_node_list.append(node)

    # Rollout
    score, vine = node.roll_out()

    if score > best_score:
        best_score = score
        best_vine = vine

    if i % log_freq == 0 and i > 0:
        print(output_dir + ', Iter %d: ' % i)
        print("best_score: " + str(best_score))

    # Backpropagate
    [node.update(score) for node in temp_node_list]


temp, Iter 500: 
best_score: 2.361811995744089
temp, Iter 1000: 
best_score: 2.361811995744089
temp, Iter 1500: 
best_score: 2.361811995744089
temp, Iter 2000: 
best_score: 2.361811995744089
temp, Iter 2500: 
best_score: 2.361811995744089


In [5]:
def show_node(node):
    print(node.state, node.visits, node.sum_score / node.visits)

## Level 1

In [6]:
show_node(root_node)

[] 3000 2.177163048297819


In [7]:
root_node.child_visits

[277, 142, 210, 182, 233, 169, 176, 229, 170, 204, 180, 252, 147, 210, 218]

In [8]:
show_node(root_node.child_nodes[0])

['0,1'] 277 2.2766059087961725


In [12]:
root_node.child_nodes[0].state.score - root_node.state.score # child state score - parent state score

0.1743533871447778

In [8]:
show_node(root_node.child_nodes[1])

['0,2'] 142 2.0468721622720474


In [13]:
root_node.child_nodes[1].state.score - root_node.state.score

0.02275698712261618

In [9]:
show_node(root_node.child_nodes[-1])

['4,5'] 218 2.2040281172297607


In [14]:
root_node.child_nodes[-1].state.score - root_node.state.score

0.1388024028804588

## Level 2

In [15]:
root_node.child_nodes[0].child_visits

[31, 34, 31, 37, 35, 36, 37, 35]

In [16]:
show_node(root_node.child_nodes[0].child_nodes[0])

['0,1', '0,2'] 59 2.195470337082897


In [18]:
root_node.child_nodes[0].child_nodes[0].state.score - root_node.child_nodes[0].state.score

0.022756987122616174

In [12]:
show_node(root_node.child_nodes[0].child_nodes[1])

['0,1', '0,3'] 60 2.26671243107355


In [19]:
root_node.child_nodes[0].child_nodes[1].state.score - root_node.child_nodes[0].state.score

0.18404303769229488

In [13]:
show_node(root_node.child_nodes[0].child_nodes[-1])

['0,1', '1,5'] 70 2.289627133104494


In [20]:
root_node.child_nodes[0].child_nodes[-1].state.score - root_node.child_nodes[0].state.score

0.26188437963064026

## Level 6

In [22]:
new_root = root_node.child_nodes[2].child_nodes[1].child_nodes[0].child_nodes[1].child_nodes[1]
show_node(new_root)

['0,1', '0,5', '1,2', '1,3', '1,4'] 93 2.3471707563756627


In [23]:
new_root.child_visits

[0, 32, 0, 25, 32, 0, 0]

In [24]:
show_node(new_root.child_nodes[1])

['0,1', '0,5', '1,2', '1,3', '1,4', '1,5|0'] 32 2.351193313514363


In [25]:
new_root.child_nodes[1].state.score - new_root.state.score

0.10991239190241187

In [19]:
new_root.child_nodes[1].roll_out()

(2.2075903978906553,
 ['0,1', '0,5', '1,2', '1,3', '1,4', '0,2|1', '0,4|1', '1,5|0', '2,3|1'])

In [17]:
show_node(new_root.child_nodes[3])

['0,1', '0,5', '1,2', '1,3', '1,4', '0,3|1'] 25 2.349568267360866


In [26]:
new_root.child_nodes[3].state.score - new_root.state.score

0.06533652348426289

In [30]:
new_root.child_nodes[3].roll_out()

(2.26964869204995,
 ['0,1', '0,5', '1,2', '1,3', '1,4', '0,2|1', '0,3|1', '0,4|1', '1,5|0'])

In [31]:
show_node(new_root.child_nodes[4])

['0,1', '0,5', '1,2', '1,3', '1,4', '2,4|1'] 32 2.3540433833923533


In [32]:
new_root.child_nodes[4].state.score - new_root.state.score

0.08750373891149543

In [21]:
new_root.child_nodes[4].roll_out()

(2.290779559933444,
 ['0,1', '0,5', '1,2', '1,3', '1,4', '0,4|1', '1,5|0', '2,4|1', '3,4|1'])