In [4]:
import numpy as np
import matplotlib.pyplot as plt
from catvsmonsters import catVsMonsters
from grid_world import GridWorld

In [5]:
import time
from tqdm import tqdm
class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0
        
    def get_child_ucb1(self, child, c):
        if child.visits == 0:
            return float("inf")
        return child.value / child.visits + c * np.sqrt(np.log(self.visits) / child.visits)
        
    def get_max_ucb1_child(self, c):
        if not self.children:
            return None, None
            
        max_i = 0
        max_ucb1 = float("-inf")
        
        for i, child in enumerate(self.children):
            ucb1 = self.get_child_ucb1(child, c)
            
            if ucb1 > max_ucb1:
                max_ucb1 = ucb1
                max_i = i
                
        return self.children[max_i], max_i

class MonteCarloSearchTree:
    def __init__(self, env):
        self.env = env
        self.root = Node((0,0))
        self.depth_limit = 200
        self.c = 100  # Changed from 100 to standard UCT constant
        self.gamma = 0.9

    def _selection(self, node):
        while len(node.children) > 0:
            child, _ = node.get_max_ucb1_child(self.c)
            node = child
        return node

    def _expansion(self, node):
        for action in self.env.actions:
            next_state = self.env.get_next_state(node.state, action)
            child = Node(next_state, parent=node)
            node.children.append(child)
        return node

    def _simulation(self, node):
        # Save the original state
        original_state = self.env.current_state
        
        # Set environment to node's state
        self.env.current_state = node.state
        current_state = node.state
        total_reward = 0
        depth = 0
        
        while current_state not in [self.env.goal] and depth < self.depth_limit:
            action = np.random.choice(self.env.actions)
            next_state, reward, done = self.env.step(action)
            total_reward += (self.gamma ** depth) * reward
            current_state = next_state
            depth += 1
            if done:
                break
        
        # Restore the original state
        self.env.current_state = original_state
        return total_reward

    def _backpropagation(self, node, reward):
        while node is not None:
            node.visits += 1
            # Use average instead of cumulative sum
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            # node.value += reward
            reward = reward * self.gamma
            node = node.parent

    def get_best_action(self, iterations, min_visits=1000):
        
        def run_mcts():
            if self.root.children != []:
                print(f'child.visits: {self.root.children[0].visits}')
            for _ in range(iterations):
                leaf = self._selection(self.root)
                if leaf.state == self.env.goal:
                    continue
                if leaf.visits == 0:
                    leaf = self._expansion(leaf)
                simulation_result = self._simulation(leaf)
                self._backpropagation(leaf, simulation_result)
        
        while True:
            run_mcts()
            if all(child.visits >= min_visits for child in self.root.children):
                break

            
        # When calculating final values, ensure we're using the average
        values = [float('-inf')] * len(self.env.actions)
        for i, child in enumerate(self.root.children):
            if child.visits > 0:
                print(f'child.value: {child.value}, child.visits: {child.visits}')
                values[i] = child.value
            else:
                values[i] = 0
        
        
        
        return values, max(values), np.argmax(values)


In [6]:
def plot_graph_for_test(env,min_visits=1000):
    mcts = MonteCarloSearchTree(env)
    mins={}
    maxs={}
    avgs={}
    iterations=[500,1000,10000,100000]
    # time_slots=[1,5,10,50,100]
    for iteration in iterations:
        min_value = float('inf')
        max_value = float('-inf')
        avg_value = 0
        avg_iteration = 0
        for i in tqdm(range(5)):
            # env.reset()
            mcts.root = Node((0,4))
            probs_pi,optimal_value,optimal_action = mcts.get_best_action(iteration,min_visits=min_visits)
            print(f'iteration: {iteration}, optimal_value: {optimal_value+6}, optimal_action: {optimal_action}')
            min_value = min(min_value, optimal_value)
            max_value = max(max_value, optimal_value)
            avg_value += optimal_value
        avg_value /= 5
        mins[iteration] = min_value
        maxs[iteration] = max_value
        avgs[iteration] = avg_value
    
    return mins,maxs,avgs,iterations


cat_dynamics = catVsMonsters()
# cat_dynamics.current_state = (0,4)
mins_cat,maxs_cat, avgs_cat, iterations_cat = plot_graph_for_test(cat_dynamics,min_visits=1000)


 20%|██        | 1/5 [00:00<00:02,  1.67it/s]

