In [71]:
import math
from typing import Optional, Union
from scipy.optimize import minimize


import gymnasium as gym
import numpy as np
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled


class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
    """
    ## Description

    This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
    ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
    A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
    The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces
     in the left and right direction on the cart.

    ## Action Space

    The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction
     of the fixed force the cart is pushed with.

    - 0: Push cart to the left
    - 1: Push cart to the right

    **Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle
     the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it

    ## Observation Space

    The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:

    | Num | Observation           | Min                 | Max               |
    |-----|-----------------------|---------------------|-------------------|
    | 0   | Cart Position         | -4.8                | 4.8               |
    | 1   | Cart Velocity         | -Inf                | Inf               |
    | 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
    | 3   | Pole Angular Velocity | -Inf                | Inf               |

    **Note:** While the ranges above denote the possible values for observation space of each element,
        it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
    -  The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates
       if the cart leaves the `(-2.4, 2.4)` range.
    -  The pole angle can be observed between  `(-.418, .418)` radians (or **±24°**), but the episode terminates
       if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

    ## Rewards
    Since the goal is to keep the pole upright for as long as possible, by default, a reward of `+1` is given for every step taken, including the termination step. The default reward threshold is 500 for v1 and 200 for v0 due to the time limit on the environment.

    If `sutton_barto_reward=True`, then a reward of `0` is awarded for every non-terminating step and `-1` for the terminating step. As a result, the reward threshold is 0 for v0 and v1.

    ## Starting State
    All observations are assigned a uniformly random value in `(-0.05, 0.05)`

    ## Episode End
    The episode ends if any one of the following occurs:

    1. Termination: Pole Angle is greater than ±12°
    2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
    3. Truncation: Episode length is greater than 500 (200 for v0)

    ## Arguments

    Cartpole only has `render_mode` as a keyword for `gymnasium.make`.
    On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.

    ```python
    >>> import gymnasium as gym
    >>> env = gym.make("CartPole-v1", render_mode="rgb_array")
    >>> env
    <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
    >>> env.reset(seed=123, options={"low": -0.1, "high": 0.1})  # default low=-0.05, high=0.05
    (array([ 0.03647037, -0.0892358 , -0.05592803, -0.06312564], dtype=float32), {})

    ```

    | Parameter               | Type       | Default                 | Description                                                                                   |
    |-------------------------|------------|-------------------------|-----------------------------------------------------------------------------------------------|
    | `sutton_barto_reward`   | **bool**   | `False`                 | If `True` the reward function matches the original sutton barto implementation                |

    ## Vectorized environment

    To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.

    ```python
    >>> import gymnasium as gym
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point")
    >>> envs
    CartPoleVectorEnv(CartPole-v1, num_envs=3)
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
    >>> envs
    SyncVectorEnv(CartPole-v1, num_envs=3)

    ```

    ## Version History
    * v1: `max_time_steps` raised to 500.
        - In Gymnasium `1.0.0a2` the `sutton_barto_reward` argument was added (related [GitHub issue](https://github.com/Farama-Foundation/Gymnasium/issues/790))
    * v0: Initial versions release.
    """

    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 50,
    }

    def __init__(
        self, sutton_barto_reward: bool = False, render_mode: Optional[str] = None
    ):
        self._sutton_barto_reward = sutton_barto_reward

        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates
        self.kinematics_integrator = "euler"

        # Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # Angle limit set to 2 * theta_threshold_radians so failing observation
        # is still within bounds.
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                self.theta_threshold_radians * 2,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        self.render_mode = render_mode

        self.screen_width = 600
        self.screen_height = 400
        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

        self.steps_beyond_terminated = None

    def step(self, action):
        assert self.action_space.contains(
            action
        ), f"{action!r} ({type(action)}) invalid"
        assert self.state is not None, "Call reset before using step method."
        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot**2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        terminated = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )

        if not terminated:
            if self._sutton_barto_reward:
                reward = 0.0
            elif not self._sutton_barto_reward:
                reward = 1.0
        elif self.steps_beyond_terminated is None:
            # Pole just fell!
            self.steps_beyond_terminated = 0
            if self._sutton_barto_reward:
                reward = -1.0
            elif not self._sutton_barto_reward:
                reward = 1.0
        else:
            if self.steps_beyond_terminated == 0:
                logger.warn(
                    "You are calling 'step()' even though this "
                    "environment has already returned terminated = True. You "
                    "should always call 'reset()' once you receive 'terminated = "
                    "True' -- any further steps are undefined behavior."
                )
            self.steps_beyond_terminated += 1
            if self._sutton_barto_reward:
                reward = -1.0
            elif not self._sutton_barto_reward:
                reward = 0.0

            reward = -1.0

        if self.render_mode == "human":
            self.render()
        # truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
        return np.array(self.state, dtype=np.float32), reward, terminated, False, {}

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        # Note that if you use custom reset bounds, it may lead to out-of-bound
        # state/observations.
        low, high = utils.maybe_parse_reset_bounds(
            options, -0.005, 0.005  # default low
        )  # default high
        #low = (-x_lim, -x_dot_lim, -theta_lim, -theta_dot_lim)
        #high = (x_lim, x_dot_lim, theta_lim, theta_dot_lim)
        self.state = self.np_random.uniform(low=low, high=high, size=(4,))
        #elf.state=np.array([0.0, 0.0, 0.01, 0.0])
        self.steps_beyond_terminated = None

        if self.render_mode == "human":
            self.render()
        return np.array(self.state, dtype=np.float32), {}

    def render(self):
        if self.render_mode is None:
            assert self.spec is not None
            gym.logger.warn(
                "You are calling render method without specifying any render mode. "
                "You can specify the render_mode at initialization, "
                f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
            )
            return

        try:
            import pygame
            from pygame import gfxdraw
        except ImportError as e:
            raise DependencyNotInstalled(
                "pygame is not installed, run `pip install gymnasium[classic-control]`"
            ) from e

        if self.screen is None:
            pygame.init()
            if self.render_mode == "human":
                pygame.display.init()
                self.screen = pygame.display.set_mode(
                    (self.screen_width, self.screen_height)
                )
            else:  # mode == "rgb_array"
                self.screen = pygame.Surface((self.screen_width, self.screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        world_width = self.x_threshold * 2
        scale = self.screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        self.surf = pygame.Surface((self.screen_width, self.screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + self.screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, self.screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if self.render_mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        elif self.render_mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )

    def close(self):
        if self.screen is not None:
            import pygame

            pygame.display.quit()
            pygame.quit()
            self.isopen = False

In [65]:
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from itertools import product
from scipy.spatial import KDTree

class PolicyIteration(object):
    """Policy Iteration Algorithm for gymnasium environment"""

    def __init__(self, env: gym.Env, gamma: float = 0.99, bins_space: dict = None):
        """Initializes the Policy Iteration.

        Parameters:
        - env (gym.Env): The environment in which the agent will interact.
        - gamma (float): The discount factor for future rewards. Default is 0.99.
        - bins_space (dict): A dictionary specifying the number of bins for each state variable. Default is None.

        Returns: None"""

        self.env = env
        self.gamma = gamma  # discaunt factor

        self.action_space = env.action_space
        self.bins_space = bins_space

        self.states_space = list(
            set(product(*bins_space.values()))
        )  # avoid repited states
        
        self.points = np.array([np.array(e) for e in self.states_space])
        self.kd_tree = KDTree(self.points)
        
        self.policy = {state: {0: 0.5, 1: 0.5} for state in self.states_space}
        self.value_function = {state: 0 for state in self.states_space}  # initialize value function

    def get_state(self, np_state: np.ndarray) -> tuple:
        """Discretizes the given state values based on the provided bins dictionary.

        Parameters:
        state (tuple): The state values to be discretized.
        bins_dict (dict): A dictionary containing the bins for each state value.

        Returns:
        tuple: The discretized states space values."""

        state = tuple(np_state)
        discretized_state = []
        for s_i, (_, bins) in zip(state, self.bins_space.items()):
            # # Digitize the value and adjust the index to be 0-based
            up_index = min(np.digitize(s_i, bins), len(bins) - 1)
            discretized_value = discretized_state.append(up_index)
            
        return tuple(discretized_state)
    
    
    def barycentric_coordinates(self, point, simplex):
        # Formulate the system of equations
        A = np.vstack([np.array(simplex).T, np.ones(len(simplex))])
        b = np.hstack([point, [1]])
        objective_function = lambda x: np.linalg.norm(A.dot(x) - b)

        # Define the constraint that the solution must be greater than zero
        constraints = ({'type': 'ineq', 'fun': lambda x: x})

        # Initial guess for the solution
        x0 = np.array([0.33, 0.33, 0.33, 0.33, 0.33])

        # Solve the optimization problem
        result = minimize(objective_function, x0, constraints=constraints, tol=1e-3)

        # The approximate solution
        x_approx = result.x
        return x_approx

    def get_transition_reward_function(self) -> dict:
        """Generate a transition reward function table.

        Returns:
            dict: A dictionary representing the transition reward function table.
                The keys are tuples of (state, action), and the values are dictionaries
                with 'reward' and 'next_state' as keys."""

        table = {}
        for state in tqdm(self.states_space):
            for action in range(self.env.action_space.n):
                self.env.reset() # TODO: is this necessary? might be slow
                self.env.state = np.array(state, dtype=np.float64)  # set the state
                obs, _, terminated, done, info = self.env.step(action)
                
                _, neighbors  = self.kd_tree.query([obs], k=5)
                simplex = self.points[neighbors[0]]
                lambdas = self.barycentric_coordinates(state, simplex)
                
                reward = (
                    0 if (-0.2 < obs[2] < 0.2) and (-2.4 < obs[0] < 2.4) else -1
                )  # TODO remove this hardcoded reward
                table[(state, action)] = {"reward": reward, 
                                          "next_state": obs,
                                          "simplex": simplex,
                                          "barycentric_coordinates":lambdas}
                

        return table

    def get_value(self, lambdas, simplex, value_function):
        """Retrieves the value of a given state from the value function.

        Parameters:
            state (any): The state for which the value needs to be retrieved.
            value_function (dict): A dictionary representing the value function.

        Returns:
            float: The value of the given state from the value function."""

        try:
            values           = np.array([value_function[tuple(e)] for e in list(simplex)])    
            next_state_value = np.dot(lambdas, values)
        except (
            KeyError
        ):  # if next_state is not in value_function, assume it's a 'dead' state.
            next_state_value = -500
            
        return next_state_value

    def evaluate_policy(self, transition_and_reward_function: dict) -> dict:
        """Evaluates the given policy using the provided transition and reward function.

        Args:
            transition_and_reward_function (dict): A dictionary representing the transition and reward function.

        Returns:
            dict: A dictionary representing the new value function after evaluating the policy.
        """
        theta =1e-2 # convergence threshold
        
        while True:
            delta = 0
            new_value_function = {}
            for state in self.states_space:
                new_val = 0
                for action in [0, 1]:
                    reward, next_state, simplex, bar_coor = transition_and_reward_function[(state, action)].values()
                    next_state_value = self.get_value(bar_coor, simplex,self.value_function)
                    new_val += self.policy[state][action] * (reward + self.gamma * next_state_value)
                new_value_function[state] = new_val

            delta = max(delta, max(abs(new_value_function[state] - self.value_function[state]) for state in self.states_space))
            print(f"delta: {delta}")
            if delta < theta:
                break

            self.value_function = new_value_function
        return new_value_function

    def improve_policy(self, transition_and_reward_function: dict) -> dict:
        """Improves the current policy based on the given transition and reward function.

        Args:
            transition_and_reward_function (dict): A dictionary representing the transition and reward function.
                The keys are tuples of (state, action) and the values are dictionaries with 'reward' and 'next_state' keys.

        Returns:
            dict: The new policy after improvement."""
        
        policy_stable = True
        new_policy = {}

        for state in self.states_space:
            action_values = {}
            for action in [0, 1]:
                reward, next_state, simplex, bar_coor = transition_and_reward_function[(state, action)].values()
                action_values[action] = reward + self.gamma * self.get_value(bar_coor, simplex,self.value_function)
                
            greedy_action, _ = max(action_values.items(), key=lambda pair: pair[1])
            new_policy[state] = {
                action: 1 if action is greedy_action else 0 for action in [0, 1]
            }
        if self.policy != new_policy:
            print(f"number of different actions: {sum([self.policy[state][0] != new_policy[state][0] for state in self.states_space])}")
            policy_stable = False

        self.policy = new_policy
        return policy_stable

    def run(self, nsteps):
        """Runs the policy iteration algorithm for a specified number of steps.

        Parameters:
        - nsteps (int): The number of steps to run the algorithm for. Default is 10 steps.
        """

        print("Generating transition and reward function table...")
        transition_and_reward_function = self.get_transition_reward_function()
        print("Running Policy Iteration algorithm...")
        for n in tqdm(range(nsteps)):
            print(f"solving step {n}")
            self.evaluate_policy(transition_and_reward_function)
            if self.improve_policy(transition_and_reward_function):
                break

if __name__ == "__main__":
    x_lim = 2.5
    x_dot_lim = 2.5
    theta_lim = 0.25
    theta_dot_lim = 2.5


    bins_space = {
        "x_space": np.linspace(-x_lim, x_lim, 20),
        "x_dot_space": np.linspace(-x_dot_lim, x_dot_lim, 20),
        "theta_space": np.linspace(-theta_lim, theta_lim, 20),
        "theta_dot_space": np.linspace(-theta_dot_lim, theta_dot_lim, 20),
    }

    pi = PolicyIteration(
        env=CartPoleEnv(sutton_barto_reward=False), bins_space=bins_space
    )
    STEPS = 10000
    # start the policy iteration algorithm
    pi.run(nsteps=STEPS)

Generating transition and reward function table...


100%|██████████| 160000/160000 [06:06<00:00, 436.73it/s]


Running Policy Iteration algorithm...


  0%|          | 0/10000 [00:00<?, ?it/s]

solving step 0
delta: 1.0
delta: 1.0243957523147484
delta: 1.0312483499915857
delta: 1.0376260884638726
delta: 1.0396093614202817
delta: 1.041090393418724
delta: 1.040345575595925
delta: 1.0390943228656653
delta: 1.03640802475452
delta: 1.0332597848520528
delta: 1.029094045465671
delta: 1.0245324258091308
delta: 1.0192117376793526
delta: 1.013555397864554
delta: 1.0073099408838662
delta: 1.0007785314491198
delta: 0.9937772467000556
delta: 0.9865332679364087
delta: 0.9789096536953856
delta: 0.9710833396068388
delta: 0.9629499577953453
delta: 0.9546514720976731
delta: 0.9461063881954814
delta: 0.9374311403302187
delta: 0.9285605370988641
delta: 0.9195916328833853
delta: 0.9104711049930039
delta: 0.9012808372129726
delta: 0.8919763746394871
delta: 0.8826274322533614
delta: 0.8731963657259989
delta: 0.8637429344085312
delta: 0.8542348641787108
delta: 0.844723637470878
delta: 0.8351813150885476
delta: 0.8256524085784918
delta: 0.8161125481537965
delta: 0.8066003141079321
delta: 0.7970943238

delta: 0.019872188647596545
delta: 0.01959709109223695
delta: 0.019325802173227657
delta: 0.01905826121635812
delta: 0.018794424460210735
delta: 0.018534233018812074
delta: 0.01827764419039113
delta: 0.018024600824560366
delta: 0.017775061255306923
delta: 0.01752897001641429
delta: 0.017286286454492483
delta: 0.017046956737601704
delta: 0.01681094120202431
delta: 0.016578187601766103
delta: 0.01634865724042811
delta: 0.0161222994109238
delta: 0.0158990763622171
delta: 0.015678938880512305
delta: 0.015461850138763111
delta: 0.015247762371899398
delta: 0.015036639656045736
delta: 0.014828435631756065
delta: 0.01462311525776272
delta: 0.014420633538350103
delta: 0.01422095629484943
delta: 0.014024039854575676
delta: 0.013829850881734274
delta: 0.013638346987377759
delta: 0.013449495659230593
delta: 0.013263255753685144
delta: 0.01307959556328342
delta: 0.012898475152553601
delta: 0.012719863600324288
delta: 0.012543722143163905
delta: 0.012370020628168277
delta: 0.012198721428958947
delta

  0%|          | 1/10000 [22:44<3791:08:29, 1364.95s/it]

number of different actions: 160000
solving step 1
delta: 2.6215958234333243
delta: 2.89228437620147
delta: 3.212432639639662
delta: 3.5910927855244275
delta: 3.974981046250619
delta: 4.353627528431424
delta: 4.6018895114579905
delta: 4.6945401256464905
delta: 3.9187172526059157
delta: 2.8984519088851215
delta: 2.060265708174306
delta: 1.929130452628904
delta: 1.7941394698065523
delta: 1.6031723218424005
delta: 1.580706029815964
delta: 1.4981864389973154
delta: 1.3801807135444406
delta: 1.347143732494672
delta: 1.305741146586925
delta: 1.2387217541713795
delta: 1.1806125958537308
delta: 1.1624271080474458
delta: 1.1350920495765493
delta: 1.1085567487299173
delta: 1.0760660597940799
delta: 1.0518612175174624
delta: 1.03791322191978
delta: 1.017579229135137
delta: 0.9922678257898312
delta: 0.9670984199085666
delta: 0.9560146254323385
delta: 0.9395607377759205
delta: 0.9186423448342289
delta: 0.8947596586021476
delta: 0.885337222512895
delta: 0.8713467964809496
delta: 0.8534700150205694
d

  0%|          | 2/10000 [40:18<3281:51:47, 1181.71s/it]

number of different actions: 30397
solving step 2
delta: 8.019702038007473
delta: 2.828666359262442
delta: 2.75064264809005
delta: 2.694867166719604
delta: 2.099949094806604
delta: 2.0407601876471944
delta: 1.9963471902416217
delta: 1.9654923992446527
delta: 1.9467814867487192
delta: 1.9389224814133996
delta: 1.940759878079902
delta: 1.951226560148612
delta: 1.9694215549360052
delta: 1.9230699210190778
delta: 1.8947971177849183
delta: 1.8755177841932351
delta: 0.23170821038121758
delta: 0.16819010885267538
delta: 0.16776835180254857
delta: 0.1514439583503986
delta: 0.13943815138635784
delta: 0.14744872910664242
delta: 0.1357532086927904
delta: 0.1435487753125102
delta: 0.13216065583077352
delta: 0.13974850870434352
delta: 0.1286612471893278
delta: 0.136047835365531
delta: 0.1252540218591207
delta: 0.132444913337892
delta: 0.12193691345072466
delta: 0.12893734853031447
delta: 0.11870762565106929
delta: 0.12552266182724736
delta: 0.11556385386452916
delta: 0.12219840425096606
delta: 0.11

  0%|          | 3/10000 [54:03<2829:57:27, 1019.09s/it]

number of different actions: 12598
solving step 3
delta: 7.787851959539971
delta: 7.229091488213925
delta: 6.406380525612022
delta: 3.4896329051642048
delta: 3.309963776313639
delta: 3.195690696232325
delta: 3.2259827466812876
delta: 3.2896303673719594
delta: 3.3738739486404485
delta: 3.4737125510577522
delta: 3.584971760682347
delta: 3.7065566147709195
delta: 1.6207039824211407
delta: 0.34565026437108815
delta: 0.2100717139137025
delta: 0.17882489666950852
delta: 0.14901417265507844
delta: 0.1352329144266542
delta: 0.11518398315851508
delta: 0.10588556342766564
delta: 0.09144450297714979
delta: 0.08506732959161578
delta: 0.07358257651390421
delta: 0.07832792114466836
delta: 0.06595802761792458
delta: 0.0713708294400699
delta: 0.059765405667280724
delta: 0.06464662535331467
delta: 0.05396555677156334
delta: 0.05836151099081377
delta: 0.0486367255106126
delta: 0.052593427668107484
delta: 0.04379360243125818
delta: 0.04735455153560686
delta: 0.039419305521370074
delta: 0.0426246050742404

  0%|          | 4/10000 [58:10<1982:50:19, 714.11s/it] 

number of different actions: 5031
solving step 4
delta: 5.726188866911599
delta: 5.009205622042693
delta: 0.1005775068694601
delta: 0.06078050249359235
delta: 0.040870634705432174
delta: 0.03867166053446747
delta: 0.037031142767026015
delta: 0.03572531553916525
delta: 0.03577385075168138
delta: 0.036113262607274876
delta: 0.036633124780554915
delta: 0.03729997313404976
delta: 0.03808140342458355
delta: 0.03897589855601069
delta: 0.02495251273948984
delta: 0.025482619339877388
delta: 0.01251234810677282
delta: 0.01330098078941866
delta: 0.011970732795981043
delta: 0.01272501818109184
delta: 0.010648968780122559
delta: 0.011320786469656952
delta: 0.009580405115744206


  0%|          | 5/10000 [59:42<1359:46:47, 489.77s/it]

number of different actions: 1971
solving step 5
delta: 4.568844664427701
delta: 0.03492032195747896
delta: 0.010995585359137916
delta: 0.009650718272016023


  0%|          | 6/10000 [1:00:01<914:59:50, 329.60s/it]

number of different actions: 599
solving step 6
delta: 3.1623935514519843
delta: 0.00951232276290681


  0%|          | 7/10000 [1:00:12<626:00:09, 225.52s/it]

number of different actions: 184
solving step 7
delta: 3.0200223868629195
delta: 0.008738788436426004


  0%|          | 8/10000 [1:00:23<436:29:42, 157.26s/it]

number of different actions: 94
solving step 8
delta: 0.008738788436426004


  0%|          | 8/10000 [1:00:31<1259:46:02, 453.88s/it]


In [72]:
def get_optimal_action(state, optimal_policy):
    """Returns the optimal action for a given state based on the optimal policy.

    Parameters:
    state (int): The current state.
    optimal_policy (dict): The optimal policy containing the action-value pairs for each state.

    Returns:
    int: The optimal action for the given state."""
    
    _, neighbors  = optimal_policy.kd_tree.query([state], k=5)
    simplex = optimal_policy.points[neighbors[0]]
    lambdas = optimal_policy.barycentric_coordinates(state, simplex)
    
    zero = 0 
    one = 0
    
    for i,l in enumerate(lambdas):
    
        if optimal_policy.policy[tuple(simplex[i])][0] > 0:
            zero +=l
        else:
            one +=l
                                 
    return 0 if zero > one else 1


num_episodes = 10000
cartpole = CartPoleEnv(render_mode="human")
max_obs = np.array([0.0, 0.0, 0.0, 0.0])
min_obs = np.array([0.0, 0.0, 0.0, 0.0])
limits = np.array([x_lim, x_dot_lim, theta_lim, theta_dot_lim])
for episode in range(0, num_episodes):
    observation, _ = cartpole.reset()
    for timestep in range(1, 1000):
        action = get_optimal_action(observation, pi)
        observation, reward, done, terminated, info = cartpole.step(action)
        max_obs = np.maximum(max_obs, observation)
        min_obs = np.minimum(min_obs, observation)
        if done:
            #print(f"max_obs: {max_obs}")
            #print(f"min_obs: {min_obs}")
            #check limits
            #if np.all(max_obs <= limits) and np.all(min_obs >= -limits):
            #    print(f"Episode {episode} finished after {timestep} timesteps")
            #else:
            #    print(f"Episode {episode} finished after {timestep} timesteps with out of limits observations")
            break

KeyboardInterrupt: 

In [42]:
table = pi.get_transition_reward_function()

  1%|▏         | 2305/160000 [00:05<05:59, 438.26it/s]


KeyboardInterrupt: 

In [36]:
for e in table.keys():
    
    if abs(table[e]["barycentric_coordinates"].sum() - 1) > 0.2:
        print(table[e]["barycentric_coordinates"], table[e]["barycentric_coordinates"].sum())

In [1]:
import numpy as np
from scipy.interpolate import LinearNDInterpolator
from itertools import product

x_lim = 2.5
x_dot_lim = 2.5
theta_lim = 0.25
theta_dot_lim = 2.5

bins_space = {
        "x_space": np.linspace(-x_lim, x_lim, 12),
        "x_dot_space": np.linspace(-x_dot_lim, x_dot_lim, 12),
        "theta_space": np.linspace(-theta_lim, theta_lim,12),
        "theta_dot_space": np.linspace(-theta_dot_lim, theta_dot_lim, 12),

    }

states_space = list(set(product(*bins_space.values())))  # avoid repited states

states_space = [np.array([a[0],a[1],a[2],a[3]]) for a in states_space]
print(states_space[0])
values = np.random.rand(len(states_space))
print(values[0])
# Create an interpolator
interpolator = LinearNDInterpolator(states_space, values)



[ 2.5         0.68181818 -0.15909091  1.59090909]
0.9503234898122915


In [5]:
# Evaluate the function at any point


point_to_evaluate = np.array([-0.13157895, 2.43684211,  0.14, 1.7])  # Point to evaluate, shape: (1, 4)
interpolated_value = interpolator(point_to_evaluate)

print("Interpolated Value:", interpolated_value)


Interpolated Value: [0.51545684]


In [14]:
import numpy as np

class BarycentricInterpolatorND:
    def __init__(self, points, values):
        self.points = np.array(points)
        self.values = np.array(values)
        self.weights = self.calculate_weights()

    def calculate_weights(self):
        n = len(self.points)
        w = np.zeros(n)
        for j in range(n):
            distances = np.linalg.norm(self.points[j] - np.delete(self.points, j, axis=0), axis=1)
            w[j] = 1.0 / np.prod(distances)
        return w

    def interpolate(self, point):
        distances = np.linalg.norm(point - self.points, axis=1)
        numerator = np.sum(self.weights * self.values / distances)
        denominator = np.sum(self.weights / distances)
        return numerator / denominator

# Example usage:
points = [np.array([1, 2, 3]), 
          np.array([4, 5, 6]), 
          np.array([7, 8, 9]),
          np.array([11, 4,7])]  # Example n-dimensional points
values = [0.1, 0.5, 0.8, 0.6]   # Example corresponding values

interp = BarycentricInterpolatorND(points, values)
interp_point = np.array([2, 3, 4])  # Example n-dimensional interpolation point
approx_value = interp.interpolate(interp_point)
print("Approximate value at point", interp_point, ":", approx_value)


Approximate value at point [2 3 4] : 0.4144559288588087


In [12]:
interp.weights * 

array([0.0016905 , 0.00518622, 0.00308642, 0.00213046])

In [16]:
distances = np.linalg.norm(interp_point - points, axis=1)

In [17]:
distances

array([1.73205081, 3.46410162, 8.66025404, 9.53939201])

In [68]:
s       = np.array([1.7105263157894735, 0.6578947368421053, -0.11842105263157895, 0.9210526315789473])  
s_prime = np.array([ 1.7236842,   0.46455497, -0.1,         1.1742967 ])

In [69]:
pi.get_state(s)

(16, 12, 5, 13)

In [70]:
s_prime_indexes = np.array(list(pi.get_state(s_prime)))

In [92]:
neighbours  = []
for i in  [1,-1]:
    for j in [1,-1]:
        for k in [1, -1]:
            for l in [1,-1]:
                ix, ix_dot, itheta, itheta_dot = tuple(s_prime_indexes + np.array([i,j,k,l]))
                x = pi.bins_space['x_space'][ix]
                x_dot = pi.bins_space['x_dot_space'][ix_dot]
                theta = pi.bins_space['theta_space'][itheta]
                theta_dot = pi.bins_space['theta_dot_space'][itheta_dot]
                dist = np.linalg.norm(np.array([x,x_dot,theta,theta_dot]) - s_prime) 
                neighbours.append(([x,x_dot,theta,theta_dot], dist))
                

In [94]:
neighbours.sort(key=lambda x: x[1])
for e in neighbours:
    print(e)

([1.973684210526316, 0.6578947368421053, -0.09210526315789475, 1.1842105263157894], 0.3162923987827456)
([1.973684210526316, 0.6578947368421053, -0.14473684210526316, 1.1842105263157894], 0.3193429812904748)
([1.4473684210526314, 0.6578947368421053, -0.09210526315789475, 1.1842105263157894], 0.3374778303235288)
([1.4473684210526314, 0.6578947368421053, -0.14473684210526316, 1.1842105263157894], 0.3403385727942421)
([1.973684210526316, 0.1315789473684208, -0.09210526315789475, 1.1842105263157894], 0.4165737002413015)
([1.973684210526316, 0.1315789473684208, -0.14473684210526316, 1.1842105263157894], 0.4188946238667607)
([1.4473684210526314, 0.1315789473684208, -0.09210526315789475, 1.1842105263157894], 0.4328787961599214)
([1.4473684210526314, 0.1315789473684208, -0.14473684210526316, 1.1842105263157894], 0.4351127558881812)
([1.973684210526316, 0.6578947368421053, -0.09210526315789475, 0.6578947368421053], 0.605486238595768)
([1.973684210526316, 0.6578947368421053, -0.14473684210526316

In [88]:
def barycentric_coordinates(point, simplex):
    # Formulate the system of equations
    A = np.vstack([np.array(simplex).T, np.ones(len(simplex))])
    b = np.hstack([point, [1]])

    # Solve the system of equations
    lambdas = np.linalg.solve(A, b)

    # Normalize the barycentric coordinates
    lambdas /= np.sum(lambdas)

    return lambdas

In [89]:
bary_coords = barycentric_coordinates(list(s_prime), simplex)
print("Barycentric coordinates:", bary_coords, np.abs(bary_coords).sum())

Barycentric coordinates: [-0.01118185  0.47500002  0.36734556  0.15        0.01883627] 1.0223636939999978


In [67]:
for e in [neighbours[i][0] for i in range(5)]:
    print(e)

[1.973684210526316, 0.6578947368421053, -0.09210526315789475, 1.1842105263157894]
[1.973684210526316, 0.6578947368421053, -0.14473684210526316, 1.1842105263157894]
[1.4473684210526314, 0.6578947368421053, -0.09210526315789475, 1.1842105263157894]
[1.4473684210526314, 0.6578947368421053, -0.14473684210526316, 1.1842105263157894]
[1.973684210526316, 0.1315789473684208, -0.09210526315789475, 1.1842105263157894]
