# MCTS

In [1]:
def search(obs):
    state = env.get_state()
    n = len(env.players[env.current_player].hand_cards)   
    p = 0
    current_player = env.current_player
    distr = np.zeros((n,))
    
    for i in range(n):
        card_id = env.players[env.current_player].hand_cards[i].id
        obs, rew, is_done, _ = env.step(card_id)
        
        if is_done:
            distr[i] = (rew[0] > 0)
        else:
            distr[i] = ((1 - search(obs)[0]) if current_player != env.current_player else search(obs)[0])
        env.set_state(state)
        p += distr[i]
        
    return distr.max(), distr

In [366]:
class State:
    def __init__(self, p, env_state, current_player, end_v=0, is_root=True):
        self.childs = []
        self.n = 0
        self.w = 0
        self.v = 0
        self.p = p
        self.env_state = env_state
        self.end_v = end_v
        self.current_player = current_player
        self.is_root = is_root
        
    def is_leaf_node(self):
        return len(self.childs) == 0
    
    def size(self):
        size = 1
        for child in self.childs:
            size += child.size()
        return size
    
def mcts_sample(state, use_model=True):
    
    if state.is_leaf_node():
        if state.end_v != 0:
            v = state.end_v
        else:
            env.set_state(state.env_state)

            if use_model:
                p, v = model.predict_single(env.regenerate_obs())
            else:
                p, v = [1] *32, [0]
                
            hand_cards = env.players[env.current_player].hand_cards[:]        
            current_player = env.current_player
            for card in hand_cards:
                obs, rew, is_done, _ = env.step(card.id)

                new_state = State(p[card.id], env.get_state(), env.current_player, 0 if not is_done else (1 if rew[0] > 0 else -1) * (-1 if env.current_player != current_player else 1), False)
                state.childs.append(new_state)

                env.set_state(state.env_state)

            v = v[0]  
            state.v = v
    else:
        n_sum = 0
        for child in state.childs:
            n_sum += child.n
            
        if state.is_root:
            epsilon = EPSILON
            nu = np.random.dirichlet([ALPHA] * len(state.childs))
        else:
            epsilon = 0
            nu = [0] * len(state.childs)

        max_u = 0
        max_child = None
        i = 0
        rand_i = 2 if state.is_root else random.randint(0, len(state.childs) - 1)
        for child in state.childs:
            if child.end_v != 0:
                u = child.end_v
            elif child.n > 0:
                u = child.w / child.n 
            else:
                u = 0
            
            u *= (-1 if child.current_player != state.current_player else 1)
            
            u += ((1 - epsilon) * child.p + epsilon * nu[i]) * np.sqrt(n_sum) / (1 + child.n)

            if max_child is None or u > max_u:
                max_u, max_child = u, child
            i += 1
        v = mcts_sample(max_child, use_model) * (-1 if max_child.current_player != state.current_player else 1)
        
    state.w += v 
    state.n += 1
    return v

def mcts_game_step(root, steps=MCTS_SIMS, use_model=True):
    for i in range(steps):
        mcts_sample(root, use_model)
        
    p = [child.n for child in root.childs]
    p /= np.sum(p)
    return np.random.choice(np.arange(0, len(p)), p=p), p

def generate_key():
    card_ids = []
    for card in env.players[env.current_player].hand_cards:
        card_ids.append(card.id)
    card_ids.sort()
    return str(card_ids) + "-" + str(env.players[env.current_player].tricks) + "-" + str(env.players[1 - env.current_player].tricks)
    
def clean_unfinished_samples():
    global sample_inputs, sample_outputs, next_index, number_of_samples, finished_sample_numbers
    number = 0
    for hand_cards, unfinished_samples_per_card_number in enumerate(unfinished_samples):
        keys_to_remove = []
        for key, sample in unfinished_samples_per_card_number.items():
            if len(sample[1]) >= SAMPLES_PER_STATE:
                sample_inputs[0][next_index] = sample[0][0]
                sample_inputs[1][next_index] = sample[0][1]
                
                sample_outputs[0][next_index].fill(0)
                sample_outputs[1][next_index].fill(0)
                
                for single_sample in sample[1]:
                    sample_outputs[0][next_index] += single_sample[1]
                    sample_outputs[1][next_index] += single_sample[0]
                    
                sample_outputs[0][next_index] = np.divide(sample_outputs[0][next_index], len(sample[1]))
                sample_outputs[1][next_index] = np.divide(sample_outputs[1][next_index], len(sample[1]))
                
                finished_sample_numbers[hand_cards] += 1
                next_index += 1
                number += 1
                next_index %= MEMORY_SIZE
                number_of_samples = max(number_of_samples, next_index)
                keys_to_remove.append(key)
        
        for key in keys_to_remove:
            del unfinished_samples_per_card_number[key]
    return number