iteration: 500, optimal_value: -inf, optimal_action: 0
iteration: 500, optimal_value: -inf, optimal_action: 0


 60%|██████    | 3/5 [00:01<00:01,  1.67it/s]

iteration: 500, optimal_value: -inf, optimal_action: 0


 80%|████████  | 4/5 [00:02<00:00,  1.66it/s]

iteration: 500, optimal_value: -inf, optimal_action: 0
iteration: 500, optimal_value: -inf, optimal_action: 0


100%|██████████| 5/5 [00:03<00:00,  1.66it/s]
 20%|██        | 1/5 [00:01<00:04,  1.18s/it]

iteration: 1000, optimal_value: -inf, optimal_action: 0


 40%|████      | 2/5 [00:02<00:03,  1.27s/it]

iteration: 1000, optimal_value: -inf, optimal_action: 0


 60%|██████    | 3/5 [00:04<00:02,  1.41s/it]

iteration: 1000, optimal_value: -inf, optimal_action: 0


 80%|████████  | 4/5 [00:05<00:01,  1.51s/it]

iteration: 1000, optimal_value: -inf, optimal_action: 0


100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


iteration: 1000, optimal_value: -inf, optimal_action: 0


 20%|██        | 1/5 [00:14<00:59, 14.86s/it]

iteration: 10000, optimal_value: -inf, optimal_action: 0


 40%|████      | 2/5 [00:31<00:47, 15.72s/it]

iteration: 10000, optimal_value: -inf, optimal_action: 0


In [4]:
def plot_graph_for_test(env,min_visits=1000):
    mcts = MonteCarloSearchTree(env)
    mins={}
    maxs={}
    avgs={}
    iterations=[500,1000,10000,100000]
    # time_slots=[1,5,10,50,100]
    for iteration in iterations:
        min_value = float('inf')
        max_value = float('-inf')
        avg_value = 0
        avg_iteration = 0
        for i in tqdm(range(5)):
            env.reset()
            mcts.root = Node((0,0))
            probs_pi,optimal_value,optimal_action = mcts.get_best_action(iteration,min_visits=min_visits)
            print(f'iteration: {iteration}, optimal_value: {optimal_value}, optimal_action: {optimal_action}')
            min_value = min(min_value, optimal_value)
            max_value = max(max_value, optimal_value)
            avg_value += optimal_value
        avg_value /= 5
        mins[iteration] = min_value
        maxs[iteration] = max_value
        avgs[iteration] = avg_value
    
    return mins,maxs,avgs,iterations


cat_dynamics = catVsMonsters()
# cat_dynamics.current_state = (0,0)
mins_cat,maxs_cat, avgs_cat, iterations_cat = plot_graph_for_test(cat_dynamics,min_visits=1000)


 20%|██        | 1/5 [00:01<00:04,  1.08s/it]

child.value: -3.814399801947498, child.visits: 125
child.value: -4.260777201702612, child.visits: 124
child.value: -3.315671392312747, child.visits: 125
child.value: -3.7295268008528346, child.visits: 125
iteration: 500, optimal_value: -3.315671392312747, optimal_action: 2


 60%|██████    | 3/5 [00:03<00:01,  1.02it/s]

child.value: -4.209917192329113, child.visits: 125
child.value: -3.5079612925614114, child.visits: 125
child.value: -3.311218419922749, child.visits: 125
child.value: -4.940803725088387, child.visits: 124
iteration: 500, optimal_value: -3.311218419922749, optimal_action: 2
child.value: -3.933572401274527, child.visits: 125
child.value: -4.545344322409392, child.visits: 125
child.value: -4.24745674633098, child.visits: 125
child.value: -4.671081031855238, child.visits: 124
iteration: 500, optimal_value: -3.933572401274527, optimal_action: 0


 80%|████████  | 4/5 [00:03<00:00,  1.05it/s]

child.value: -3.821802054096597, child.visits: 125
child.value: -3.6891946673470124, child.visits: 125
child.value: -4.347180464314795, child.visits: 125
child.value: -4.776524807723827, child.visits: 124
iteration: 500, optimal_value: -3.6891946673470124, optimal_action: 1
child.value: -3.6541282833969726, child.visits: 125
child.value: -4.408422807511806, child.visits: 125
child.value: -3.114189841852566, child.visits: 125
child.value: -4.7450667799653115, child.visits: 124
iteration: 500, optimal_value: -3.114189841852566, optimal_action: 2


