In [None]:
import numpy as np 
from scipy.special import softmax 
import time 

from IPython.display import clear_output
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib.pyplot as plt 
import seaborn as sns 
import networkx as nx 

import sys 
sys.path.append("..") 
from utils.env import frozen_lake
from utils.viz import viz 
viz.get_style()

In [None]:
# The frozen lake 
layout = [
    "S.......",
    "........",
    "...H....",
    ".H...H..",
    "...H....",
    ".HH..H..",
    ".H..H...",
    "...H...G"
]
env = frozen_lake(layout)
env.reset()
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
env.render(ax)

In [None]:
s, a = 1, 2
env.p_s_next(s, a)

In [None]:
s_next = 3
r, done = env.r(s_next)
r, done 

## MCTS: basic 

In [None]:
# two types of nodes
class Node_s:
    type = 'state'

    def __init__(self, s, parent=None):
        self.s = s
        self.n = 1 
        self.v = 0
        self.name = f's={s}'
        self.parent = parent 
        self.children = []

class Node_a:
    type = 'action'

    def __init__(self, a, parent=None):
        self.a = a
        self.n = 1
        self.v = 0
        self.name = f'a={a}'
        self.parent = parent 
        self.children = []

## MCTS: select and expand

In [None]:
class MTCS:

    def __init__(self, model, c=1, rng=None):
        self.model = model 
        self.c     = c
        self.rng   = rng

    def plan(self, s, max_iter=100):
        '''Plan with MCTS
        '''
        root = Node_s(s)
        for _ in range(max_iter):
            node = self.select_expand(root)
            r_sum  = self.rollout(node)
            self.backprop(mode, r_sum)
        return self.uct_policy(root, c=0), root
    

In [None]:
def viz_tree(node, deep=0):
    key = node.name
    print('    |'*deep+'--'+key)
    for child in node.children:
        viz_tree(child, deep+1)

In [None]:
def select_expand(self, s_node):
    done = False
    while s_node not in self.model.s_termination:
        if self.fully_expanded(s_node):
            a = self.uct_policy(s_node, self.c)
            #### index the node for action a ###
            #                                  #
            ####################################
        else:
            a = self.expand(s_node)
            ## add node a to s_node's children ###
            #                                    #
            ######################################
            done = True 
        # sample the next state 
        ###  sample the s_next using model ###
        #                                    #
        ######################################
        s_next_lst = [child.s for child in a_node.children]
        if s_next in s_next_lst:
            ######  index the s_next node  ######
            #                                   #
            #####################################
            pass
        else:
            #### construct the s_next node  #####
            #                                   #
            #####################################
            pass
        s_node = s_next_node
        if done: break  
    return s_node

MTCS.select_expand = select_expand

In [None]:
# check your answer 
def test_select_expand(self, s=0, max_iter=20):
    root = Node_s(s)
    for _ in range(max_iter):
        node = self.select_expand(root)
    viz_tree(root)

MTCS.test_select_expand = test_select_expand
rng = np.random.RandomState(0)
MTCS(env, rng=rng).test_select_expand(max_iter=10)

##  Rollout

sample trajectory $(s, a, r, s')$ until the end using random rollout policy

In [None]:
def rollout(self, s_node):
    s = s_node.s 
    r_sum = 0 
    done = False
    while True:
        ##  sample a trajectory using random rollout policy #
        #                                                   #
        #####################################################
        r_sum += r
        if done: break
        s = s_next 
    return r_sum
MTCS.rollout = rollout

In [None]:
def backprop(self, node, r_sum):
    ##  backpropagate the reward to the root ##
    #                                         #
    ###########################################
    pass
MTCS.backprop = backprop

### Test your MCTS

In [None]:
def train(env, max_epi=1, seed=1234, max_iter=20):

    rng = np.random.RandomState(seed)
    agent = MTCS(model=env, c=1, rng=rng)
  
    for epi in range(max_epi):
        s, r, done = env.reset()
        t = 0 
        G = 0
        while True:
            # sample At, observe Rt, St+1
            a, root = agent.plan(s, max_iter=max_iter)
            s_next, r, done = env.step(a)
     
            t += 1
            G += r
            s = s_next

            fig, ax = plt.subplots(1, 1, figsize=(4, 4))
            clear_output(True)
            env.render(ax)
            time.sleep(.1)
            plt.show()
            
            if done:
                break 

In [None]:
env = frozen_lake(layout=layout, eps=0)
rng = np.random.RandomState(12434)
train(env, seed=224, max_iter=20)

In [None]:
env = frozen_lake(layout=layout, eps=0)
rng = np.random.RandomState(12434)
train(env, seed=224, max_iter=200)

In [None]:
env = frozen_lake(layout=layout, eps=0)
rng = np.random.RandomState(12434)
train(env, seed=224, max_iter=1000)