In [1]:
import numpy as np 
from scipy.special import softmax 
import time 

from IPython.display import clear_output
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib.pyplot as plt 
import seaborn as sns 

import sys 
sys.path.append("..") 
from utils.env import frozen_lake
from utils.viz import viz 
viz.get_style()

In [None]:
class Node_s:
    type = 'state'

    def __init__(self, s, parent=None):
        self.s = s
        self.n = 1 
        self.v = 0
        self.parent = parent 
        self.children = {}

    def is_leaf(self):
        return len(self.children) > 1
    
    def update(self, r):
        self.r += r
        self.n += 1

In [None]:
class Node_a:
    type = 'action'

    def __init__(self, a, parent=None):
        self.a = a
        self.n = 1
        self.q = 0
        self.parent = parent 
        self.children = {}
    
    def update(self, r):
        self.r += r
        self.n += 1

In [None]:
def MCTS(s, model, n_iter=10, c=1, seed=1234):

    # set the random seed
    rng = np.random.RandomState(seed)

    def uct_policy(node):
        assert node.type == 'state', 'input a state to the policy'
        score = [] 
        for child in node.children:
            exploit = (child.q / child.n) 
            explore = np.sqrt(2*np.log(node.parent.n) / node.n)
            score.append(exploit + c*explore)
        return rng.choice(np.where(score==np.max(score))[0].tolist())   
    
    def fully_expanded(node):
        return node.children == model.nA

    
    
    def expand(s_node):
        a_selected = [child.a for child in s_node.childrens]
        a_unslected = list(set(list(range(model.nA))) 
                         - set(a_selected))
        a = rng.choice(a_unslected)
        a_node = Node_a(a, parent=s_node)
        s_node.childrens.append(a_node)
        p_next = model.p_s_next(s_node.s , a)
        s_next = rng.choice(model.nS, p=p_next)
        s_next_node = Node_s(s_next, parent=a_node)
        a_node.childrens.append(s_next_node)

    def rollout(s_node):
        s = s_node.s 
        r = 0 
        while s not in model.s_termination():
            a = rng.choice(model.nA)
            p_next = model.p_s_next(s, a)
            s_next = rng.choice(model.nS, p=p_next)
            r += model.r(s_next)
            s = s_next 
        return r 
    
    def backprop(node, r):
        while node is not None:
            node.n += 1
            node.r += r 
            node = node.parent 
        

    


In [None]:
class MTCS:

    def __init__(self, model, c=1, seed=1234):
        self.model = model 
        self.c     = c
        self.rng   = np.random.RandomState(seed)

    def plan(self, s, max_iter=20):
        root = Node_s(s)
        for _ in range(max_iter):
            leaf = self.selection(root)
            rew  = self.rollout(leaf)
            self.backprop(leaf, rew)
        return self.rng.choice(self.uct_policy(root))
    
    def selection(self, s_node):
        while not s_node.s in self.model.s_termination:
            if fully_expanded(s_node):
                a = uct_policy(s_node)
                a_node = s_node.children[a]
            else:
                a = expand(s_node)
                a_node = Node_a(a, parent=s_node)
        # sample the next state 
        p_next = model.p_s_next(s_node.s, a)
        s_next = rng.choice(model.nS, p=p_next)
        s_next_lst = [child.s for child in a_node.children]
        if s_next in s_next_lst:
            s_next_node = a_node.children[s_next_lst.index(s_next)]
        else:
            s_next_node = Node_s(s_next, parent=a_node)  
        return s_next_node 