100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
 20%|██        | 1/5 [00:01<00:07,  1.85s/it]

child.value: -3.157556169029629, child.visits: 250
child.value: -4.003453147552813, child.visits: 249
child.value: -3.478421819171639, child.visits: 250
child.value: -3.7781519941433, child.visits: 250
iteration: 1000, optimal_value: -3.157556169029629, optimal_action: 0


 20%|██        | 1/5 [00:03<00:14,  3.52s/it]


KeyboardInterrupt: 

In [37]:

# import matplotlib.pyplot as plt
# # plot with time slots
# #  and range  in lighter blue and avg in darker blue
# x_values = [500,1000,10000,100000]
# y_values =list(avgs_cat.values())
# plt.plot(x_values,y_values,label='avg')
# plt.fill_between(x_values,list(mins_cat.values()),list(maxs_cat.values()),alpha=0.5)
# plt.legend()
# # plt.xscale('log')
# plt.show()

In [None]:
from tqdm import tqdm
env_cat = catVsMonsters()
mcts_cat = MonteCarloSearchTree(env_cat)
probs_pi_cat={}
pi_cat={}
for k in tqdm(range(25)):
    i = k // 5
    j = k % 5
    action_count = [0,0,0,0]
    for p in range(5):
        if (i,j) not in env_cat.furniture and (i,j) != env_cat.goal:
            env_cat.current_state = (i,j)
            mcts_cat.root = Node((i,j))
            probs_pi_cat[(i,j)],optimal_value,optimal_action = mcts_cat.get_best_action(20000,min_visits=1000)
            print(f'state: {(i,j)}, optimal_value: {optimal_value}, optimal_action: {optimal_action}')
            action_count[optimal_action] += 1
    if (i,j) not in env_cat.furniture and (i,j) != env_cat.goal:
        pi_cat[(i,j)] = action_count.index(max(action_count))

In [None]:
grid_world = GridWorld()
mins_grid,maxs_grid, avgs_grid, iterations_grid = plot_graph_for_test(grid_world)


In [None]:
import matplotlib.pyplot as plt
# plot with time slots
#  and range  in lighter blue and avg in darker blue
x_values = [10,50,100,500,1000,10000,100000]
y_values =list(avgs_grid.values())
plt.plot(x_values,y_values,label='avg')
plt.fill_between(x_values,list(mins_grid.values()),list(maxs_grid.values()),alpha=0.5)
plt.legend()
plt.xscale('log')
plt.show()


In [None]:
from tqdm import tqdm
env = GridWorld()
mcts = MonteCarloSearchTree(env)
probs_pi={}
pi={}
for k in tqdm(range(25)):
    i = k // 5
    j = k % 5
    action_count = [0,0,0,0]
    for p in range(5):
        if (i,j) not in env.obstacles and (i,j) != env.goal:
            env.current_state = (i,j)
            mcts.root = Node((i,j))
            probs_pi[(i,j)],optimal_value,optimal_action = mcts.get_best_action(20000)
            print(f'state: {(i,j)}, optimal_value: {optimal_value}, optimal_action: {optimal_action}')
            action_count[optimal_action] += 1
    if (i,j) not in env.obstacles and (i,j) != env.goal:
        pi[(i,j)] = action_count.index(max(action_count))

    

In [None]:
def print_policy(policy):
    for i in range(5):
        for j in range(5):
            state = (i, j)
            if state == env.goal:
                print("G", end=" ")
                continue
            elif state in env.obstacles:
                print("O", end=" ")
                continue
            action = policy[state]
            if action == 0:
                print("↑", end=" ")
            elif action == 1:
                print("↓", end=" ")
            elif action == 2:
                print("←", end=" ")
            elif action == 3:
                print("→", end=" ")
            else:
                print(" ", end=" ")
        print()

print_policy(pi)

In [None]:
def get_greedy_policy(probs_pi):
    policy = {}
    for state in probs_pi:
        if state not in env.obstacles and state != env.goal:
            policy[state] = np.argmax(probs_pi[state])
    return policy
print(probs_pi)
print_policy(get_greedy_policy(probs_pi))