In [3]:
import numpy as np
from copy import deepcopy
from sklearn.preprocessing import LabelBinarizer
from model import create_net
from tqdm import tqdm
from tsp import evaluate

def dirichlet_noise(size, eta=None):
    if eta is None:
        eta = 8/size
    noise_vec = []

    for _ in range(size):
        noise_vec.append(eta)

    noise = np.random.dirichlet(noise_vec, 1)
    return noise[0]

def convert_to_string(arr):
    string_ver = ""
    for num in arr:
        string_ver += str(num) + "_"
    string_ver = string_ver[:-1]
    return string_ver

def create_node(parent, action, proba):
    print("in create node")
    node = {}
    node["children"] = []
    node["visits"] = 0
    node["total_value"] = 0
    node["mean_value"] = 0
    node["probability"] = proba
    node["parent"] = parent
    node["action"] = action
    node["depth"] = parent["depth"] + 1
    return node

def select(node, is_stochastic=False, most_visits=False, T=1):
    print("in select")
    scores = []
    probabilities = []
    c = np.log(2)
    all_visits = np.sum([child["visits"] for child in node["children"]])
    for i, child in enumerate(node["children"]):
        visits = all_visits - child["visits"]
        if visits == 0:
            visits = 1
        U = c * child["probability"] * np.log(visits)/(1 + child["visits"])
#         U = child["probability"]/(1 + child["visits"])
        print("U: {}, child visits: {}, child proba: {}, child mean value: {}".format(U, child["visits"],
                                                                                     child["probability"],
                                                                                     child["mean_value"]))
        score = child["mean_value"] + U
        scores.append(score)
        if all_visits == 0:
            proba = 0
        else:
            proba = (child["visits"]**(1/T))/all_visits**(1/T)
        probabilities.append(proba)
        
    probabilities = np.array(probabilities)
    if np.sum(probabilities) == 0:
        probabilities = np.zeros(len(probabilities))
        probabilities += 1/len(probabilities)
    
    if is_stochastic:
        node = np.random.choice(node["children"], p = probabilities)
    elif most_visits:
        node = node["children"][np.argmax(probabilities)]
    else:
        node = node["children"][np.argmax(scores)]
        
    return node, probabilities
    
def _transition(state, node):
    print("in transition")
    state.board[node["depth"]] = node["action"]
            
def expand(state, node):
    print("in expand")
    ohe_board = state.get_ohe_board()
    (probabilities, value) = state.net.predict(ohe_board)
    probabilities = probabilities[0]
    value = value[0][0]
    eta = .25
    if node["is_root"]:
        probabilities = probabilities * (1 - eta) + eta*dirichlet_noise(size=len(probabilities))
    for i, proba in enumerate(probabilities):
        action = state.move_list_numeric[i]
        n = create_node(node, action, proba)
        node["children"].append(n)
        
    return value
        
def backup(node, value, orig_depth):
    print("in backup")
    while "parent" in node:
        if node["depth"] > orig_depth:
            node["visits"] += 1
            node["total_value"] += value
            node["mean_value"] = node["total_value"]/node["visits"]
        node = node["parent"]
        
    return node

# consider switching to in memory sqlite (s, a) records

def MCTS(root_state=None, root_node=None, is_stochastic=False, num_simulations=1, 
         max_steps=None, resign_threshold=None, resign_disabled=False):
    if root_node is None:
        root_node = {"children": [], "depth": -1, "mean_value": None}
        
    node = deepcopy(root_node)
    
    node["is_root"] = True

    for i in tqdm(range(num_simulations)):    
        state = root_state.clone()
       
        while node["children"] != []:
            node, probas = select(node, is_stochastic)
            _transition(state, node)
            
        value = expand(state, node)
        node = backup(node, value, root_node["depth"])
     
    node, probas = select(node, most_visits=True)
    if (not resign_disabled and 
        root_node["mean_value"] is not None and 
        root_node["mean_value"] < resign_threshold and 
        node["mean_value"] < resign_threshold):
        result = -1
    else:
        result = evaluate(state)
        
    board = convert_to_string(state.board)
    probas = convert_to_string(probas)
        
    db.insert("""INSERT INTO results VALUES (?, ?, ?)""", (board, probas, result))
    
    return state, node