In [24]:
import math
import random
import seaborn
import matplotlib.pyplot as plt

from collections import defaultdict

from azulbot.azulsim import Azul, Move, AzulState

from mctspy.tree import DecisionNode, MCTS
from mctspy.policies import uct_action
from mctspy.simluator import MCTSSimulator

In [20]:
class AzulMCTS(MCTSSimulator):

    def __init__(self, azul, full_game=False):
        self.azul = azul
        self.ful_game = full_game

    def step(self, state, action):
        next_game_state = self.azul.apply_move(state, move)
        
        if self.full_game and self.azul.is_round_end(next_game_state):
            next_game_state = self.azul.deal_round(next_game_state)

        return next_game_state.state, 0, next_game_state.nextPlayer

    def state_is_terminal(self, state):
        return (
            self.azul.is_game_end(state) if self.full_game 
            else self.azul.is_round_end(state)
        ) 

    def enumerate_actions(self, state):
        return set(self.azul.enumerate_moves(state))

    def get_initial_state(self):
        state = self.azul.deal_round(self.azul.get_init_state())
        return state, state.nextPlayer

    def get_agent_num(self):
        return self.azul.PlayerNumber

    def get_current_agent(self, state):
        return state.nextPlayer

    def get_terminal_value(self, state):
        state = self.azul.score_round(state)

        if self.full_game:
            state = self.azul.score_game(state)
        
        return {i: player.score for i, player in enumerate(state.players)}

In [None]:
env = Azul()
env = AzulMCTS(env)

initial_state, agent_id = env.get_initial_state()

mcts = MCTS(env, uct_action, partial(random_rollout_value, env=env, seed=seed), 50)
mcts_root = DecisionNode(initial_state, 0, {}, agent_id)

# Build tree
mcts.build_tree(mcts_root)