In [172]:
#!pip install gym
#!pip install pygame

In [173]:
import gym
import numpy as np

In [174]:
env = gym.make('CartPole-v1', render_mode='human', new_step_api=True)

In [175]:
def epsilon_greedy_policy(state, Q, epsilon=0.1):
    explore = np.random.binomial(1, epsilon)
    if explore:
        action = env.action_space.sample()
        print('explore')
    else:
        action = np.argmax(Q[state])
        print('exploit')
    return action

In [176]:
def q_learning(state, Q, epsilon=0.1):
    '''
    Inicializar Q(s, a) arbitrariamente ∀s ∈ S, a ∈ A(s)
    Repetir:
        Inicializar s
        done ← False
        Repetir hasta done:
            Con probabilidad ε hacer: (* estrategia ε-greedy *)
                explore: a ← sample(A(s))
                exploit: a ← arg m ́ax Q(s, ·)
            s′, r , done ← step(a)
            Q(s, a) ← Q(s, a) + α(r + γ max Q(s′, ·) − Q(s, a))
            s ← s′
    '''
    
    done = False
    alpha = 0.1
    gamma = 0.9
    while not done:
        action = epsilon_greedy_policy(state, Q, epsilon)
        obs, reward, done, info, _ = env.step(action)
        Q[state, action] = Q[state, action] + alpha*(reward + gamma*np.max(Q[obs, :]) - Q[state, action])
        Q

In [177]:
def optimal_policy(state, Q):
    action = np.argmax(Q[state])
    return action

In [178]:
cart_position_bins = np.linspace(-2.4, 2.4, 6)
cart_position_bins
pole_angle_bins = np.linspace(-.2095, .2095, 4)
pole_angle_bins
pole_top_angles_bins = np.linspace(-0.06983333, 0.06983333, 8)
pole_top_angles_bins
acc_bins = np.linspace(-10, 10, 5)
acc_bins.size


5

In [179]:
def get_state(obs):
    car_pos = np.digitize(obs[0], cart_position_bins)
    car_acc = np.digitize(obs[1], acc_bins)
    pole_ang = np.digitize(obs[2], pole_angle_bins)
    pole_acc = np.digitize(obs[3], acc_bins)
    state = tuple([car_pos, car_acc, pole_ang, pole_acc])
    return state

In [180]:
state = get_state(np.array([-1.4, -2., 0.23, 1.2]))
state

(2, 2, 4, 3)

In [181]:
position_bins_count = cart_position_bins.size + 1
acc_bins_count = acc_bins.size + 1
angle_bins_count = pole_angle_bins.size + 1
print("Valid cart position bins: 0 - ", position_bins_count - 1)
print("Valid cart acceleration bins: 0 - ", acc_bins_count - 1)
print("Valid pole angle bins: 0 - ", angle_bins_count - 1)
print("Valid pole acceleration bins: 0 - ", acc_bins_count - 1)
Q = np.random.random((position_bins_count,acc_bins_count,angle_bins_count,acc_bins_count,2))
Q

Valid cart position bins: 0 -  6
Valid cart acceleration bins: 0 -  5
Valid pole angle bins: 0 -  4
Valid pole acceleration bins: 0 -  5


array([[[[[4.04664620e-01, 2.22215653e-01],
          [7.42784308e-02, 4.25866627e-01],
          [2.47720790e-01, 1.51953854e-01],
          [9.15100698e-02, 1.36017740e-01],
          [3.97613522e-01, 6.74327346e-01],
          [1.23067838e-01, 8.56266308e-01]],

         [[4.50582055e-01, 3.79258167e-01],
          [2.42745587e-01, 4.05587360e-01],
          [7.05657097e-01, 3.60203603e-01],
          [1.31822948e-01, 1.04101320e-01],
          [4.19595349e-01, 3.64350692e-01],
          [9.75046734e-01, 6.56565106e-02]],

         [[8.88146241e-01, 6.67640910e-01],
          [6.31776738e-01, 3.13863625e-01],
          [8.29159350e-01, 5.22290898e-01],
          [9.62263938e-01, 3.86198278e-02],
          [9.63918128e-01, 4.93021721e-01],
          [1.24373333e-01, 1.18183024e-01]],

         [[8.30267085e-01, 9.17763526e-01],
          [9.89912229e-01, 8.99778488e-01],
          [2.51862194e-01, 2.00695443e-01],
          [3.80946209e-01, 9.23505863e-01],
          [5.38339246e-02,

In [182]:
obs = env.reset()
print(obs)
done = False
while not done:
    state = get_state(obs)
    action = epsilon_greedy_policy(state, Q, 0.5)
    obs, reward, done, info, _ = env.step(action)
    print('->', state, action, reward, obs, done, info)
env.close()

[ 0.01866666 -0.0439557   0.04268664  0.04304318]
exploit
-> (3, 2, 2, 3) 1 1.0 [ 0.01778754  0.15052897  0.0435475  -0.23587202] False False
exploit
-> (3, 3, 2, 2) 0 1.0 [ 0.02079812 -0.04518723  0.03883006  0.07022287] False False
exploit
-> (3, 2, 2, 3) 1 1.0 [ 0.01989438  0.14935714  0.04023452 -0.20996054] False False
explore
-> (3, 3, 2, 2) 0 1.0 [ 0.02288152 -0.04631632  0.03603531  0.09513786] False False
exploit
-> (3, 2, 2, 3) 1 1.0 [ 0.02195519  0.14827111  0.03793807 -0.18596171] False False
exploit
-> (3, 3, 2, 2) 0 1.0 [ 0.02492062 -0.04737252  0.03421883  0.11844371] False False
exploit
-> (3, 2, 2, 3) 1 1.0 [ 0.02397317  0.14724286  0.03658771 -0.1632501 ] False False
explore
-> (3, 3, 2, 2) 0 1.0 [ 0.02691802 -0.04838324  0.03332271  0.14074704] False False
explore
-> (3, 2, 2, 3) 0 1.0 [ 0.02595036 -0.2439662   0.03613765  0.4437537 ] False False
explore
-> (3, 2, 2, 3) 0 1.0 [ 0.02107104 -0.43958035  0.04501272  0.7476055 ] False False
explore
-> (3, 2, 2, 3) 0 1.0 