In [2]:
import random
import numpy as np
import gymnasium as gym

In [3]:
def Q(state, action, q_weights):
    return np.dot(np.append(state, action), q_weights)

In [4]:
def epsilon_greedy(env, obs, q_weights, epsilon):
    action = None
    e = random.random()
    if e < epsilon:
        action = env.action_space.sample()
    else:
        possible_actions = range(env.action_space.n)
        q_values = {Q(obs, a, q_weights) : a for a in possible_actions}
        action = q_values[max(q_values)]
    return action

In [5]:
def calculate_target(env, transition, q_weights, discount_rate):
    _, _, t_reward, t_new_obs, t_terminal = transition
    if t_terminal:
        target = t_reward
    else:
        possible_actions = range(env.action_space.n)
        max_Q = max([Q(t_new_obs, a, q_weights) for a in possible_actions])
        target = transition[2] + discount_rate * max_Q
    return target

In [57]:
def gradient_descent(target, q_weights, obs, action, learning_rate):
    q_grad = np.append(obs, action)
    loss_grad = (target - Q(obs, action, q_weights)) * q_grad
    new_weights = np.asarray(q_weights + learning_rate * loss_grad).clip(-1, 1)
    return new_weights

In [58]:
MAX_EPISODES = 100
MAX_TIMESTEPS = 100
MINIBATCH_SIZE = 5
EPSILON = .1
DISCOUNT_RATE = .9
LEARNING_RATE = .1

In [59]:
rng = np.random.default_rng()
env = gym.make("CartPole-v1", render_mode="human")
obs, info = env.reset()

In [60]:
import time

replay_memory = []
q_weights = rng.random(env.observation_space.shape[0] + 1)

for episode in range(MAX_EPISODES):
    states = [obs]
    for t in range(MAX_TIMESTEPS):
        # sample and take action
        action = epsilon_greedy(env, obs, q_weights, EPSILON)
        new_obs, reward, terminated, truncated, info = env.step(action)
        # sample minibatch and update q_weights
        replay_memory.append([obs, action, reward, new_obs, terminated or truncated])
        minibatch = random.sample(replay_memory, min(MINIBATCH_SIZE, len(replay_memory)))
        for transition in minibatch:
            target = calculate_target(env, transition, q_weights, DISCOUNT_RATE)
            print(f'old weights: {q_weights}')
            print(f'old loss: {(target - Q(obs, action, q_weights))**2}')
            t_obs, t_action, _,_, _ = transition
            q_weights = gradient_descent(target, q_weights, t_obs, t_action, LEARNING_RATE / (episode+1))
            print(f'new weights: {q_weights}')
            print(f'new loss: {(target - Q(obs, action, q_weights))**2}')
            time.sleep(.1)
        obs = new_obs

old weights: [0.3687442  0.77854844 0.61591571 0.23296671 0.84789422]
old loss: 0.9852925673320838
new weights: [0.37225355 0.77369605 0.61452958 0.23463102 0.94715613]
new loss: 0.7973572017831401
old weights: [0.37225355 0.77369605 0.61452958 0.23463102 0.94715613]
old loss: 0.9433569220449237
new weights: [0.37559244 0.78791875 0.61320584 0.2074075  1.        ]
new loss: 0.8255099551288028
old weights: [0.37559244 0.78791875 0.61320584 0.2074075  1.        ]
old loss: 0.7989144091480342
new weights: [0.37907818 0.78309899 0.61182903 0.20906061 1.        ]
new loss: 0.80075769873861
old weights: [0.37907818 0.78309899 0.61182903 0.20906061 1.        ]
old loss: 0.9265045706842542
new weights: [0.38266901 0.81599402 0.60997759 0.15349839 1.        ]
new loss: 0.8446841323546554
old weights: [0.38266901 0.81599402 0.60997759 0.15349839 1.        ]
old loss: 0.6089843612814918
new weights: [0.38622383 0.81107874 0.6085735  0.15518426 1.        ]
new loss: 0.6128820187876683
old weights:

  logger.warn(


old weights: [ 0.49758395  0.39521798  0.55077937 -0.02978586  0.60556904]
old loss: 0.1342333600313975
new weights: [ 0.50104852  0.40997605  0.5494058  -0.05803414  0.70635169]
new loss: 0.02574064843176874
old weights: [ 0.50104852  0.40997605  0.5494058  -0.05803414  0.70635169]
old loss: 2.33716094387094
new weights: [0.47546756 0.14813761 0.58442887 0.3694334  0.55347393]
new loss: 0.07846818737677524
old weights: [0.47546756 0.14813761 0.58442887 0.3694334  0.55347393]
old loss: 1.7266320747321875
new weights: [0.48173023 0.23175033 0.57795688 0.23564042 0.64353737]
new loss: 0.49577057930475216
old weights: [0.48173023 0.23175033 0.57795688 0.23564042 0.64353737]
old loss: 1.508761055049875
new weights: [0.48495801 0.22728726 0.57668197 0.23717119 0.73483525]
new loss: 1.3181579385906679
old weights: [0.48495801 0.22728726 0.57668197 0.23717119 0.73483525]
old loss: 0.1261516966888942
new weights: [0.4780977  0.15706736 0.58607446 0.35180955 0.69383645]
new loss: 0.034184622867

KeyboardInterrupt: 

In [29]:
env = gym.make("CartPole-v1", render_mode="human")
obs, _ = env.reset()
for _ in range(100):
    action = epsilon_greedy(env, obs, q_weights, EPSILON)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        obs, _ = env.reset()

env.close()