In [1]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
import gym
from IPython import display as ipythondisplay
from PIL import Image

env = gym.make('CartPole-v1', render_mode="rgb_array")

In [3]:
observation = env.reset()
print(observation[0])

[-0.04176844 -0.03451551  0.01290812  0.02217063]


In [4]:
env.reset()
images = []

for _ in range(50):
    screen = env.render()
    images += [Image.fromarray(screen)]
    action = env.action_space.sample()
    observation, reward, terminated, _, info = env.step(action)

    print(observation, reward, terminated, info)

[ 0.04518548  0.23516513 -0.02960301 -0.30295032] 1.0 False {}
[ 0.04988878  0.4306962  -0.03566202 -0.60482043] 1.0 False {}
[ 0.0585027   0.62629825 -0.04775843 -0.9085194 ] 1.0 False {}
[ 0.07102867  0.43185416 -0.06592882 -0.6312214 ] 1.0 False {}
[ 0.07966575  0.23771104 -0.07855324 -0.36000845] 1.0 False {}
[ 0.08441997  0.43385664 -0.08575341 -0.67639047] 1.0 False {}
[ 0.09309711  0.6300588  -0.09928122 -0.99479294] 1.0 False {}
[ 0.10569828  0.43639493 -0.11917708 -0.7348683 ] 1.0 False {}
[ 0.11442618  0.2431033  -0.13387445 -0.48194262] 1.0 False {}
[ 0.11928824  0.43983564 -0.14351329 -0.8136423 ] 1.0 False {}
[ 0.12808496  0.63660073 -0.15978615 -1.147804  ] 1.0 False {}
[ 0.14081697  0.83340645 -0.18274222 -1.4860294 ] 1.0 False {}
[ 0.1574851   1.0302241  -0.21246281 -1.8297678 ] 1.0 True {}
[ 0.17808959  1.2269733  -0.24905817 -2.1802506 ] 0.0 True {}
[ 0.20262904  1.0349877  -0.2926632  -1.9736264 ] 0.0 True {}
[ 0.2233288   0.84389484 -0.3321357  -1.7839953 ] 0.0 True

  if not isinstance(terminated, (bool, np.bool8)):
  logger.warn(


In [5]:
def show_gif(images: list) -> None:
    image_file = 'test.gif'
    images[0].save(image_file, save_all=True, append_images=images[1:], loop=0, duration=1)

In [6]:
env.reset()
images = []

for _ in range(50):
    screen = env.render()
    images += [Image.fromarray(screen)]
    action = env.action_space.sample()
    observation, reward, terminated, _, info = env.step(action)
    if terminated:
        break

show_gif(images)

In [21]:
print(env.observation_space.low)
print(env.observation_space.high)
print(env.action_space.n)
print(env.action_space.sample())

[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38]
[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]
2
0


In [8]:
n_bins_pos = 10
n_bins_vel = 10
n_bins_ang = 10
n_bins_anv = 10
n_states = n_bins_pos * n_bins_vel * n_bins_ang * n_bins_anv
n_actions = env.action_space.n

Q_table = np.random.uniform(0, 1, (n_states, n_actions))

In [9]:
Q_table.shape

(10000, 2)

In [10]:
env.reset()
images = []
observation, _, _, _, _ = env.step(0)
env.close()

pos, vel, ang, anv = observation
print(observation)

[ 0.04778298 -0.219906   -0.00724872  0.24261995]


In [24]:
def map_discrete_state(state):
    pos, vel, ang, anv = state
    idx_pos = np.where(np.histogram(np.clip(pos, -4, 4), bins=n_bins_pos, range=(-4, 4))[0] == 1)[0][0]
    idx_vel = np.where(np.histogram(np.clip(vel, -2, 2), bins=n_bins_vel, range=(-2, 2))[0] == 1)[0][0]
    idx_ang = np.where(np.histogram(np.clip(ang, -0.4, 0.4), bins=n_bins_ang, range=(-0.4, 0.4))[0] == 1)[0][0]
    idx_anv = np.where(np.histogram(np.clip(anv, -2, 2), bins=n_bins_anv, range=(-2, 2))[0] == 1)[0][0]

    states = np.zeros([n_bins_pos, n_bins_vel, n_bins_ang, n_bins_anv])
    states[idx_pos, idx_vel, idx_ang, idx_anv] = 1
    states = states.reshape(-1, 1)

    s = np.where(states == 1)[0][0]
    return s

In [31]:
alpha = 0.3
gamma = 0.9

Q_table = np.random.uniform(0, 1, (n_states, n_actions))

for episode in range(2001):
    state = env.reset()
    state = state[0]

    count = 0

    terminated = False

    while not terminated:
        s = map_discrete_state(state)
        count += 1
        epsilon = 0.1
        if np.random.uniform() < epsilon:
            a = env.action_space.sample()
        else:
            a = np.argmax(Q_table[s, :])
        
        next_state, reward, terminated, _, _ = env.step(a)

        if terminated:
            reward = -100
            Q_table[s, a] = reward
        else:
            next_s = map_discrete_state(next_state)
            Q_table[s, a] = alpha * Q_table[s, a] + alpha * (reward + np.max(Q_table[next_s, :]))
        state = next_state
    if episode % 100 == 0:
        print(episode, count)

env.close()



0 14
100 99
200 149
300 160
400 739
500 203
600 311
700 152
800 606
900 128
1000 523
1100 106
1200 253
1300 263
1400 1022
1500 252
1600 149
1700 69
1800 294
1900 185
2000 207


In [32]:
state = env.reset()
state = tuple(state[0])
images = []

terminated = False
while not terminated:
    screen = env.render()
    images += [Image.fromarray(screen)]
    s = map_discrete_state(state)
    a = np.argmax(Q_table[s, :])
    next_state, _, terminated, _, _ = env.step(a)

    state = next_state

env.close()
show_gif(images=images)
    