def mcts_game():
    global unfinished_samples, finished_sample_numbers
    
    next_hand_card_size = finished_sample_numbers.index(min(finished_sample_numbers))
            
    if len(unfinished_samples[next_hand_card_size]) == 0:
        obs = env.reset()
    else:
        env.set_state(next(iter(unfinished_samples[next_hand_card_size].values()))[0][-1])
        
        opponent_cards = env.players[1 - env.current_player].hand_cards
        number_of_opponent_cards = len(opponent_cards)
        env.cards_left.extend(opponent_cards)
        random.shuffle(env.cards_left)
        
        opponent_cards.clear()
        for hand_card in range(number_of_opponent_cards):
            opponent_cards.append(env.cards_left.pop())
            
        obs = env.regenerate_obs()
        
    root = State(1, env.get_state(), env.current_player)
    is_done = False
    values = []
    
    while not is_done:
        key = generate_key()
                
        if not key in unfinished_samples[5 - len(env.players[env.current_player].hand_cards)]:
            unfinished_samples[5 - len(env.players[env.current_player].hand_cards)][key] = [[np.array(obs[0]), np.array(obs[1]), env.get_state()], []]
        current_sample = unfinished_samples[5 - len(env.players[env.current_player].hand_cards)][key]
        
        current_sample[1].append([1 if env.current_player is 0 else -1])    
        values.append(current_sample[1][-1])

        game_state = env.get_state()
        a, p = mcts_game_step(root)
        env.set_state(game_state)
        
        output = np.zeros((32,))
        for i, card in enumerate(env.players[env.current_player].hand_cards):
            output[card.id] = p[i]    
        current_sample[1][-1].append(output)
                
        last_player = env.current_player
        obs, rew, is_done, _ = env.step(env.players[env.current_player].hand_cards[a].id)
        root = root.childs[a]
        root.is_root = True
                
    for val in values:
        val[0] *= (1 if last_player is 0 else -1) * (1 if rew[0] > 0 else -1)
        
    return clean_unfinished_samples()
        
def mcts_generate():
    
    number = 0
    with tqdm(total=EPISODES * 5) as progress_bar:
        while number < EPISODES * 5:
            number += mcts_game()
            progress_bar.update(number - progress_bar.n)
        
    #postprocess_samples()

def reset_samples():
    global sample_outputs, sample_inputs, next_index, number_of_samples, unfinished_samples, finished_sample_numbers
    sample_outputs = [np.zeros((MEMORY_SIZE, 32), dtype=float), np.zeros((MEMORY_SIZE, 1), dtype=float)]
    sample_inputs = [np.zeros((MEMORY_SIZE, 4, 8, 2), dtype=int), np.zeros((MEMORY_SIZE, 4), dtype=int)]
    unfinished_samples = [{}, {}, {}, {}, {}]
    finished_sample_numbers = [0, 0, 0, 0, 0]
    next_index = 0
    number_of_samples = 0
    
def draw_tree(root):
    dot = pydot.Dot()
    dot.set('rankdir', 'TB')
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    create_nodes(root, dot)

   # print("Root: " + str(root.end_v if root.n is 0 else root.w / root.n) + " / " + str(root.n))
   # for child in root.childs:
   #     print( str(child.end_v if child.n is 0 else child.w / child.n) + " / " + str(child.n) + " p: " + str(child.p))

    # render pydot by calling dot, no file saved to disk
    png_str = dot.create_png(prog='dot')

    # treat the dot output string as an image file
    sio = BytesIO()
    sio.write(png_str)
    sio.seek(0)
    
    # plot the image
    fig, ax = plt.subplots(figsize=(18, 5))
    ax.imshow(plt.imread(sio), interpolation="bilinear")
    
def create_nodes(root, dot, id=0):    
    text = "N: " + str(root.n) + " (" + str(root.current_player) + ')\n'
    text += "Q: " + str(root.end_v if root.n is 0 else root.w / root.n) + '\n'
    text += "P: " + str(root.p) + '\n'
    text += "V: " + str(root.v)
    
    node = pydot.Node(str(id), label=text)
    dot.add_node(node)    
    id += 1

    for child in root.childs:
        child_node, id = create_nodes(child, dot, id)
        dot.add_edge(pydot.Edge(node.get_name(), child_node.get_name()))
        
    return node, id
    
        
def draw_path(root):
    state = root
    while True:     
        print(str(state.current_player) + ": " + str(state.end_v if state.n is 0 else state.w / state.n) + " / " + str(state.n))
        
        if state.is_leaf_node():
            break
    
        max_u = 0
        max_child = None
        for child in state.childs:
            if child.end_v != 0:
                u = child.end_v
            elif child.n > 0:
                u = child.w / child.n 
            else:
                u = -1
                
            if max_child is None or u > max_u:
                max_u, max_child = u, child
        
        state = max_child


In [343]:
reset_samples()

In [344]:
mcts_generate()
finished_sample_numbers

  7%|▋         | 25/375 [00:06<01:39,  3.53it/s]


KeyboardInterrupt: 

In [359]:
next_index

11577

In [69]:
len(unfinished_samples)

5

In [308]:
finished_sample_numbers

[6, 6, 6, 6, 6]