Continuous Plain MCTS (w/ wandb)
================================

### 1. Login to wandb

In [None]:
import wandb
wandb.login()

### 2. Import packages

In [None]:
import math
import json
from itertools import product
import os
import sys


#### - Import Plain_MCTS

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.Env.Grid.Cont_Grid import Continuous_Grid
from src.Planners.H_MCTS_continuous.Plain_MCTS_cont import Plain_MCTS_Cont

from utils import *

In [None]:
# Basic Setup for environment
l1_rows = 8
l1_cols = 8
l1_width = 2
l1_height = 2
goal_distance = 3
grid_setting = (l1_rows, l1_cols, l1_width, l1_height, goal_distance)
num_barrier = 15

In [None]:
# Test code
# Input: 
#   param
#       0. explorationConstant
#       1. iter_Limit
#       2. alpha (PW)
#       3. constant_c (PW)
#       4. gamma
#   grid_setting
#   folder_name: for local save


def test_plain_mcts(param, grid_setting, folder_name):
    mcts_result = {}
    mcts_result["iter_cnt"] = {}
    mcts_success_rate = 0

    explorationConstant = param[0]
    iter_Limit = param[1]
    alpha = param[2]
    constant_c = param[3]
    gamma = param[4]

    for random_seed in range(100):
        print("Random_seed", random_seed)
        MCTS = Plain_MCTS_Cont(grid_setting, random_seed=random_seed,
                                num_barrier = num_barrier,
                                explorationConstant=1 / math.sqrt(explorationConstant),
                                alpha=alpha,
                                constant_c=constant_c,
                                gamma=gamma,
                                iter_Limit = iter_Limit)
        traj, success, iter = MCTS.search("{}/tree/{}.png".format(folder_name, random_seed))
        iter += 1
        mcts_result["iter_cnt"][random_seed] = iter

        if success:
            mcts_success_rate += 1
            MCTS.env.plot_grid(0, traj, "{}/traj/{}.png".format(folder_name, random_seed))
            print(f'success with {iter}')
        else:
            MCTS.env.plot_grid(0, traj, "{}/traj/{}.png".format(folder_name, random_seed))
            print('Failed')
        
        wandb.log({f"{random_seed}": wandb.Image("{}/tree/{}.png".format(folder_name, random_seed))})
        wandb.log({f"{random_seed}": wandb.Image("{}/traj/{}.png".format(folder_name, random_seed))})

    mcts_result["success_rate"] = mcts_success_rate
    with open("{}/result.json".format(folder_name), 'w') as result_file:
        json.dump(mcts_result, result_file)
    x_values, y_values = cumul_plot(iter_Limit, mcts_result, folder_name)
    wandb.log({"iteration_plot": wandb.Image("{}/success_rate.png".format(folder_name))})
    
    data = [[x, y] for (x, y) in zip(x_values, y_values)]
    table = wandb.Table(data=data, columns = ["x", "y"])
    wandb.log(
        {"Iteration_Plot" : wandb.plot.line(table, "x", "y",
            title="Iteration vs Success Rate Plot")})
    return mcts_result


In [None]:
# Launch experiment
total_runs = 1
for run in range(total_runs):
    # 1. Start a new run to track script
    explorationConstant = [0.25, 0.5, 1.0]
    iter_Limit = [10000]
    alpha = [0.01, 0.025, 0.05]
    constant_c = [2]
    gamma = [1]

    # Create a list of parameter arrays
    parameters = [explorationConstant, iter_Limit, alpha, constant_c, gamma]

    # Generate all possible combinations
    param_combinations = list(product(*parameters))

    # Print the combinations
    for param in param_combinations:
        folder_name = os.path.join(PLAIN_MCTS_EXPERIMENT_DIR,"plain_mcts_{}_{}_{}_{}_{}" \
                    .format(param[0], param[1], param[2], param[3], param[4]))
        folder_exists = make_param_dir(folder_name)
        print("Param to check", param)


        if not folder_exists:

            wandb.init(
                # Project_name
                project='plain-mcts',
                # Run_name
                name=f"plain_mcts_{run}_{param[0]}_{param[1]}_{param[2]}_{param[3]}_{param[4]}",
                # Track hyperparameters and run metadata
                config={
                    "goal_distance": goal_distance,
                    "num_barrier": num_barrier,
                    "H_level": 2,
                    "explorationConstant": param[0],
                    "iter_Limit": param[1],
                    "alpha": param[2],
                    "constant_c": param[3],
                    "gamma": param[4],
                }
            )
            wandb.run.log_code(".")
            
            make_param_dir("{}/traj".format(folder_name))
            make_param_dir("{}/tree".format(folder_name))
            make_param_dir("{}/found_path".format(folder_name))
            mcts_result = test_plain_mcts(param, grid_setting, folder_name)
            wandb.log(mcts_result)
    
        # This simple block simulates a training loop loggin metrics
        wandb.finish()