<a href="https://colab.research.google.com/github/nicoRomeroCuruchet/DynamicProgramming/blob/main/testing_bary.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pickle
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from itertools import product
from scipy.spatial import KDTree
from scipy.optimize import minimize
from PolicyIteration import PolicyIteration
from classic_control.cartpole import CartPoleEnv 
from classic_control.continuous_mountain_car import Continuous_MountainCarEnv

In [None]:
env=CartPoleEnv(sutton_barto_reward=True)
# position thresholds:
x_lim = env.x_threshold + 0.5
theta_lim = env.theta_threshold_radians + 0.5
# velocity thresholds:
x_dot_lim = 2.5
theta_dot_lim = 2.5

bins_space = {
    "x_space": np.linspace(-x_lim, x_lim, 12),                         # position space         (0)
    "x_dot_space": np.linspace(-x_dot_lim, x_dot_lim, 12),             # velocity space         (1)
    "theta_space": np.linspace(-theta_lim, theta_lim, 12),             # angle space            (2)
    "theta_dot_space": np.linspace(-theta_dot_lim, theta_dot_lim, 12), # angular velocity space (3)
}

pi = PolicyIteration(
    env=env, 
    bins_space=bins_space,
    gamma=0.95,
    action_space=[0, 1]
)

pi.run()

In [None]:
def get_optimal_action(state, optimal_policy):
    """Returns the optimal action for a given state based on the optimal policy.

    Parameters:
    state (int): The current state.
    optimal_policy (dict): The optimal policy containing the action-value pairs for each state.

    Returns:
    int: The optimal action for the given state."""

    _, neighbors  = optimal_policy.kd_tree.query([state], k=5)
    simplex = optimal_policy.points[neighbors[0]]
    lambdas = optimal_policy.barycentric_coordinates(state, simplex)

    zero = 0
    one = 0

    for i,l in enumerate(lambdas):

        if optimal_policy.policy[tuple(simplex[i])][0] > 0:
            zero +=l
        else:
            one +=l

    return 0 if zero > one else 1

del pi

with open(env.__class__.__name__ + ".pkl" "rb") as f:
    pi = pickle.load(f)

num_episodes = 10000
mountain_car = CartPoleEnv(render_mode="human") # Continuous_MountainCarEnv(render_mode="human")  | CartPoleEnv(render_mode="human")
for episode in range(0, num_episodes):
    observation, _ = mountain_car.reset()
    total_reward = 0
    for timestep in range(1, 1000):
        action = get_optimal_action(observation, pi)
        observation, reward, done, terminated, info = mountain_car.step(action)
        total_reward += reward
        if done:
            print(f"Episode {episode} finished after {timestep} timesteps")
            print(f"Total reward: {total_reward}")
            break