# Duelling DQN with Importance Sampling in the Cartpole Environment

In [None]:
import random

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

EPISODE_COUNT = 10000
EPISODE_LENGTH = 200
EXPERIENCE_CAPACITY = 5000
OBS_SIZE = 4
USE_SOFT_TARGET_UPDATE = True
TARGET_NET_REFRESH_RATE = 200
MINIBATCH_SIZE = 16

In [None]:
class DuelQ(nn.Module):
    def __init__(self, action_count):
        super(DuelQ, self).__init__()
        self.fc1 = nn.Linear(OBS_SIZE, 64)
        self.v1 = nn.Linear(64, 64)
        self.v2 = nn.Linear(64, 1)
        self.a1 = nn.Linear(64, 64)
        self.a2 = nn.Linear(64, action_count)
        
    def forward(self, x):
        x_last_dim = x.dim() - 1
        x = F.relu(self.fc1(x))
        v_out = self.v2(F.relu(self.v1(x)))
        a_out = self.a2(F.relu(self.a1(x)))
        a_max = torch.max(a_out, dim=x_last_dim)[0]
        return v_out + a_out - a_max.view(torch.numel(a_max), 1)


def transfer_state(src, dst):
    dst.load_state_dict(src.state_dict())


def duel_argmax_action(net,state):
    with torch.no_grad():
        argmax = torch.max(net.forward(state), 1)[1].int().item()
    return argmax


def duel_argmax_action_batch(net, state):
    out = net.forward(state)
    return torch.max(out, 1)[1].unsqueeze(1).int()


def empty_transition_block():
    return torch.zeros(1, OBS_SIZE * 2 + 3)


def polyak(src, dst, p):
    for src_p, dst_p in zip(src.parameters(), dst.parameters()):
        dst_p.data.copy_(p * dst_p.data + (1 - p) * src_p.data)


def duel_dqn_target(target_network, batch):
    with torch.no_grad():
        out = target_network.forward(batch[:, OBS_SIZE + 1: 2 * OBS_SIZE + 1])
    return torch.max(out, 1)[0].unsqueeze(1)


def duel_double_dqn_target(target_network, q_network, batch):
    with torch.no_grad():
        next_s_actions = duel_argmax_action_batch(q_network, batch[:, OBS_SIZE + 1: 2 * OBS_SIZE + 1])
        out = target_network.forward(batch[:, OBS_SIZE + 1: 2*OBS_SIZE + 1])
    return out.gather(1, next_s_actions.long())


In [None]:
class ImportanceWeighedSample(object):
    def __init__(self, _isb, samples, weights, indices, i):
        self.batch = samples
        self.weights = weights
        self.__isb = _isb
        self.__indices = indices
        self.i = i
        
    def update(self, losses):
        assert self.__isb.i == self.i
        self.__isb.update_priorities(losses, self.__indices)
        self.__isb.update_beta()
        

class ImportanceSamplingBuffer(object):
    def __init__(self, alpha, initial_beta):
        assert 0 <= initial_beta <= 1
        self.capacity = EXPERIENCE_CAPACITY
        self.alpha = alpha
        self.transitions = torch.zeros((EXPERIENCE_CAPACITY, OBS_SIZE * 2 + 3))
        self.priorities = torch.zeros(EXPERIENCE_CAPACITY)
        self.initial_max_priority = 1.0
        self.size = 0
        self.i = 0
        self.initial_beta = initial_beta
        self.beta = initial_beta
        self.beta_i = 0
        self.beta_i_at_max = int(EPISODE_COUNT*EPISODE_LENGTH/10)
    
    def push(self, transition):
        self.transitions[self.i] = transition
        max_priority = torch.max(self.priorities).item()
        self.priorities[self.i] = max_priority if max_priority > 0 else self.initial_max_priority
        self.size = min(self.size + 1, self.capacity)
        self.i = (self.i + 1) % EXPERIENCE_CAPACITY

    def update_beta(self):
        self.beta_i += 1
        multiplier = min(self.beta_i, self.beta_i_at_max)/self.beta_i_at_max
        self.beta = multiplier/(1-self.initial_beta) + self.initial_beta
        
    def update_priorities(self, losses, indices):
        with torch.no_grad():
            self.priorities[indices] = losses + 1e-4

    def sample(self, batch_size):
        scaled_priorities = torch.pow(self.priorities, self.alpha)
        probabilities = scaled_priorities/torch.sum(scaled_priorities)
        indices = np.random.choice(self.size, batch_size, replace=False, p=probabilities[:self.size].numpy())

        samples = self.transitions[indices]
        weights = torch.pow(probabilities[indices], -self.beta)
        normalized_weights = weights/torch.max(weights) 
        return ImportanceWeighedSample(self, samples, normalized_weights, indices, self.i)


