In [3]:
import logging
from collections import defaultdict
import numpy as np

from jass.game.game_util import *
from jass.game.game_observation import GameObservation
from jass.game.const import *
from jass.game.rule_schieber import RuleSchieber
from jass.agents.agent import Agent
from jass.arena.arena import Arena
from jass.agents.agent_random_schieber import AgentRandomSchieber
from jass.game.game_rule import GameRule
from jass.game.game_sim import GameSim

In [4]:
class MonteCarloTreeSearchNode():
    def __init__(self, gameState, parent=None, parent_action=None):
        self.state = gameState
        self.parent = parent
        self.parent_action = parent_action
        self.children = []
        self._number_of_visits = 0
        self._results = defaultdict(0)
        self._results[1] = 0
        self._results[-1] = 0
        self._untried_actions = None
        self._untried_actions = self.untried_actions()
        return

    def untried_actions(self):
        self._untried_actions = RuleSchieber.get_valid_actions_from_obs(self.state.get_observation())
        return self._untried_actions

    def differenceWinLoss(self):
        wins = self._results[1]
        loses = self._results[-1]
        return win - loses

    def numberOfVisits(self):
        return self._number_of_visits

    def expand(self):
        action = self._untried_actions.pop()
        next_state = self.state.action_play_card(action)
        child_node = MonteCarloTreeSearchNode(next_state, parent=self, parent_action=action)
        self.children.append(child_node)
        return child_node

    def is_terminal_node(self):
        return self.state.is_done()

    def rollout(self):
        current_rollout_state = self.state

        while not current_rollout_state.state.is_done():
            possible_moves = RuleSchieber.get_valid_actions_from_obs(self.state.get_observation())
            action = self.rollout_policy(possible_moves)
            current_rollout_state = current_rollout_state.state.action_play_card(action)

        return gameResult

    def backpropagate(self, result):
        self._number_of_visits += 1
        self._results[result] += 1
        if self.parent:
            self.parent.backpropagate(result)

    def is_fully_expanded(self):
        return len(self._untried_actions) == 0

    def best_child(self, c_param=0.1):
        choices_weights = [(c.differenceWinLoss() / c.numberOfVisits()) + c_param * np.sqrt(
            (2 * np.log(self.numberOfVisits()) / c.numberOfVisits())) for c in self.children]
        return self.children[np.argmax(choices_weights)]

    def rollout_policy(self, possible_moves):
        return possible_moves[np.random.randint(len(possible_moves))]

    def _tree_policy(self):
        current_node = self
        while not current_node.is_terminal_node():
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

    def best_action(self):
        simulation_no = 100

        for i in range(simulation_no):
            v = self._tree_policy()
            reward = v.rollout()
            v.backpropagate(reward)

        return self.best_child(c_param=0.)

In [5]:
class DMCTSAgent(Agent):
    def __init__(self):
        self._logger = logging.getLogger("DMTSAgent")
        super().__init__()

        self._rule = RuleSchieber()
        self.round = 0

    def action_trump(self, obs: GameObservation) -> int:
        self._logger.info("Select Trump")

        return 0

    def action_play_card(self, obs: GameObservation) -> int:
        self.round += 1
        self.round = self.round % 9

        _valid_cards = self._rule.get_valid_cards_from_obs(obs)
        print(obs.player_view)
        print(obs.hand)
        print(convert_one_hot_encoded_cards_to_str_encoded_list(obs.hand))
        print(_valid_cards)
        print(convert_one_hot_encoded_cards_to_str_encoded_list(_valid_cards))
        print("-------------------------------------------------------------------------")
        return np.random.choice(np.flatnonzero(_valid_cards))

In [6]:
def main():
    # setup the arena
    arena = Arena(nr_games_to_play=1)
    player = AgentRandomSchieber()
    my_player = DMCTSAgent()

    arena.set_players(my_player, player, my_player, player)
    print('Playing {} games'.format(arena.nr_games_to_play))
    arena.play_all_games()
    print('Average Points Team 0: {:.2f})'.format(arena.points_team_0.mean()))
    print('Average Points Team 1: {:.2f})'.format(arena.points_team_1.mean()))

In [7]:
if __name__ == '__main__':
    main()


Playing 1 games
2
[0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 1 0 1 1 0 1]
['D9', 'D8', 'H7', 'H6', 'S7', 'CJ', 'C9', 'C8', 'C6']
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
['H7', 'H6', 'S7']
-------------------------------------------------------------------------
0
[0 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 1 1 0 0 1 0 0 1 0]
['D10', 'D7', 'HA', 'S9', 'S8', 'CA', 'CK', 'C10', 'C7']
[0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0]
['HA', 'S9', 'S8']
-------------------------------------------------------------------------
2
[0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 1 1 0 1]
['D9', 'D8', 'H7', 'S7', 'CJ', 'C9', 'C8', 'C6']
[0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 1 1 0 1]
['D9', 'D8', 'H7', 'S7', 'CJ', 'C9', 'C8', 'C6']
-------------------------------------------------------------------------
0
[0 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0