# RL Gymnasium

In [101]:
import gymnasium as gym
import jax.numpy as jnp
from jax import random
import numpy as np
from tqdm import tqdm
    
class cartpole:
    qtable = []
    n_actions = []
    PRNGKey = 0
    
    def __init__(self):
        self.env = gym.make("CartPole-v1", render_mode="human")
        self.n_actions = np.zeros(env.action_space.n)
        self.qtable = np.zeros((4, 4, 2)) + 0.5
        # self.PRNGKey = random.PRNGKey(0)
        self.rng = np.random.default_rng()

    def get_state(self, obs):
        state_pos = [self.env.observation_space.low[0]/2, self.env.observation_space.high[0]/2]
        state_ang = [self.env.observation_space.low[2]/2, self.env.observation_space.high[2]/2]
        pos, ang = [int((n_pos) * (obs[0] - state_pos[0]) / (state_pos[1] - state_pos[0])),
                    int((n_ang) * (obs[2] - state_ang[0]) / (state_ang[1] - state_ang[0]))]
        return [0 if pos < 0 else 3 if pos > 3 else pos, 0 if ang < 0 else 3 if ang > 3 else ang]
        
    def choose_action(self, observation, eps=0.3):
        state = self.get_state(observation)
        if self.get_random() < eps:
            return env.action_space.sample()
        return int(np.argmax(self.qtable[state[0]][state[1]]))

    def get_random(self):
        # old_key, new_key = random.split(self.PRNGKey)
        # self.PRNGKey = new_key
        # return random.uniform(old_key)
        return self.rng.random()
        
        
class episodic(cartpole):
    def __init__(self):
        super().__init__()
        
    def train(self, episodes=1000):
        for _ in tqdm(range(episodes)):
            state_action = []
            rewards = 0
            observation, info = env.reset()
            terminated, truncated = 0, 0
            while(not (terminated or truncated)):
                action = self.choose_action(observation)
                state_action.append((self.get_state(observation), action))
                observation, reward, terminated, truncated, info = env.step(action)
                rewards += reward
            self.update_qtable(state_action, reward)
        env.close()

    def update_qtable(self, state_action, reward):
        for s, a in state_action:
            self.n_actions += 1
            self.qtable[*s, a] += (reward - self.qtable[*s, a]) / self.n_actions[a]
            # self.n_actions = self.n_actions.at[a].set(self.n_actions[a] + 1)
            # self.qtable = self.qtable.at[*s, a].set(self.qtable[*s, a] + (reward - self.qtable[*s, a]) / self.n_actions[a])

class continuous(cartpole):
    def __init__(self):
        super().__init__()

    def train(self, max_iter=100000, gamma=0.9):
        observation, info = env.reset()
        for _ in tqdm(range(max_iter)):
            action = self.choose_action(observation)
            observation, reward, terminated, truncated, info = env.step(action)
            self.update_qtable(observation, action, -1 if terminated else reward, gamma)
            if terminated or truncated:
                observation, info = env.reset()

    def update_qtable(self, observation, action, reward, gamma):
        state = self.get_state(observation)
        s, a = state, action
        self.n_actions += 1
        self.qtable[*s, a] += gamma * (reward - self.qtable[*s, a]) / self.n_actions[a]
        # self.n_actions = self.n_actions.at[a].set(self.n_actions[a] + 1)
        # self.qtable = self.qtable.at[*s, a].set(self.qtable[*s, a] + gamma * (reward - self.qtable[*s, a]) / self.n_actions[a])

In [102]:
episodic_pole = episodic()
continuous_pole = continuous()

In [None]:
episodic_pole.train()

In [104]:
env = gym.make("CartPole-v1", render_mode="human")
continuous_pole.train()

100%|███████████████████████████████████████████| 100000/100000 [36:08<00:00, 46.11it/s]


In [85]:
obs, info = episodic_pole.env.reset()
s = episodic_pole.get_state(obs)

In [99]:
qtable1 = episodic_pole.qtable

In [88]:
episodic_pole.qtable[*s, 0]

1.0

In [106]:
qtable2 = continuous_pole.qtable

In [107]:
print(qtable1)
print(qtable2)

[[[0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.5        0.5       ]
  [0.82676477 0.58989917]
  [0.97249807 0.6827857 ]
  [0.90831967 0.59039693]]

 [[0.50056495 0.50127997]
  [1.         0.53738949]
  [0.76266717 0.53865689]
  [0.50133477 0.5       ]]

 [[0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]]]
[[[0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.5052258  0.50103514]
  [0.98918956 0.5394394 ]
  [0.97538865 0.71112578]
  [0.52468836 0.52467087]]

 [[0.51463924 0.5016634 ]
  [0.76027548 0.56012859]
  [0.80643053 0.56768878]
  [0.50852782 0.51707979]]

 [[0.50063688 0.50005795]
  [0.5004586  0.50001064]
  [0.50021607 0.50001063]
  [0.4999857  0.4998891 ]]]


In [114]:
observation, info = env.reset()
for _ in range(1000):
    # action = episodic_pole.choose_action(observation)
    action = continuous_pole.choose_action(observation)
    observation, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        observation, info = env.reset()
        break