In [None]:
# Environment
env = gym.make('CartPole-v0')

# Networks, loss fn, optimizer
q = DuelQ(env.action_space.n)
q_ = DuelQ(env.action_space.n)
transfer_state(q, q_)

criterion = nn.MSELoss()
optimizer = optim.Adam(q.parameters(), lr=0.005)

# Discount factor
gamma = 0.99

# Prioritised Experience Replay
isb = ImportanceSamplingBuffer(1, 0.01)

# List of possible actions
action_values = [torch.tensor([i], dtype=torch.float32) for i in range(env.action_space.n)]

# Replay memory counter, counter for target net update
next_i = 0
c = 0

# Train, test, visualise
for i_episode in range(EPISODE_COUNT):
    observation = torch.Tensor(env.reset())
    eps = max(0.01, ((EPISODE_COUNT - i_episode) / EPISODE_COUNT) * 0.99)
    for t in range(EPISODE_LENGTH):

        # With probability epsilon select a random action a_t
        # otherwise select a_t = argmax_a Q(x_t, a, net parameters)
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            action = duel_argmax_action(q, observation)

        # Execute action a_t in emulator and observe reward r_t and observation x_t+1
        new_observation, reward, done, info = env.step(action)
        new_observation = torch.Tensor(new_observation)
        
        # Store transition (a_t, x_t, r_t, x_t+1, d)
        transition_block = empty_transition_block()
        transition_block[0][0] = action
        transition_block[0][1:OBS_SIZE + 1] = observation
        transition_block[0][OBS_SIZE + 1: 2 * OBS_SIZE + 1] = new_observation
        transition_block[0][2 * OBS_SIZE + 1] = reward
        transition_block[0][2 * OBS_SIZE + 2] = 1 if (done or t == EPISODE_LENGTH - 1) else 0
        isb.push(transition_block)
        observation = new_observation
        next_i += 1
        c += 1

        # If enough memory stored
        if next_i > 500:

            # Sample random minibatch of transitions (x_j, a_j, r_j, x_j+1, d_j)
            sample = isb.sample(MINIBATCH_SIZE)
            batch = sample.batch
            # Set targets
            rs = batch[:, 2 * OBS_SIZE + 1].view(MINIBATCH_SIZE, 1)
            ds = batch[:, 2 * OBS_SIZE + 2].view(MINIBATCH_SIZE, 1)
        
            # # Choose from :
            # # # qs = duel_dqn_target(q_, batch) for Duelling DQN
            # # # qs = duel_double_dqn_target(q_, q, batch) for Double Duelling DQN
            qs = duel_double_dqn_target(q_, q, batch)
            ys = rs + (1 - ds) * gamma * qs
            
            # Perform a gradient descent step on (y_j - Q(x_j, a_j))^2
            optimizer.zero_grad()
            outputs = q.forward(batch[:, 1:OBS_SIZE + 1]).gather(1, batch[:, 0].unsqueeze(1).long())
            losses = torch.pow((outputs - ys), 2) * sample.weights.unsqueeze(1)

            loss = losses.mean()
            loss.backward()
            optimizer.step()
            
            # Update priorities and beta
            sample.update(losses.squeeze(1))
            
            # Update target network q_
            if USE_SOFT_TARGET_UPDATE:
                polyak(q, q_, 0.9)
            elif c > TARGET_NET_REFRESH_RATE:
                transfer_state(q, q_)
                c = 0

        if done:
            break
        
    if i_episode % 10 == 0 and next_i > 500:
        all_rewards = 0
        count = 100
        for j in range(count):
            observation = torch.Tensor(env.reset())
            rewards = 0
            for t in range(EPISODE_LENGTH):
                # Uncomment for rendering:
                # env.render()

                action = duel_argmax_action(q, observation)
                observation, reward, done, info = env.step(action)
                rewards += reward
                observation = torch.Tensor(observation)
                if done:
                    break
            all_rewards += rewards / count
        print("trained using data from %d episodes, current avg test reward: %d, current beta: %f" % (i_episode, all_rewards, isb.beta))
        if all_rewards > 195:
            torch.save(q.state_dict(), "ISDuelDQNCartpole")
            print("Cartpole solved, saved net")
            break

env.close()
