# The PTAN CartPole solver

In [1]:
import gym
import ptan
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


HIDDEN_SIZE = 128
BATCH_SIZE = 16
TGT_NET_SYNC = 10
GAMMA = 0.9
REPLAY_SIZE = 1000
LR = 1e-3
EPS_DECAY=0.99


class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x):
        return self.net(x.float())

In [2]:
@torch.no_grad()
def unpack_batch(batch, net, gamma):
    states = []
    actions = []
    rewards = []
    done_masks = []
    last_states = []
    for exp in batch:
        states.append(exp.state)
        actions.append(exp.action)
        rewards.append(exp.reward)
        done_masks.append(exp.last_state is None)
        if exp.last_state is None:
            last_states.append(exp.state)
        else:
            last_states.append(exp.last_state)

    states_v = torch.tensor(states)
    actions_v = torch.tensor(actions)
    rewards_v = torch.tensor(rewards)
    last_states_v = torch.tensor(last_states)
    last_state_q_v = net(last_states_v)
    best_last_q_v = torch.max(last_state_q_v, dim=1)[0]
    best_last_q_v[done_masks] = 0.0
    return states_v, actions_v, best_last_q_v * gamma + rewards_v

In [3]:
env = gym.make("CartPole-v0")
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Net(obs_size, HIDDEN_SIZE, n_actions)
tgt_net = ptan.agent.TargetNet(net)
selector = ptan.actions.ArgmaxActionSelector()
selector = ptan.actions.EpsilonGreedyActionSelector(
    epsilon=1, selector=selector)
agent = ptan.agent.DQNAgent(net, selector)
exp_source = ptan.experience.ExperienceSourceFirstLast(
    env, agent, gamma=GAMMA)
buffer = ptan.experience.ExperienceReplayBuffer(
    exp_source, buffer_size=REPLAY_SIZE)
optimizer = optim.Adam(net.parameters(), LR)

step = 0
episode = 0
solved = False

In the beginning, we create the NN (the simple two-layer feed-forward NN that
we used for CartPole before) and target the NN epsilon-greedy action selector and
DQNAgent. Then the experience source and replay buffer are created. With those few
lines, we have finished with our data pipeline. Now we just need to call populate()
on the buffer and sample training batches from it.

In the beginning of every training loop iteration, we ask the buffer to fetch one
sample from the experience source and then check for the finished episode. The
method pop_rewards_steps() in the ExperienceSource class returns the list of
tuples with information about episodes completed since the last call to the method.

Later in the training loop, we convert a batch of ExperienceFirstLast objects into
tensors suitable for DQN training, calculate the loss, and do a backpropagation
step. Finally, we decay epsilon in our action selector (with the hyperparameters
used, epsilon decays to zero at training step 500) and ask the target network to sync
every 10 training iterations.

In [8]:
while True:
    step += 1
    buffer.populate(1)

    for reward, steps in exp_source.pop_rewards_steps():
        episode += 1
        print("%d: episode %d done, reward=%.3f, epsilon=%.2f" % (
            step, episode, reward, selector.epsilon))
        solved = reward > 150
    if solved:
        print("Congrats!")
        break

    if len(buffer) < 2*BATCH_SIZE:
        continue

    batch = buffer.sample(BATCH_SIZE)
    states_v, actions_v, tgt_q_v = unpack_batch(
        batch, tgt_net.target_model, GAMMA)
    optimizer.zero_grad()
    q_v = net(states_v)
    q_v = q_v.gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    loss_v = F.mse_loss(q_v, tgt_q_v)
    loss_v.backward()
    optimizer.step()
    selector.epsilon *= EPS_DECAY

    if step % TGT_NET_SYNC == 0:
        tgt_net.sync()

14: episode 1 done, reward=13.000, epsilon=1.00
30: episode 2 done, reward=16.000, epsilon=1.00
55: episode 3 done, reward=25.000, epsilon=0.79
67: episode 4 done, reward=12.000, epsilon=0.70
80: episode 5 done, reward=13.000, epsilon=0.62
91: episode 6 done, reward=11.000, epsilon=0.55
104: episode 7 done, reward=13.000, epsilon=0.48
116: episode 8 done, reward=12.000, epsilon=0.43
140: episode 9 done, reward=24.000, epsilon=0.34
153: episode 10 done, reward=13.000, epsilon=0.30
164: episode 11 done, reward=11.000, epsilon=0.27
174: episode 12 done, reward=10.000, epsilon=0.24
185: episode 13 done, reward=11.000, epsilon=0.21
194: episode 14 done, reward=9.000, epsilon=0.20
203: episode 15 done, reward=9.000, epsilon=0.18
213: episode 16 done, reward=10.000, epsilon=0.16
221: episode 17 done, reward=8.000, epsilon=0.15
229: episode 18 done, reward=8.000, epsilon=0.14
237: episode 19 done, reward=8.000, epsilon=0.13
250: episode 20 done, reward=13.000, epsilon=0.11
259: episode 21 done

In [11]:
??torch.gather