In [2]:
from typing import Dict, Any, List, Union, Optional, Tuple
from dataclasses import dataclass, field

import gymnasium as gym
import random
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as pyplot
from gymnasium import Env, Space
from gymnasium.spaces import Box
from numpy import ndarray

In [23]:
# Dataclasses are dope: https://docs.python.org/3/library/dataclasses.html
@dataclass
class QLearner:
    """Q Learner Model for the cart pole problem"""
    ###############################################
    # Hyperparameters
    ###############################################
    num_episodes: int = 10000
    progress_capture_rate: int = 1000
    exploration_min: float = 0.1
    exploration_decay: float = 0.001
    learning_rate: float = 0.01
    cart_velocity_bound: float = 3
    pole_velocity_bound: float = 3
    num_bins: int = 100
    discount: float = 0.95

    ###############################################
    # Q Learning parameters
    ###############################################
    q_table: Dict[ndarray, List[float]] = field(default_factory=lambda: defaultdict(lambda: [0.0, 0.0]))
    exploration_prob: float = 0.0
    max_steps: int = 0

    ###############################################
    # Gymnasium parameters
    ###############################################
    env: Env = field(init=False)

    ###############################################
    # Training analysis parameters
    ###############################################
    episode_history: List[int] = field(default_factory=list)
    progress_x: List[int] = field(default_factory=list)
    progress_avg: List[float] = field(default_factory=list)
    progress_max: List[float] = field(default_factory=list)
    render_mode: Optional[str] = None
    print_progress: bool = True
    max_abs_cart_velocity: float = 0.0
    max_abs_pole_velocity: float = 0.0

    def __post_init__(self):
        self.env: Env = gym.make(
            'CartPole-v1',
            render_mode=self.render_mode,
        )
        self.state_space: Box = self.env.observation_space  # type: ignore
        self.state_bin_bounds = self.state_space.high

        # Explicitly set our cart velocity bound
        self.state_bin_bounds[1] = self.cart_velocity_bound
        self.state_bin_bounds[3] = self.pole_velocity_bound

        self.state_bins = np.linspace(
            self.state_bin_bounds * -1,
            self.state_bin_bounds,
            self.num_bins,
            endpoint=True).T

    def determine_action(self, discrete_state: np.ndarray) -> int:
        """Return a 0 or 1 based on the policy"""
        if np.random.uniform(0, 1) < self.exploration_prob:
            # Random action
            action = self.env.action_space.sample()
        else:
            action = np.argmax(self.q_table[discrete_state])
        return action

    def discretize(self, state: np.ndarray) -> Tuple:
        """Given a state, return a discretized version of the state via binning"""
        return tuple(np.digitize(state[i], self.state_bins[i], right=True) for i in range(state.shape[0]))

    def update_q_table(self, discrete_state, action, reward):
        old_q = self.q_table[discrete_state][action]
        new_q = reward + self.discount * max(self.q_table[discrete_state])
        print(f"old: {old_q}, new: {new_q}")
        self.q_table[discrete_state][action] = (1 - self.learning_rate) * old_q + self.learning_rate * new_q

    def run(self):
        # type hints
        state: np.ndarray
        info: Dict[str, Any]

        # fresh start, morty!
        state, info = self.env.reset()

        # TODO: Extract logical methods
        for episode in range(self.num_episodes):
            terminated = False
            truncated = False
            steps = 0
            while not terminated and not truncated:
                steps += 1
                # Discretize each value in the state
                discrete_state = self.discretize(state)
                action = self.determine_action(discrete_state)
                state, reward, terminated, truncated, info = self.env.step(action)

                # Learn!
                self.update_q_table(discrete_state, action, reward)

                # Episode progress capture
                self.max_abs_cart_velocity = max(self.max_abs_cart_velocity, abs(state[1]))
                self.max_abs_pole_velocity = max(self.max_abs_pole_velocity, abs(state[3]))


            self.episode_history.append(steps)

            self.exploration_prob = max(self.exploration_min,
                                        np.exp(-self.exploration_prob * self.exploration_decay))

            if episode % self.progress_capture_rate == 0:
                self.max_steps = max(steps, self.max_steps)
                avg = sum(self.episode_history) / len(self.episode_history)
                print(f"[{episode}] Max: {self.max_steps}, Avg: {avg}")
                self.episode_history = []
                self.progress_x.append(episode)
                self.progress_avg.append(avg)
                self.progress_max.append(self.max_steps)

            state, info = self.env.reset()

        self.env.close()

    def plot(self):
        pyplot.plot(self.progress_x, self.progress_avg, label="avg")
        pyplot.plot(self.progress_x, self.progress_max, label="max")

q = QLearner(
    num_episodes=10
)
q.run()
q.plot()

KeyboardInterrupt: 

In [4]:
# q.max_abs_pole_velocity

1.5022974