In [2]:
import cProfile
from tqdm import tqdm
from Game import Game
from State import State
from Action import Action
from typing import List, Dict, TYPE_CHECKING, Any

import random
import os
import copy
import math
import queue


In [3]:
from abc import ABC, abstractmethod
class Agent(ABC):
    @abstractmethod
    def getResponse(self, validActions, game=None, maxPlayer=None):
        pass

In [4]:
class Saver_Node:

    ## Node for MCTS_saver, uses action history to recreate states rather than storing each one
    def __init__(self, action:Action, parent:'Saver_Node', maxPlayer: int):
        self.zero_wins = 0
        self.one_wins = 0
        self.visits = 0

        if parent is None:
            self.num_parents = 0
        else:
            self.num_parents = parent.num_parents + 1

        self.terminal = False

        #self.state = state ## The state we are in
        self.action = action ## The action that got us here
        self.parent = parent ## The parent node we are descendent from

        ## will be a dict of {action : childNode } where childnode state is the result of applying action to this node state
        self.children: Dict['Action','Saver_Node'] = {}
        self.player = maxPlayer # needs to be changed to "zero" "one" and "random"

    def __str__(self) -> str:
        if self.parent is None:
            p = "None"
        else:
            p = "TODO: Hash"
            ## p = get_hash(self.parent.state)
        try:
            expected_value = self.get_expected_value()
        except ValueError:
            expected_value = 0
        spacer = "   " * self.num_parents
        return spacer + f'Action {self.action}\n' + spacer + f' Visits={self.visits} Zero Wins={self.zero_wins} One Wins={self.one_wins}\n' + spacer + f' Expected Value={expected_value} UCB={self.get_ucb()}'

    ## given a mutable rootState, reconstructs the state associated with this node
    def construct_state(self, rootState: State) -> State:
        action_hist = []
        curr = self
        while curr.parent is not None:
            action_hist.insert(0,curr.action)
            curr = curr.parent
        
        [rootState.applyAction(action) for action in action_hist]
        return rootState

    ## MAKE SURE TO FLIP PLAYER BEFORE CALLING
    def add_child(self, action:Action, player:int) -> 'Saver_Node':
        if action in self.children.keys():
            raise ValueError('dupe child')
        else:
            self.children[action] = Saver_Node(action, self, player)
            return self.children[action]

    def get_p_win(self, player:int):
        try:
            if player == 0:
                return self.zero_wins / self.visits
            elif player == 1:
                return self.one_wins / self.visits
            else:
                raise ValueError(f'Given {player} for player, need 1 or 0')
        except ZeroDivisionError:
            raise ValueError('need atleast one visit before getting pwin')

    def get_expected_value(self) -> float:
        try:
            return (self.zero_wins - self.one_wins) / self.visits
        except ZeroDivisionError:
            raise ValueError('need atleast one visit before getting expected value')

    def get_explore_term(self, parent:'Saver_Node', c=1) -> float:
        if self.parent is not None:
            return c * (2*math.log(parent.visits) / self.visits) ** (1/2)
        return 0
    
    def get_ucb(self, c=1, default=6) -> float:
        if self.visits:
            p_win = self.get_expected_value()
            if self.player == 0:
                p_win *= -1
            explore_term = self.get_explore_term(self.parent,c)
            return p_win + explore_term
        return default

    def print_tree(self, max_nodes = 50):
        if max_nodes is None:
            max_nodes = len(self.children)+1
        print(f"Printing from node TODO: Hash {self}")
        q = queue.Queue()
        q.put(self)
        node_count = 0
        while not q.empty() and node_count < max_nodes:
            node_count += 1
            n = q.get()
            print(n)
            print()
            for key in n.children.keys():
                q.put(n.children[key])               

In [5]:
game = Game(order=list(range(0,72)))
root = Saver_Node(None,None,0)

a1 = game.state.currentActions[0]
c1 = root.add_child(a1,1)

mState = game.startSim()
reconstructed_state = c1.construct_state(mState)
input()

''

