In [1]:
import numpy as np
import random
import time
import copy
import gymnasium

import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
class QNet(nn.Module):
    def __init__(self,hidden_dim=48,n_a=2):
        super().__init__()
        
        self.hidden = nn.Linear(4, hidden_dim)
        self.output = nn.Linear(hidden_dim, 2)
    
    def forward(self, s):
        outs = self.hidden(s)
        outs = F.relu(outs)
        outs = self.output(outs)
        return outs

In [3]:
env = gymnasium.make("CartPole-v1")

In [4]:
q_net = QNet(hidden_dim=32, n_a=env.action_space.n)
q_target= QNet(hidden_dim=32, n_a=env.action_space.n)
q_target.load_state_dict(q_net.state_dict())
q_target.requires_grad_(False)

QNet(
  (hidden): Linear(in_features=4, out_features=32, bias=True)
  (output): Linear(in_features=32, out_features=2, bias=True)
)

In [5]:
class ReplayBuffer:
    def __init__(self, buffer_size: int):
        self.buffer_size = buffer_size
        self.buffer = []
    
    def add(self, item):
        if len(self.buffer) == self.buffer_size:
            self.buffer.pop(0)
        self.buffer.append(item)
    
    def sample(self, sample_size):
        sample_size = min(len(self.buffer),sample_size)
        items = random.sample(self.buffer, sample_size)
        
        states = [i[0] for i in items]
        actions = [i[1] for i in items]
        rewards = [i[2] for i in items]
        n_states = [i[3] for i in items]
        dones = [i[4] for i in items]
        
        states = torch.tensor(states, dtype=torch.float)
        actions = torch.tensor(actions, dtype=torch.float)
        n_states = torch.tensor(n_states, dtype=torch.float)
        rewards = torch.tensor(rewards, dtype=torch.float)
        dones = torch.tensor(dones, dtype=torch.float)
        
        return states, actions, rewards, n_states, dones

memory = ReplayBuffer(buffer_size=10000)

In [6]:
gamma = 0.99

opt = torch.optim.Adam(q_net.parameters(), lr=0.0001)

def optimize(states, actions, rewards, next_states, dones):
    
    with torch.no_grad():
        target_all_axns = q_target(next_states)
        target_max_axns = torch.argmax(target_all_axns, 1)
        target_axns_one_hot = F.one_hot(target_max_axns, env.action_space.n).float()
        target_vals = torch.sum(target_all_axns*target_axns_one_hot, 1)
        target_vals_masked = (1.0 - dones)*target_vals
        q_vals1 = rewards + gamma*target_vals_masked
    
    opt.zero_grad()
    actions_one_hot = F.one_hot(actions.to(int), env.action_space.n).float()
    q_vals2 = torch.sum(q_net(states)*actions_one_hot, 1)
    
    loss = F.mse_loss(q_vals1.detach(), q_vals2, reduction="mean")
    loss.backward()
    opt.step()

In [7]:
sampling_size = 64 * 30
batch_size = 64

epsilon = 1.0
epsilon_decay = epsilon / 3000
epsilon_final = 0.1

In [8]:
def pick_sample(s, epsilon):
    with torch.no_grad():
        # get optimal action,
        # but with greedy exploration (to prevent picking up same values in the first stage)
        if np.random.random() > epsilon:
            s_batch = torch.tensor(s, dtype=torch.float)
            s_batch = s_batch.unsqueeze(dim=0)  # to make batch with size=1
            q_vals_for_all_actions = q_net(s_batch)
            a = torch.argmax(q_vals_for_all_actions, 1)
            a = a.squeeze(dim=0)
            a = a.tolist()
        else:
            a = np.random.randint(0, env.action_space.n)
        return a

In [9]:
# evaluate current agent with no exploration
def evaluate():
    with torch.no_grad():
        s, _ = env.reset()
        done = False
        total = 0
        while not done:
            a = pick_sample(s, 0.0)
            s_next, r, term, trunc, _ = env.step(a)
            done = term or trunc
            total += r
            s = s_next
        return total

In [10]:
reward_records = []
for _ in range(15000):
    # Run episode till it picks up 500 samples
    # (All samples are stored in reply memory.)
    done = True
    for _ in range(500):
        if done:
            s, _ = env.reset()
            done = False
            cum_reward = 0

        a = pick_sample(s, epsilon)
        s_next, r, term, trunc, _ = env.step(a)
        done = term or trunc
        memory.add([s.tolist(), a, r, s_next.tolist(), float(term)])  # (see above note for truncation)
        cum_reward += r
        s = s_next

    # Init replay memory (without updates) till 2000 samples are filled
    if len(memory.buffer) < 2000:
        continue

    # Optimize Q-network with a batch from replay memory
    states, actions, rewards, n_states, dones = memory.sample(sampling_size)
    states = torch.reshape(states, (-1, batch_size, 4))
    actions = torch.reshape(actions, (-1, batch_size))
    rewards = torch.reshape(rewards, (-1, batch_size))
    n_states = torch.reshape(n_states, (-1, batch_size, 4))
    dones = torch.reshape(dones, (-1, batch_size))
    for j in range(actions.size(dim=0)):
        optimize(states[j], actions[j], rewards[j], n_states[j], dones[j])
    total_reward = evaluate()
    reward_records.append(total_reward)
    iteration_num = len(reward_records)
    print("Run iteration {} rewards {:3} epsilon {:1.5f}".format(iteration_num, total_reward, epsilon), end="\r")

    # Clone Q-network to obtain target
    if iteration_num % 50 == 0:
        q_target.load_state_dict(q_net.state_dict())

    # Update epsilon
    if epsilon - epsilon_decay >= epsilon_final:
        epsilon -= epsilon_decay

    # stop if reward mean > 495.0
    if np.average(reward_records[-200:]) > 495.0:
        break

env.close()
print("\nDone")

Run iteration 2505 rewards 356.0 epsilon 0.16533

KeyboardInterrupt: 