# MCTS

In [5]:
def search(obs, target_player=None):    
    current_player = env.current_player
    if True or target_player is None or target_player == env.current_player:        
        if target_player is None:
            target_player = env.current_player
            
        state = env.get_state()
        n = len(env.players[env.current_player].hand_cards)   
        p = 0
        distr = np.zeros((n,))

        for i in range(n):
            card_id = env.players[env.current_player].hand_cards[i].id
            obs, rew, is_done, _ = env.step(card_id)

            if is_done:
                distr[i] = (rew[0] > 0)
            else:
                distr[i] = ((1 - search(obs, target_player)[0]) if current_player != env.current_player else search(obs, target_player)[0])
            env.set_state(state)
            p += distr[i]
        
        return distr.max() if target_player == env.current_player else distr.max(), distr
    else:
        state = env.get_state()
        p = calc_correct_output(env.players[env.current_player].hand_cards, env.table_card, env.players[env.current_player].tricks, env.players[1 - env.current_player].tricks)
        env.set_state(state)
        step = env.players[env.current_player].hand_cards[np.argmax(p)].id
        
        obs, rew, is_done, _ = env.step(step)
        if is_done:
            return (rew[0] > 0), []
        else:
            return ((1 - search(obs, target_player)[0]) if current_player != env.current_player else search(obs, target_player)[0]), []

In [None]:

def draw_tree(root, tree_depth=5, tree_path=[]):
    dot = pydot.Dot()
    dot.set('rankdir', 'TB')
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    create_nodes(root, dot, tree_depth, tree_path)

   # print("Root: " + str(root.end_v if root.n is 0 else root.w / root.n) + " / " + str(root.n))
   # for child in root.childs:
   #     print( str(child.end_v if child.n is 0 else child.w / child.n) + " / " + str(child.n) + " p: " + str(child.p))

    # render pydot by calling dot, no file saved to disk
    png_str = dot.create_png(prog='dot')
    dot.write_svg('tree.svg')

    # treat the dot output string as an image file
    sio = BytesIO()
    sio.write(png_str)
    sio.seek(0)
        
    # plot the image
    fig, ax = plt.subplots(figsize=(18, 5))
    ax.imshow(plt.imread(sio), interpolation="bilinear")
    
def create_nodes(root, dot, tree_depth, tree_path, id=0):    
    text = "N: " + str(root.n) + " (" + str(root.current_player) + ')\n'
    text += "Q: " + str(root.end_v if root.end_v is not 0 or root.n is 0 else root.w / root.n) + '\n'
    text += "P: " + str(root.p) + '\n'
    text += "V: " + str(root.v)
    
    node = pydot.Node(str(id), label=text)
    dot.add_node(node)    
    id += 1

    if tree_depth > 1:
        i = 0
        for child in root.childs:
            if len(tree_path) == 0 or (type(tree_path[0]) == list and i in tree_path[0] or type(tree_path[0]) == int and tree_path[0] == i):
                child_node, id = create_nodes(child, dot, tree_depth - 1, tree_path if len(tree_path) == 0 else tree_path[1:], id)
                dot.add_edge(pydot.Edge(node.get_name(), child_node.get_name()))
            i+=1
        
    return node, id
    
        
def draw_path(root):
    state = root
    while True:     
        print(str(state.current_player) + ": " + str(state.end_v if state.n is 0 else state.w / state.n) + " / " + str(state.n))
        
        if state.is_leaf_node():
            break
    
        max_u = 0
        max_child = None
        for child in state.childs:
            if child.end_v != 0:
                u = child.end_v
            elif child.n > 0:
                u = child.w / child.n 
            else:
                u = -1
                
            if max_child is None or u > max_u:
                max_u, max_child = u, child
        
        state = max_child