In [None]:
import jax
import jax.numpy as jnp
import numpy as np

import gym
from common.jupyter_animation import animate, animation_table
from typing import Iterable, Union, Callable, Tuple, List, Dict, Any
import time

In [None]:
env = gym.make('CartPole-v1', render_mode="rgb_array")
env.reset()
image_seq = []

In [None]:
done = False
while not done:
  observation, reward, done, truncated, info = env.step(env.action_space.sample())
  image_seq.append(env.render())

animate(image_seq)

In [None]:
env.reset()[0]

In [None]:
def discretize_state(space_obj, num_bins:Iterable[int], return_indexi=False, max_value=1024) -> Callable:
  bins = [np.linspace(max(low, -max_value), min(high, max_value), num=num_bin) for low, high, num_bin in zip(space_obj.low, space_obj.high, num_bins)]

  def discretize_state_fn(input_state:np.ndarray) -> np.ndarray:
    discretized_state = np.zeros(input_state.shape)
    discretized_indexi = np.zeros(input_state.shape, dtype=np.int32)
    for i, bin_space in enumerate(bins):
      discrete_index = np.digitize(input_state[i], bin_space)-1
      discretized_state[i] = bins[i][discrete_index]
      discretized_indexi[i] = discrete_index
    
    ret = discretized_state if not return_indexi else discretized_indexi
    ret = tuple(ret)
    return ret

  return discretize_state_fn

# discret_func = discretize_state(env.observation_space, num_bins=(10, 10, 10, 10))
# discret_func(np.array([-4.8, -100, 1, -10]))
# discret_func(np.array([5, 2000, 0.1, 0]))

In [None]:
# Hyperparameters
alpha = 0.1
gamma = 0.6
epsilon = 0.1
max_epoch = 100000

discret_func = discretize_state(env.observation_space, num_bins=(30, 30, 50, 50), return_indexi=True)
q_table = np.zeros((30, 30, 50, 50, 2))
all_steps_taken = []
truncated = False


for i in range(1, max_epoch):
  state = discret_func(env.reset()[0])
  done = False
  steps_taken = 0

  # trace
  while not done and not truncated:
    if np.random.uniform(0 ,1) < epsilon:
      action = env.action_space.sample()
    else:
      action = np.argmax(q_table[state])

    next_state, reward, done, truncated, info = env.step(action)
    next_state = discret_func(next_state)

    q_value = q_table[state, action]
    next_max = np.max(q_table[next_state])

    new_q_value = (1 - alpha) * q_value + alpha * (reward + gamma * next_max)
    q_table[state, action] = new_q_value

    state = next_state
    steps_taken += 1
  
  all_steps_taken.append(steps_taken)
  
  if i % 100 == 0:
    print(f"Average steps taken for the past 10 epochs: {sum(all_steps_taken[-10:]) / 10}")

print("Training finished")


In [None]:
print("total elements: ", q_table.size)
print("q vlaue >= 0.5: ", np.count_nonzero(q_table >= 0.5))
print("q vlaue < 0.5: ", np.count_nonzero(q_table < 0.5))