In [59]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math
from typing import TypeVar, Dict, Callable, Tuple, Union, List, Generic
Fn = Callable

%matplotlib inline

In [186]:
Action = TypeVar('Action')
State = TypeVar('State')
c_puct: float = 10.0
temp: float = 1.0
    
class Game(Generic[State, Action]):
    def __init__(self,
                 gen_root: Fn[[], State], 
                 do_action: Fn[[State, Action], State], 
                 eval_state: Fn[[State], Tuple[Dict[Action, float], float]],
                 search_len: int):
        self.gen_root = gen_root
        self.do_action = do_action
        self.eval_state = eval_state
        self.search_len = search_len
        
    def mcts(self, root_node: MCTSNode):
        for _ in range(self.search_len):
            cur_node = root_node
            next_action = cur_node.select_action()
            prev = [(cur_node, next_action)]

            while next_action in cur_node.children:
                cur_node = cur_node.children[next_action]
                next_action = cur_node.select_action()

                if next_action != None:
                    prev.append((cur_node, next_action))

            v: float
            if next_action == None:
                _, v = self.eval_state(cur_node.state)
            else:
                new_state = self.do_action(cur_node.state, next_action)
                probs, v = self.eval_state(new_state)
                new_node = MCTSNode(new_state, probs)
                cur_node.children[next_action] = new_node

            for i, (node, act) in enumerate(reversed(prev)):
                node.backup(act, v if i % 2 == 1 else -v)
    
    def play_game(self) -> List[Tuple[State, Dict[Action, float], float]]:
        cur_state = self.gen_root()
        probs, _ = self.eval_state(cur_state)
        cur_node = MCTSNode(cur_state, probs)
        history: List[Tuple[State, Dict[Action, float]]] = []
        
        while cur_node.actions:
            self.mcts(cur_node)
            probs = cur_node.probs()
            print(cur_node.N, cur_node.W)
            if len(cur_node.W) == 2:
                print(cur_node.children[6].W)
            next_act = cur_node.actions[np.argmax(list(
                map(lambda a: probs[a], cur_node.actions)))]
            
            history.append((cur_node.state, probs))
            
            cur_node = cur_node.children[next_act]
            cur_state = cur_node.state
            
        history.append((cur_node.state, probs))
        _, r = self.eval_state(cur_node.state)
        
        return reversed([(s, p, r if i % 2 == 1 else -r)
                        for i, (s, p) in enumerate(reversed(history))])
        
    
class MCTSNode(Generic[State, Action]):
    def __init__(self, state: State, probs: Dict[Action, float]) -> None:
        self.actions = list(probs.keys())
        
        self.N: Dict[Action, Int] = { action: 0 for action in self.actions }
        self.W: Dict[Action, float] = { action: 0 for action in self.actions }
        self.Q: Dict[Action, float] = { action: 0 for action in self.actions }
        self.P: Dict[Action, float] = probs
        
        self.state: State = state
        self.total_visits: int = 0
        self.children: Dict[Action, MCTSNode] = {}
            
    def select_action(self) -> Action:
        if not self.actions:
            return None
        
        root_total = math.sqrt(self.total_visits)
        def eval_action(a):
            U = c_puct * self.P[a] * root_total / float(1 + self.N[a])
            return self.Q[a] + U
        
        selection_priorities = map(eval_action, self.actions)
        return self.actions[np.argmax(list(selection_priorities))]
    
    def backup(self, a, v):
        self.total_visits = self.total_visits + 1
        self.N[a] = self.N[a] + 1
        self.W[a] = self.W[a] + v
        self.Q[a] = self.W[a] / self.N[a]
    
    def probs(self) -> Dict[Action, float]:
        exps = { act: self.N[act] ** (1 / temp) for act in self.actions }
        tot = sum(exps.values())
        return { act: exps[act] / tot for act in self.actions }

In [198]:
TTTState = Tuple[List[int], int]
TTTAction = int

def ttt_gen_root() -> TTTState:
    return ([0 for i in range(9)], 1)

def ttt_do_action(s: TTTState, a: TTTAction) -> TTTState:
    new_board = [b for b in s[0]]
    cur_player = s[1]
    
    new_board[a] = cur_player
    return (new_board, -cur_player)

def ttt_eval_state(s: TTTState) -> Tuple[Dict[TTTAction, float], float]:
    board = s[0]
    cur_player = s[1]
    end = False
    
    def win(i, j, k):
        return int(board[i] == board[j] and board[j] == board[k] and board[i] != 0) * board[i]
    
    for i in range(3):
        end = end or win(3 * i, 3 * i + 1, 3 * i + 2)
        end = end or win(i, i + 3, i + 6)
    end = end or win(2, 4, 6)
    end = end or win(0, 4, 8)
    
    if end:
        return (dict(), end * cur_player)
    else:
        valid_acts = [ i for i, v in enumerate(board) if v == 0 ]
        return ({ i: 1.0 / len(valid_acts) for i in valid_acts }, 0)