In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import gym
import math

In [None]:
EPISODES = 700
BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class PytorchWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        obs = torch.tensor(obs, dtype=torch.float)
        if done: reward = -10 ## Specific to Cartpole env
        return obs, reward, done
    
    def reset(self):
        obs = self.env.reset()
        obs = torch.tensor(obs, dtype=torch.float)
        return obs

# Policy Network - FCN

In [None]:
class PolicyFC(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 64), 
            nn.LeakyReLU(), 
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 16),
            nn.LeakyReLU(),
            nn.Linear(16, out_features)
        )
    
    def forward(self, x):
        return self.net(x)

# Prioritized Experience Replay

[Schaul et al.](https://arxiv.org/pdf/1511.05952.pdf) propose using binary heap for faster search and retreival of samples. However, In this implementation we use a list. <br />
The samples are prioritized based on TD error. Instead of purely sampling only on TD error (greedy approach), an alpha parameter is used to interpolate between uniform random and greedy sampling.

In [None]:
class PrioritizedExperienceReplay:
    """
    Implementation is adapted from: https://github.com/higgsfield/RL-Adventure 
    """
    def __init__(self, buffer_size=100000, alpha=0.6):
        self.alpha = alpha ## Determines how much prioritization is used
        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.buffer_size = buffer_size
        self.priorities = np.zeros((buffer_size,)).astype(np.float) + 1e-5 ## TD errors
        self.count = 0
    
    def store(self, state, action, next_state, reward):
        if(len(self.state) == self.buffer_size):
            self.state = self.state[1:]
            self.action = self.action[1:]
            self.next_state = self.next_state[1:]
            self.reward = self.reward[1:]
        
        self.state.append(state)
        self.action.append(action)
        self.next_state.append(next_state)
        self.reward.append(reward)
        max_priority = self.priorities.max()
        self.priorities[self.count] = max_priority
        self.count += 1
        self.count = min(self.count, self.buffer_size - 1)
    
    def sample_batch(self, batch_size, beta=0.4):
        probs  = self.priorities[:self.count] ** self.alpha
        probs /= probs.sum()
        
        idxs = np.random.choice(len(self.state), batch_size, p=probs)
        state = torch.stack(self.state)[idxs]
        action = torch.tensor(self.action, dtype=torch.long)[idxs]
        next_state = torch.stack(self.next_state)[idxs]
        reward = torch.tensor(self.reward, dtype=torch.float)[idxs]

        ## Beta parameter is used to anneal the amount of importance sampling
        total = len(self.state)
        weights  = (total * probs[idxs]) ** (-beta)
        weights /= weights.max()
        weights = torch.tensor(weights, dtype=torch.float)

        return (state, action, next_state, reward, idxs, weights)
    
    def update_priorities(self, idxs, priorities):
        for i, priority in zip(idxs, priorities):
            self.priorities[i] = abs(priority)

    def __len__(self):
        return len(self.state)

# Double DQN Agent

In [None]:
class DQN:
    def __init__(self, obs_size, action_size, device, gamma=0.99, lr=0.001):
        self.target = PolicyFC(obs_size, action_size).to(device)
        self.target.eval()
        self.policy = PolicyFC(obs_size, action_size).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.device = device
        self.buffer = PrioritizedExperienceReplay()
        self.gamma = gamma
        self.action_size = action_size
    
    def loss_fct(self, target, pred):
        return F.smooth_l1_loss(pred, target, reduction="none")
    
    def forward(self, policy, obs, grad=False):
        obs = obs.to(self.device)
        if(obs.size() == (4,)):
            obs = obs.unsqueeze(0)
        q_values = policy(obs)
        if(not grad):
            q_values = q_values.detach()
        action = torch.argmax(q_values, 1)
        return q_values, action
    
    def optimize_policy(self, batch, beta):
        self.optimizer.zero_grad()
        state, action, next_state, reward, idxs, weights = batch
        weights = weights.to(device)
        action = action.unsqueeze(1).to(device)
        reward = reward.to(self.device)
        Q, _ = self.forward(self.policy, state, grad=True)
        _, next_action = self.forward(self.policy, next_state)
        next_Q, _ = self.forward(self.target, next_state)
        ## Target value estimation is made using both networks. Prevents overestimation
        Q_target = next_Q.gather(1, next_action.unsqueeze(-1)).squeeze()
        target = reward + self.gamma * Q_target
        Q = Q.gather(1, action).squeeze()
        loss = self.loss_fct(Q, target)
        loss = loss * weights
        priorities = loss + 1e-5
        self.buffer.update_priorities(idxs, priorities.detach().cpu().numpy())
        loss = loss.mean()

        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def update_target(self):
        self.policy.eval()
        self.target.load_state_dict(self.policy.state_dict())
        torch.save(self.target.state_dict(), "DQN_Agent.bin")
        self.policy.train()
    
    def load_policy(self, path=None):
        if path is None:
            path = "DQN_Agent.bin"
        
        self.target.load_state_dict(torch.load(path))
        print("Successfully loaded")
    
    def evaluate_policy(self, env):
        obs = env.reset()
        done = False
        count = 0
        while(not done):
            obs = obs.unsqueeze(0)
            obs = obs.to(self.device)
            env.render()
            with torch.no_grad():
                q_values = self.target(obs)
            action = torch.argmax(q_values, 1).item()
            obs, reward, done = env.step(action)
            print(f"{count}, {action}, {reward}")
            count += 1
    
    def get_beta(self, curr_eps, total_eps):
        """
        Reduce beta as episodes trained increases.
        """
        beta_start = 0.4
        beta = beta_start + curr_eps * (1.0 - beta_start) / total_eps
        beta = min(1.0, beta)
        return beta
    
    def get_eps(self, i, decay=100):
        """
        Reduce epsilon as training progresses to reduce exlporation.
        """
        epsilon_start = 1.0
        epsilon_final = 0.05
        eps = epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * i / decay)
        return eps
    
    def learn(self, env, episodes, batch_size):
        writer = SummaryWriter()
        counter = 1
        loss_count = 0
        reward_count = 0
        for eps in range(episodes):
            obs = env.reset()
            loss_tracker = AverageMeter()
            reward_tracker = AverageMeter()
            for t in range(1000):
                epsilon = self.get_eps(eps)
                if(np.random.rand() <= epsilon): ## Epsilon greedy
                    action = np.random.randint(self.action_size)
                else:
                    _, action = self.forward(self.policy, obs)
                    action = action.item()
                next_obs, reward, done = env.step(action)
                self.buffer.store(obs, action, next_obs, reward)
                reward_tracker.update(reward)

                if(len(self.buffer) >= batch_size):
                    batch = self.buffer.sample_batch(batch_size)
                    beta = self.get_beta(eps, episodes)
                    loss = self.optimize_policy(batch, beta)
                    loss_tracker.update(loss)
                    writer.add_scalar('Loss', loss, loss_count)
                    loss_count += 1
                    if(counter % 200 == 0): ## Delayed update of target. Promotes exploration
                        self.update_target()
                
                if done: break

                counter += 1
                obs = next_obs
            
            writer.add_scalar("Reward", reward_tracker.sum, reward_count)
            reward_count += 1
            
            if((eps + 1) % 10 == 0):
                print(f"Episode: {eps}/{episodes}, step: {t+1}/1000, Epsilon: {epsilon}, reward: {reward_tracker.sum}, loss: {loss_tracker.avg}")

In [None]:
env = gym.make('CartPole-v0')
obs_size = env.observation_space.shape[0]
action_size = env.action_space.n
env = PytorchWrapper(env)

In [None]:
agent = DQN(obs_size, action_size, device)

In [None]:
## Load tensorboard for visualization of loss
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
agent.learn(env, EPISODES, BATCH_SIZE)

In [None]:
# agent.load_policy() # Load trained policy from local
agent.evaluate_policy(env) # Evaluate target policy