In [9]:
class MCTS_Saver(Agent):
    
    def getLegalMoves(state: State) -> List[Action]:
        return state.getActions()

    def nextPlayer(currentPlayer:int) -> int:
        return (currentPlayer + 1) % 2

    def getResult(state: State) -> int:
        if state.gameOver() is False:
            return None
        return state.scoreDelta()

    ####*******####
    #             #  
    # ENTRY POINT #
    #             #
    ####*******####

    def getResponse(self, validActions: List[Action], game:Game=None, maxPlayer:int=None) -> Action:

        
        self.headState = game.state      ## <- a reference of the root state for rollback
        self.muteState = game.startSim() ## <- make a deepcopy of the headstate that we can play around with

        root = Saver_Node(None,None,maxPlayer)
        for iteration in range(60):
            v = MCTS_Saver.tree_policy(root, maxPlayer, self.muteState)
            game.refresh(self.muteState) ## refresh the state to where it was before TP
            score = MCTS_Saver.default_policy(v, game, self.muteState)
            game.refresh(self.muteState)
            MCTS_Saver.backProp(v, score)
        move = MCTS_Saver.bestChild(root, 0)

        root.print_tree()
        print(f"\n\nBEST NODE: {move.action} UCB: {move.get_ucb()}")

        return move.action


    ####*******####
    #             #  
    # POLICYFUNCS #
    #             #
    ####*******####

    ## MCTS tree policy (selects child node to examine)
    def tree_policy(node: Saver_Node, player:int, muteState: State) -> Saver_Node:
        ## return the node if its a terminal
        if node.terminal:
            return node 
        
        ## otherwise expand a new possible childNode <--- This is where the random will come in but for now its deterministic
        if node.action is not None:
            muteState.applyAction(node.action)
        moves = MCTS_Saver.getLegalMoves(muteState)
        if len(moves) > len(node.children):
            return MCTS_Saver.expand(node, player, moves)
        
        ## if all children have been expanded, go down the tree by what we think is the best candidate and recurs
        return MCTS_Saver.tree_policy(MCTS_Saver.bestChild(node), MCTS_Saver.nextPlayer(player), muteState)

    ## Adds a random successor node
    def expand(node:Saver_Node, player:int, actions: List[Action]) -> Saver_Node:
        child = None
        for action in actions:
            if action not in node.children.keys():
                child = node.add_child(action,MCTS_Saver.nextPlayer(player))
                return child
        raise ValueError("Ran out of children when we shouldn't")

    ## Selection heuristic for following tree and finally move choice
    def bestChild(node:Saver_Node, c=1) -> Saver_Node:
        bestNode = list(node.children.values())[0]
        for action, node in node.children.items():
            if node.get_ucb(c) > bestNode.get_ucb(c):
                bestNode = node
        return bestNode

    ## Recurse up the tree now, incrementing vists and accumulating score
    def backProp(node:Saver_Node,score:int) -> None:
        while node is not None:
            node.visits += 1
            if score > 0:
                node.zero_wins += score
            elif score < 0:
                node.one_wins += score
            node = node.parent

    ## Rollout randomly from a gamestate to game end
    def default_policy(node:Saver_Node,game:Game, muteState: State, print_final=False) -> int:
        current_state = node.construct_state(muteState)
        
        while(current_state.gameOver() is False):
            moves = MCTS_Saver.getLegalMoves(current_state)
            current_state.applyAction(random.choice(moves), quiet=True)
        return MCTS_Saver.getResult(current_state)

######################
#  Save a head state copy for the mcts search
#  when you get to default policy instead of copying the current node's state
#  create a new state by applying actions from the original state
#  when done, shallow-copy-restore the head state copy


In [10]:
def launch():
    players = [MCTS_Saver(), MCTS_Saver()]
    carcassonne = Game(players,order=list(range(0,72)))

    actions = carcassonne.getActions()
    currPlayer = carcassonne.currentPlayer()

    response = currPlayer.agent.getResponse(actions,game=carcassonne,maxPlayer=currPlayer.id)
    carcassonne.applyAction(response)

    carcassonne.render()

launch()
# import cProfile, pstats
# profiler = cProfile.Profile()
# profiler.enable()
# launch()
# profiler.disable()
# stats = pstats.Stats(profiler).sort_stats('tottime')
# stats.print_stats()

Printing from node TODO: Hash Action None
 Visits=60 Zero Wins=119 One Wins=-14
 Expected Value=2.216666666666667 UCB=-2.216666666666667
Action None
 Visits=60 Zero Wins=119 One Wins=-14
 Expected Value=2.216666666666667 UCB=-2.216666666666667

   Action Location: [-1, 0] Orientation: 0 
    Visits=6 Zero Wins=0 One Wins=-14
    Expected Value=2.3333333333333335 UCB=3.501571973655036

   Action Location: [-1, 0] Orientation: 0 Meeple: Feature: CITY on edges [2]
    Visits=2 Zero Wins=2 One Wins=0
    Expected Value=1.0 UCB=3.023448680402372

   Action Location: [-1, 0] Orientation: 0 Meeple: Feature: GRASS on edges [0, 1, 6, 7]
    Visits=1 Zero Wins=0 One Wins=0
    Expected Value=0.0 UCB=2.8615885665909766

   Action Location: [-1, 0] Orientation: 0 Meeple: Feature: CITY on edges [1]
    Visits=8 Zero Wins=22 One Wins=0
    Expected Value=2.75 UCB=3.761724340201186

   Action Location: [-1, 0] Orientation: 3 
    Visits=1 Zero Wins=0 One Wins=0
    Expected Value=0.0 UCB=2.8615885665