In [3]:
import math
from typing import Optional, Union
from scipy.optimize import minimize


import gymnasium as gym
import numpy as np
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled


class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
    """
    ## Description

    This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
    ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
    A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
    The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces
     in the left and right direction on the cart.

    ## Action Space

    The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction
     of the fixed force the cart is pushed with.

    - 0: Push cart to the left
    - 1: Push cart to the right

    **Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle
     the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it

    ## Observation Space

    The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:

    | Num | Observation           | Min                 | Max               |
    |-----|-----------------------|---------------------|-------------------|
    | 0   | Cart Position         | -4.8                | 4.8               |
    | 1   | Cart Velocity         | -Inf                | Inf               |
    | 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
    | 3   | Pole Angular Velocity | -Inf                | Inf               |

    **Note:** While the ranges above denote the possible values for observation space of each element,
        it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
    -  The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates
       if the cart leaves the `(-2.4, 2.4)` range.
    -  The pole angle can be observed between  `(-.418, .418)` radians (or **±24°**), but the episode terminates
       if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

    ## Rewards
    Since the goal is to keep the pole upright for as long as possible, by default, a reward of `+1` is given for every step taken, including the termination step. The default reward threshold is 500 for v1 and 200 for v0 due to the time limit on the environment.

    If `sutton_barto_reward=True`, then a reward of `0` is awarded for every non-terminating step and `-1` for the terminating step. As a result, the reward threshold is 0 for v0 and v1.

    ## Starting State
    All observations are assigned a uniformly random value in `(-0.05, 0.05)`

    ## Episode End
    The episode ends if any one of the following occurs:

    1. Termination: Pole Angle is greater than ±12°
    2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
    3. Truncation: Episode length is greater than 500 (200 for v0)

    ## Arguments

    Cartpole only has `render_mode` as a keyword for `gymnasium.make`.
    On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.

    ```python
    >>> import gymnasium as gym
    >>> env = gym.make("CartPole-v1", render_mode="rgb_array")
    >>> env
    <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
    >>> env.reset(seed=123, options={"low": -0.1, "high": 0.1})  # default low=-0.05, high=0.05
    (array([ 0.03647037, -0.0892358 , -0.05592803, -0.06312564], dtype=float32), {})

    ```

    | Parameter               | Type       | Default                 | Description                                                                                   |
    |-------------------------|------------|-------------------------|-----------------------------------------------------------------------------------------------|
    | `sutton_barto_reward`   | **bool**   | `False`                 | If `True` the reward function matches the original sutton barto implementation                |

    ## Vectorized environment

    To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor.

    ```python
    >>> import gymnasium as gym
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point")
    >>> envs
    CartPoleVectorEnv(CartPole-v1, num_envs=3)
    >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
    >>> envs
    SyncVectorEnv(CartPole-v1, num_envs=3)

    ```

    ## Version History
    * v1: `max_time_steps` raised to 500.
        - In Gymnasium `1.0.0a2` the `sutton_barto_reward` argument was added (related [GitHub issue](https://github.com/Farama-Foundation/Gymnasium/issues/790))
    * v0: Initial versions release.
    """

    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 50,
    }

    def __init__(
        self, sutton_barto_reward: bool = False, render_mode: Optional[str] = None
    ):
        self._sutton_barto_reward = sutton_barto_reward

        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates
        self.kinematics_integrator = "euler"

        # Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        # Angle limit set to 2 * theta_threshold_radians so failing observation
        # is still within bounds.
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                self.theta_threshold_radians * 2,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)

        self.render_mode = render_mode

        self.screen_width = 600
        self.screen_height = 400
        self.screen = None
        self.clock = None
        self.isopen = True
        self.state = None

        self.steps_beyond_terminated = None

    def step(self, action):
        assert self.action_space.contains(
            action
        ), f"{action!r} ({type(action)}) invalid"
        assert self.state is not None, "Call reset before using step method."
        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action == 1 else -self.force_mag
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        # For the interested reader:
        # https://coneural.org/florian/papers/05_cart_pole.pdf
        temp = (
            force + self.polemass_length * theta_dot**2 * sintheta
        ) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
        )
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        terminated = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )

        if not terminated:
            if self._sutton_barto_reward:
                reward = 0.0
            elif not self._sutton_barto_reward:
                reward = 1.0
        elif self.steps_beyond_terminated is None:
            # Pole just fell!
            self.steps_beyond_terminated = 0
            if self._sutton_barto_reward:
                reward = -1.0
            elif not self._sutton_barto_reward:
                reward = 1.0
        else:
            if self.steps_beyond_terminated == 0:
                logger.warn(
                    "You are calling 'step()' even though this "
                    "environment has already returned terminated = True. You "
                    "should always call 'reset()' once you receive 'terminated = "
                    "True' -- any further steps are undefined behavior."
                )
            self.steps_beyond_terminated += 1
            if self._sutton_barto_reward:
                reward = -1.0
            elif not self._sutton_barto_reward:
                reward = 0.0

            reward = -1.0

        if self.render_mode == "human":
            self.render()
        # truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
        return np.array(self.state, dtype=np.float32), reward, terminated, False, {}

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ):
        super().reset(seed=seed)
        # Note that if you use custom reset bounds, it may lead to out-of-bound
        # state/observations.
        low, high = utils.maybe_parse_reset_bounds(
            options, -0.005, 0.005  # default low
        )  # default high
        #low = (-x_lim, -x_dot_lim, -theta_lim, -theta_dot_lim)
        #high = (x_lim, x_dot_lim, theta_lim, theta_dot_lim)
        self.state = self.np_random.uniform(low=0, high=0, size=(4,))
        #elf.state=np.array([0.0, 0.0, 0.01, 0.0])
        self.steps_beyond_terminated = None

        if self.render_mode == "human":
            self.render()
        return np.array(self.state, dtype=np.float32), {}

    def render(self):
        if self.render_mode is None:
            assert self.spec is not None
            gym.logger.warn(
                "You are calling render method without specifying any render mode. "
                "You can specify the render_mode at initialization, "
                f'e.g. gym.make("{self.spec.id}", render_mode="rgb_array")'
            )
            return

        try:
            import pygame
            from pygame import gfxdraw
        except ImportError as e:
            raise DependencyNotInstalled(
                "pygame is not installed, run `pip install gymnasium[classic-control]`"
            ) from e

        if self.screen is None:
            pygame.init()
            if self.render_mode == "human":
                pygame.display.init()
                self.screen = pygame.display.set_mode(
                    (self.screen_width, self.screen_height)
                )
            else:  # mode == "rgb_array"
                self.screen = pygame.Surface((self.screen_width, self.screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        world_width = self.x_threshold * 2
        scale = self.screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        self.surf = pygame.Surface((self.screen_width, self.screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + self.screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, self.screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if self.render_mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        elif self.render_mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )

    def close(self):
        if self.screen is not None:
            import pygame

            pygame.display.quit()
            pygame.quit()
            self.isopen = False

In [8]:
import numpy as np
from tqdm import tqdm
import gymnasium as gym
from itertools import product
from scipy.spatial import KDTree

class PolicyIteration(object):
    """Policy Iteration Algorithm for gymnasium environment"""

    def __init__(self, env: gym.Env, gamma: float = 0.99, bins_space: dict = None):
        """Initializes the Policy Iteration.

        Parameters:
        - env (gym.Env): The environment in which the agent will interact.
        - gamma (float): The discount factor for future rewards. Default is 0.99.
        - bins_space (dict): A dictionary specifying the number of bins for each state variable. Default is None.

        Returns: None"""

        self.env = env
        self.gamma = gamma  # discaunt factor

        self.action_space = range(env.action_space.n)
        self.bins_space = bins_space
        
        self.states_space = list(
            set(product(*bins_space.values()))
        )  # avoid repited states
        
        self.points = np.array([np.array(e) for e in self.states_space])
        self.kd_tree = KDTree(self.points)
        
        self.policy = {state: {0: 0.5, 1: 0.5} for state in self.states_space}
        self.value_function = {state: 0 for state in self.states_space}  # initialize value function


    def barycentric_coordinates(self, point, simplex):
        # Formulate the system of equations
        A = np.vstack([np.array(simplex).T, np.ones(len(simplex))])
        b = np.hstack([point, [1]])
        objective_function = lambda x: np.linalg.norm(A.dot(x) - b)

        # Define the constraint that the solution must be greater than zero
        constraints = ({'type': 'ineq', 'fun': lambda x: x})

        # Initial guess for the solution
        x0 = np.array([0.33, 0.33, 0.33, 0.33, 0.33])

        # Solve the optimization problem
        result = minimize(objective_function, x0, constraints=constraints, tol=1e-3)

        # The approximate solution
        x_approx = result.x
        return x_approx

    def get_transition_reward_function(self) -> dict:
        """Generate a transition reward function table.

        Returns:
            dict: A dictionary representing the transition reward function table.
                The keys are tuples of (state, action), and the values are dictionaries
                with 'reward' and 'next_state' as keys."""

        table = {}
        for state in tqdm(self.states_space):
            for action in range(self.env.action_space.n):
                self.env.reset() # TODO: is this necessary? might be slow
                self.env.state = np.array(state, dtype=np.float64)  # set the state
                obs, _, terminated, done, info = self.env.step(action)
                
                _, neighbors  = self.kd_tree.query([obs], k=5)
                simplex = self.points[neighbors[0]]
                lambdas = self.barycentric_coordinates(state, simplex)
                
                reward = (
                    0 if (-0.2 < obs[2] < 0.2) and (-2.4 < obs[0] < 2.4) else -1
                )  # TODO remove this hardcoded reward
                table[(state, action)] = {"reward": reward, 
                                          "next_state": obs,
                                          "simplex": simplex,
                                          "barycentric_coordinates":lambdas}
                

        return table

    def get_value(self, lambdas, simplex, value_function):
        """Retrieves the value of a given state from the value function.

        Parameters:
            state (any): The state for which the value needs to be retrieved.
            value_function (dict): A dictionary representing the value function.

        Returns:
            float: The value of the given state from the value function."""

        try:
            values           = np.array([value_function[tuple(e)] for e in list(simplex)])    
            next_state_value = np.dot(lambdas, values)
        except (
            KeyError
        ):  # if next_state is not in value_function, assume it's a 'dead' state.
            next_state_value = -500
            
        return next_state_value

    def evaluate_policy(self, transition_and_reward_function: dict) -> dict:
        """Evaluates the given policy using the provided transition and reward function.

        Args:
            transition_and_reward_function (dict): A dictionary representing the transition and reward function.

        Returns:
            dict: A dictionary representing the new value function after evaluating the policy.
        """
        theta = 1e-2 # convergence threshold
        
        while True:
            delta = 0
            new_value_function = {}
            for state in self.states_space:
                new_val = 0
                for action in self.action_space:
                    reward, next_state, simplex, bar_coor = transition_and_reward_function[(state, action)].values()
                    next_state_value = self.get_value(bar_coor, simplex,self.value_function)
                    new_val += self.policy[state][action] * (reward + self.gamma * next_state_value)
                new_value_function[state] = new_val

            delta = max(delta, max(abs(new_value_function[state] - self.value_function[state]) for state in self.states_space))
            print(f"delta: {delta}")
            if delta < theta:
                break

            self.value_function = new_value_function
        return new_value_function

    def improve_policy(self, transition_and_reward_function: dict) -> dict:
        """Improves the current policy based on the given transition and reward function.

        Args:
            transition_and_reward_function (dict): A dictionary representing the transition and reward function.
                The keys are tuples of (state, action) and the values are dictionaries with 'reward' and 'next_state' keys.

        Returns:
            dict: The new policy after improvement."""
        
        policy_stable = True
        new_policy = {}

        for state in self.states_space:
            action_values = {}
            for action in self.action_space:
                reward, next_state, simplex, bar_coor = transition_and_reward_function[(state, action)].values()
                action_values[action] = reward + self.gamma * self.get_value(bar_coor, simplex,self.value_function)
                
            greedy_action, _ = max(action_values.items(), key=lambda pair: pair[1])
            new_policy[state] = {
                action: 1 if action is greedy_action else 0 for action in self.action_space
            }
        if self.policy != new_policy:
            print(f"number of different actions: {sum([self.policy[state][0] != new_policy[state][0] for state in self.states_space])}")
            policy_stable = False

        self.policy = new_policy
        return policy_stable

    def run(self, nsteps):
        """Runs the policy iteration algorithm for a specified number of steps.

        Parameters:
        - nsteps (int): The number of steps to run the algorithm for. Default is 10 steps.
        """

        print("Generating transition and reward function table...")
        transition_and_reward_function = self.get_transition_reward_function()
        print("Running Policy Iteration algorithm...")
        for n in tqdm(range(nsteps)):
            print(f"solving step {n}")
            self.evaluate_policy(transition_and_reward_function)
            if self.improve_policy(transition_and_reward_function):
                break

if __name__ == "__main__":

    x_lim = 2.5
    x_dot_lim = 2.5
    theta_lim = 0.25
    theta_dot_lim = 2.5


    bins_space = {
        "x_space": np.linspace(-x_lim, x_lim, 20),
        "x_dot_space": np.linspace(-x_dot_lim, x_dot_lim, 20),
        "theta_space": np.linspace(-theta_lim, theta_lim, 20),
        "theta_dot_space": np.linspace(-theta_dot_lim, theta_dot_lim, 20),
    }

    pi = PolicyIteration(
        env=CartPoleEnv(sutton_barto_reward=False), bins_space=bins_space
    )
    STEPS = 10000
    # start the policy iteration algorithm
    pi.run(nsteps=STEPS)

Generating transition and reward function table...


100%|██████████| 160000/160000 [07:36<00:00, 350.70it/s]


Running Policy Iteration algorithm...


  0%|          | 0/10000 [00:00<?, ?it/s]

solving step 0
delta: 1.0
delta: 1.0243957523147484
delta: 1.0312483499915857
delta: 1.037626088463456
delta: 1.039609361421367
delta: 1.041090393421829
delta: 1.0403455756003304
delta: 1.0390943228718292
delta: 1.0364080609313415
delta: 1.0332598377578606
delta: 1.0290941730639904
delta: 1.024532599239567
delta: 1.0192120083867628
delta: 1.013555733548733
delta: 1.007310373197539
delta: 1.0007790269298376
delta: 0.9937778189916635
delta: 0.9865338857429506
delta: 0.9789103203978833
delta: 0.9710840291767546
delta: 0.9629506702770065
delta: 0.9546521880092165
delta: 0.946107108168917
delta: 0.9374318509384238
delta: 0.92856124094385
delta: 0.9195923211709776
delta: 0.9104717822558435
delta: 0.9012814977048933
delta: 0.8919770242463798
delta: 0.8826280669479907
delta: 0.8731969921164016
delta: 0.8637435492427947
delta: 0.8542354742936205
delta: 0.8447242398034547
delta: 0.8351819163835756
delta: 0.8256530056468776
delta: 0.8161131474711922
delta: 0.8066009122256759
delta: 0.797094926903

  0%|          | 0/10000 [14:52<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def get_optimal_action(state, optimal_policy):
    """Returns the optimal action for a given state based on the optimal policy.

    Parameters:
    state (int): The current state.
    optimal_policy (dict): The optimal policy containing the action-value pairs for each state.

    Returns:
    int: The optimal action for the given state."""
    
    _, neighbors  = optimal_policy.kd_tree.query([state], k=5)
    simplex = optimal_policy.points[neighbors[0]]
    lambdas = optimal_policy.barycentric_coordinates(state, simplex)
    
    zero = 0 
    one = 0
    
    for i,l in enumerate(lambdas):
    
        if optimal_policy.policy[tuple(simplex[i])][0] > 0:
            zero +=l
        else:
            one +=l
                                 
    return 0 if zero > one else 1


num_episodes = 10000
cartpole = CartPoleEnv(render_mode="human")
max_obs = np.array([0.0, 0.0, 0.0, 0.0])
min_obs = np.array([0.0, 0.0, 0.0, 0.0])
limits = np.array([x_lim, x_dot_lim, theta_lim, theta_dot_lim])
for episode in range(0, num_episodes):
    observation, _ = cartpole.reset()
    for timestep in range(1, 1000):
        action = get_optimal_action(observation, pi)
        observation, reward, done, terminated, info = cartpole.step(action)
        max_obs = np.maximum(max_obs, observation)
        min_obs = np.minimum(min_obs, observation)
        if done:
            #print(f"max_obs: {max_obs}")
            #print(f"min_obs: {min_obs}")
            #check limits
            #if np.all(max_obs <= limits) and np.all(min_obs >= -limits):
            #    print(f"Episode {episode} finished after {timestep} timesteps")
            #else:
            #    print(f"Episode {episode} finished after {timestep} timesteps with out of limits observations")
            break