# DQN in the Cartpole environment

In [None]:
import random

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque

EPISODE_COUNT = 10000
EPISODE_LENGTH = 200
EXPERIENCE_CAPACITY = 5000
OBS_SIZE = 4
USE_SOFT_TARGET_UPDATE = True
TARGET_NET_REFRESH_RATE = 200
MINIBATCH_SIZE = 16

In [None]:
class Q(nn.Module):
    def __init__(self):
        super(Q, self).__init__()
        self.fc1 = nn.Linear(OBS_SIZE + 1, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc4(x)
        return x


def transfer_state(src, dst):
    dst.load_state_dict(src.state_dict())


def argmax_action(net, action_space, state):
    best_i = [0]
    best_out = net.forward(torch.cat((action_space[0], state)))
    for i in range(1, len(action_space)):
        out = net.forward(torch.cat((action_space[i], state)))
        if out > best_out:
            best_i = [i]
            best_out = out
    return best_i[0]


def argmax_action_batch(net, action_space, state):
    best_i = torch.zeros(state.size()[0], 1)
    action_vector_size = action_space[0].size()[0]
    best_out = net.forward(torch.cat(
        (torch.ones(state.size()[0], action_vector_size) * action_space[0], state), 1))
    for i in range(1, len(action_space)):
        out = net.forward(torch.cat(
            (torch.ones(state.size()[0], action_vector_size) * action_space[i], state), 1))
        best_i[out > best_out] = i
    return best_i


def max_action_batch(net, action_space, state):
    action_vector_size = action_space[0].size()[0]
    best_out = net.forward(torch.cat(
        (torch.ones(state.size()[0], action_vector_size) * action_space[0], state), 1))
    for i in range(1, len(action_space)):
        out = net.forward(torch.cat(
            (torch.ones(state.size()[0], action_vector_size) * action_space[i], state), 1))
        best_out[out > best_out] = out[out > best_out]
    return best_out


def empty_transition_block():
    return torch.zeros(1, OBS_SIZE * 2 + 3)


def polyak(src, dst, p):
    for src_p, dst_p in zip(src.parameters(), dst.parameters()):
        dst_p.data.copy_(p * dst_p.data + (1 - p) * src_p.data)


def dqn_target(target_network, action_space, batch):
    return max_action_batch(target_network, action_space, batch[:, OBS_SIZE + 1: 2 * OBS_SIZE + 1])


def double_dqn_target(target_network, q_network, action_space, batch):
    with torch.no_grad():
        next_s_actions = argmax_action_batch(q_network, action_space, batch[:, OBS_SIZE + 1: 2 * OBS_SIZE + 1])
        inputs = torch.cat((next_s_actions, batch[:, OBS_SIZE + 1: 2 * OBS_SIZE + 1]), 1)
    return target_network.forward(inputs)


In [None]:
# Networks, loss fn, optimizer
q = Q()
q_ = Q()
transfer_state(q, q_)

criterion = nn.MSELoss()
optimizer = optim.Adam(q.parameters(), lr=0.001)

# Environment
env = gym.make('CartPole-v0')

# Discount factor
gamma = 0.99

# Replay Memory
transitions = deque(maxlen=EXPERIENCE_CAPACITY)

# List of possible actions
action_values = [torch.tensor([i], dtype=torch.float32) for i in range(env.action_space.n)]

# Replay memory counter, counter for target net update
next_i = 0
c = 0

# Train, test, visualise
for i_episode in range(EPISODE_COUNT):
    observation = torch.Tensor(env.reset())
    eps = max(0.01, ((EPISODE_COUNT - i_episode) / EPISODE_COUNT) * 0.99)
    for t in range(EPISODE_LENGTH):

        # With probability epsilon select a random action a_t
        # otherwise select a_t = argmax_a Q(x_t, a, net parameters)
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            action = argmax_action(q, action_values, observation)

        # Execute action a_t in emulator and observe reward r_t and observation x_t+1
        new_observation, reward, done, info = env.step(action)
        new_observation = torch.Tensor(new_observation)

        # Store transition (a_t, x_t, r_t, x_t+1, d)
        transition_block = empty_transition_block()
        transition_block[0][0] = action
        transition_block[0][1:OBS_SIZE + 1] = observation
        transition_block[0][OBS_SIZE + 1: 2 * OBS_SIZE + 1] = new_observation
        transition_block[0][2 * OBS_SIZE + 1] = reward
        transition_block[0][2 * OBS_SIZE + 2] = 1 if (done or t == EPISODE_LENGTH - 1) else 0
        transitions.append(transition_block)
        observation = new_observation
        next_i += 1
        c += 1

        # If enough memory stored
        if next_i > 500:

            # Sample random minibatch of transitions (x_j, a_j, r_j, x_j+1)
            batch = torch.cat(random.sample(transitions, MINIBATCH_SIZE), 0)
            
            # Set targets
            rs = batch[:, 2 * OBS_SIZE + 1].view(MINIBATCH_SIZE, 1)
            ds = batch[:, 2 * OBS_SIZE + 2].view(MINIBATCH_SIZE, 1)
        
            # # Choose from :
            # # # qs = dqn_target(q_, action_values, batch) for normal DQN
            # # # qs = double_dqn_target(q_, q, action_values, batch) for Double DQN
            qs = dqn_target(q_, action_values, batch)
            ys = rs + (1 - ds) * gamma * qs
            
            # Perform a gradient descent step on (y_j - Q(x_j, a_j))^2
            optimizer.zero_grad()
            outputs = q.forward(batch[:, :OBS_SIZE + 1])
            loss = criterion(outputs, ys)
            loss.backward()
            optimizer.step()

            # Update target network q_
            if USE_SOFT_TARGET_UPDATE:
                polyak(q, q_, 0.9)
            elif c > TARGET_NET_REFRESH_RATE:
                transfer_state(q, q_)
                c = 0

        if done:
            break
        
    if i_episode % 10 == 0 and next_i > 500:
        all_rewards = 0
        count = 100
        for j in range(count):
            observation = torch.Tensor(env.reset())
            rewards = 0
            for t in range(EPISODE_LENGTH):
                # Uncomment for rendering:
                # env.render()

                action = argmax_action(q, action_values, observation)
                observation, reward, done, info = env.step(action)
                rewards += reward
                observation = torch.Tensor(observation)
                if done:
                    break
            all_rewards += rewards / count
        print("trained from %d episodes, current avg test reward: %d" % (i_episode, all_rewards))
        if all_rewards > 195:
            torch.save(q.state_dict(), "DQNCartpole")
            print("Cartpole solved, saved net")
            break

env.close()
