In [179]:
#!pip install gym
#!pip install pygame

In [180]:
import gym
import numpy as np
import time

In [200]:
def epsilon_greedy_policy(state, Q, epsilon=0.1):
    explore = np.random.binomial(1, epsilon)
    if explore:
        action = env.action_space.sample()
        print('explore')
    else:
        action = np.argmax(Q[state])
        print('exploit')
    return action

In [182]:
def q_learning(state, Q, epsilon=0.1):
    '''
    Inicializar Q(s, a) arbitrariamente ∀s ∈ S, a ∈ A(s)
    Repetir:
        Inicializar s
        done ← False
        Repetir hasta done:
            Con probabilidad ε hacer: (* estrategia ε-greedy *)
                explore: a ← sample(A(s))
                exploit: a ← arg m ́ax Q(s, ·)
            s′, r , done ← step(a)
            Q(s, a) ← Q(s, a) + α(r + γ max Q(s′, ·) − Q(s, a))
            s ← s′
    '''
    
    done = False
    alpha = 0.05
    gamma = 0.9
    while not done:
        action = epsilon_greedy_policy(state, Q, epsilon)
        obs, reward, done, info, _ = env.step(action)
        discrete_obs = get_state(obs)
        Q[state][action] = Q[state][action] + alpha*(reward + gamma*np.max(Q[discrete_obs][:]) - Q[state][action])
        state = discrete_obs

In [183]:
def optimal_policy(state, Q):
    action = np.argmax(Q[state])
    return action

# Discretizacion de variables de la observación

Como la posicion del carrito y su aceleracion no son tan relevantes como la posicion de la barra y su velocidad, se discretizan utilizando contenedores de mayor tamaño, siguiendo steps de tamaño 0.25. Por otro lado, tanto la aceleración angular como el ángulo del poste son las variables de mayor importancia, son los que van a variar con una mayor velocidad, por lo que optamos por usar los steps de 0.004 para el angulo, y 0.02 para la aceleracion angular.

In [184]:
cart_position_bins, cart_pos_step = np.linspace(-2.4, 2.4, 20, retstep=True)
cart_acc_bins, cart_acc_step = np.linspace(-10, 10, 80, retstep=True)
pole_angle_bins, pole_angle_step = np.linspace(-.2095,.2095, 100, retstep=True)
angular_acc_bins, angular_acc_step = np.linspace(-10, 10, 1000, retstep=True)
print("cart_pos_step: ", cart_pos_step)
print("cart_acc_step: ", cart_acc_step)
print("pole_angle_step: ", pole_angle_step)
print("angular_acc_step: ", angular_acc_step)

cart_pos_step:  0.25263157894736843
cart_acc_step:  0.25316455696202533
pole_angle_step:  0.004232323232323232
angular_acc_step:  0.02002002002002002


In [185]:
def get_state(obs):
    cart_pos = np.digitize(obs[0], cart_position_bins)
    cart_acc = np.digitize(obs[1], cart_acc_bins)
    pole_ang = np.digitize(obs[2], pole_angle_bins)
    ang_acc = np.digitize(obs[3], angular_acc_bins)
    state = tuple([cart_pos, cart_acc, pole_ang, ang_acc])
    return state

In [186]:
state = get_state(np.array([-1.4, -2., 0.1, -11]))
state

(4, 32, 74, 0)

In [187]:
position_bins_count = cart_position_bins.size + 1
acc_bins_count = cart_acc_bins.size + 1
angle_bins_count = pole_angle_bins.size + 1
angular_acc_bins_count = angular_acc_bins.size + 1
print("Valid cart position bins: 0 - ", position_bins_count - 1)
print("Valid cart acceleration bins: 0 - ", acc_bins_count - 1)
print("Valid pole angle bins: 0 - ", angle_bins_count - 1)
print("Valid pole acceleration bins: 0 - ", angular_acc_bins_count - 1)
Q = np.random.random((position_bins_count,acc_bins_count,angle_bins_count,angular_acc_bins_count,2))
Q.shape

Valid cart position bins: 0 -  20
Valid cart acceleration bins: 0 -  80
Valid pole angle bins: 0 -  100
Valid pole acceleration bins: 0 -  1000


(21, 81, 101, 1001, 2)

In [199]:
env = gym.make('CartPole-v1', render_mode="rgb_array", new_step_api=True)
done = False
max_episodes = 1000
for i in range(max_episodes):
    obs = env.reset()
    q_learning(get_state(obs), Q)
env.close()

explore
explore
explore
explore
explore
explore
explore
state (10, 41, 38, 481)
exploit
explore
state (10, 42, 33, 451)
exploit
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
state (10, 41, 55, 494)
exploit
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
state (10, 40, 49, 499)
exploit
explore
explore
explore
explore
explore
explore
explore
explore
explore
state (10, 43, 32, 439)
exploit
explore
explore
explore
explore
explore
state (10, 40, 44, 500)
exploit
explore
explore
explore
state (10, 39, 52, 528)
exploit
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
explore
state (10, 41, 46, 483)
exploit
explore
explore
explore
state (10, 41, 39, 481)
exploit
explore
explore
explore
explore
explore
explore
explore
state (10, 41, 34, 475)
exploit
state (10, 41, 32, 460)
exploit
explore
explore
explore
explore
state (10, 41, 6

# Ejecución con la policy óptima

In [203]:
time.sleep(1)
max_reward = 500
final_reward= 0
env = gym.make('CartPole-v1', render_mode='human', new_step_api=True)
print("Playing optimal policy")
obs = env.reset()
done = False
while not done and final_reward < max_reward:
    state = get_state(obs)
    action = optimal_policy(state, Q)
    obs, reward, done, info, _ = env.step(action)
    final_reward += reward
    print('->', reward, obs, done)
    time.sleep(0.05)
env.close()
print("Final reward: ", final_reward)

Playing optimal policy
-> 1.0 [ 0.02122001 -0.14896968 -0.03539339  0.23638113] False
-> 1.0 [ 0.01824061  0.04663959 -0.03066576 -0.06725249] False
-> 1.0 [ 0.0191734   0.24218747 -0.03201081 -0.36945072] False
-> 1.0 [ 0.02401715  0.04753461 -0.03939983 -0.0870306 ] False
-> 1.0 [ 0.02496785 -0.14700107 -0.04114044  0.19296592] False
-> 1.0 [ 0.02202782 -0.3415111  -0.03728112  0.4723922 ] False
-> 1.0 [ 0.0151976  -0.5360872  -0.02783328  0.75309545] False
-> 1.0 [ 0.00447586 -0.34059277 -0.01277137  0.45178545] False
-> 1.0 [-0.002336   -0.5355318  -0.00373566  0.74041545] False
-> 1.0 [-0.01304663 -0.34035847  0.01107265  0.44655922] False
-> 1.0 [-0.0198538  -0.1453949   0.02000383  0.1573871 ] False
-> 1.0 [-0.0227617  -0.34079745  0.02315158  0.456313  ] False
-> 1.0 [-0.02957765 -0.53623897  0.03227784  0.75620264] False
-> 1.0 [-0.04030243 -0.7317906   0.04740189  1.0588653 ] False
-> 1.0 [-0.05493824 -0.5373275   0.0685792   0.78142935] False
-> 1.0 [-0.0656848  -0.34321192 