In [1]:
import gym
import numpy
import matplotlib.pyplot as plt
%matplotlib inline
from queue import Queue, LifoQueue
from copy import copy
import numpy as np
from IPython.display import clear_output
from time import sleep

def print_frames(frames):
    for i, frame in enumerate(frames):
        clear_output(wait=True)
        print(frame['frame'])
        print(f"Timestep: {i + 1}")
        print(f"State: {frame['state']}")
        #print(f"Action: {frame['action']}")
        print(f"Reward: {frame['reward']}")
        sleep(.1)

In [2]:
env = gym.make('Taxi-v3', render_mode='human')
init_state = env.reset()
env.render()

action_num = env.action_space.n
state_num = env.observation_space.n

print('Action space:', action_num)
print('Observation space:', state_num)

Action space: 6
Observation space: 500


In [3]:
state_id = 328
env.s = state_id
env.render()
env.P[state_id]

{0: [(1.0, 428, -1, False)],
 1: [(1.0, 228, -1, False)],
 2: [(1.0, 348, -1, False)],
 3: [(1.0, 328, -1, False)],
 4: [(1.0, 328, -10, False)],
 5: [(1.0, 328, -10, False)]}

In [2]:
class Node:

    def __init__(self, env, parent = None):
        self.state = env
        self.parent = parent
        self.children = []
        self.untried_actions = [action for action in range(action_num)]
        self.visiting_times = 0
        self.q = 0
        self.is_done = False
        self.observation = None
        self.reward = 0
        self.action = None

    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

    def is_terminal_node(self):
        return self.is_done

    def compute_mean_value(self):
        if self.visiting_times == 0:
            return 0
        return self.q / self.visiting_times

    def compute_score(self, scale = 10, max_score = 10e100):
        if self.visiting_times == 0:
            return max_score
        parent_visiting_times = self.parent.visiting_times
        ucb = 2 * np.sqrt(np.log(parent_visiting_times) / self.visiting_times)
        result = self.compute_mean_value() + scale * ucb
        return result

    def best_child(self):
        scores = [child.compute_score() for child in self.children]
        child_index = np.argmax(scores)
        return self.children[child_index]

    def expand(self):
        action = self.untried_actions.pop()
        next_state = copy(self.state)
        self.observation, self.reward, self.is_done,_, _ = next_state.step(action)
        child_node = Node(next_state, parent = self)
        child_node.action = action
        self.children.append(child_node)
        return child_node
  
    def rollout_policy(self, state):
        return state.action_space.sample()
  
    def rollout(self, t_max = 10**8):
        state = copy(self.state)
        rollout_return = 0
        gamma = 0.6
        done = False
        while not done:
            action = self.rollout_policy(state)
            obs, reward, done, _, _ = state.step(action)
            rollout_return += gamma * reward
            if done:
                break

        return rollout_return

    def backpropagate(self, child_value):
        node_value = self.reward + child_value
        self.q += node_value
        self.visiting_times += 1
        if self.parent:
            return self.parent.backpropagate(node_value)


class MonteCarloTreeSearch(object):
    def __init__(self, node):
        self.root = node

    def best_action(self, simulations_number):
        for _ in range(0, simulations_number):
            v = self._tree_policy()
            reward = v.rollout()
            v.backpropagate(reward)
        return self.root.best_child()

    def _tree_policy(self):
        current_node = self.root
        while not current_node.is_terminal_node():
            if not current_node.is_fully_expanded():
                return current_node.expand()
            else:
                current_node = current_node.best_child()
        return current_node

In [None]:
env.reset()
# env.render()

n_simulation = 10**4
root = Node(env)
is_done = False
total_reward, penalty, epochs = 0, 0, 0

while not is_done:
    env.render()
    mcts = MonteCarloTreeSearch(root)
    best_child = mcts.best_action(n_simulation)
    new_state, reward, is_done, info, _ = env.step(best_child.action)
    total_reward += reward
    if reward == -10:
        penalty += 1
    epochs += 1
    root = best_child

# env.render()
print('Timesteps taken:', epochs)
print('Penalty:', penalty)
print('total_reward:', total_reward)