# References
- [SumTree introduction in Python](https://adventuresinmachinelearning.com/sumtree-introduction-python/)
- [A Deeper Look at Experience Replay](https://arxiv.org/abs/1712.01275)
- [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952)
- [Double Learning and Prioritized Experience Replay](https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/)
- [GitHub rlcode: Prioritized Experience Replay](https://github.com/rlcode/per)
- https://github.com/google/dopamine/tree/a9911ec0a322f38b9c0ea447e61caa2756a383cf

# Import

In [1]:
import time
import random
import math
from copy import deepcopy
from collections import namedtuple, deque
from itertools import count, product
import numpy as np
import matplotlib.pyplot as plt

import gym
from gym import logger
logger.set_level(gym.logger.DISABLED)
from replay_buffer import ReplayBuffer

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# Replay Buffer

In [10]:
"""A sum tree data structure.
Used for prioritized experience replay. See prioritized_replay_buffer.py
and Schaul et al. (2015).
"""
import math
import random

import numpy as np


Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class SumTree(object):
    """A sum tree data structure for storing replay priorities.
    A sum tree is a complete binary tree whose leaves contain values called
    priorities. Internal nodes maintain the sum of the priorities of all leaf
    nodes in their subtree.
    For capacity = 4, the tree may look like this:
               +---+
               |2.5|
               +-+-+
                 |
         +-------+--------+
         |                |
       +-+-+            +-+-+
       |1.5|            |1.0|
       +-+-+            +-+-+
         |                |
    +----+----+      +----+----+
    |         |      |         |
    +-+-+     +-+-+  +-+-+     +-+-+
    |0.5|     |1.0|  |0.5|     |0.5|
    +---+     +---+  +---+     +---+
    This is stored in a list of numpy arrays:
    self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ]
    For conciseness, we allocate arrays as powers of two, and pad the excess
    elements with zero values.
    This is similar to the usual array-based representation of a complete binary
    tree, but is a little more user-friendly.
    """

    def __init__(self, capacity):
        """Creates the sum tree data structure for the given replay capacity.
        Args:
            capacity: int, the maximum number of elements that can be stored in this
            data structure.
        Raises:
            ValueError: If requested capacity is not positive.
        """
        assert isinstance(capacity, int)
        if capacity <= 0:
            raise ValueError('Sum tree capacity should be positive. Got: {}'.format(capacity))

        self.nodes = []
        tree_depth = int(math.ceil(np.log2(capacity)))
        level_size = 1
        for _ in range(tree_depth + 1):
            nodes_at_this_depth = np.zeros(level_size)
            self.nodes.append(nodes_at_this_depth)
            
            level_size *= 2

        self.max_recorded_priority = 1

    def total(self):
        """Returns the sum of all priorities stored in this sum tree.
        Returns:
            float, sum of priorities stored in this sum tree.
        """
        return self.nodes[0][0]
    
    def sample(self, query_value=None):
        """Samples an element from the sum tree.
        Each element has probability p_i / sum_j p_j of being picked, where p_i is
        the (positive) value associated with node i (possibly unnormalized).
        Args:
            query_value: float in [0, 1], used as the random value to select a
            sample. If None, will select one randomly in [0, 1).ipynb_checkpoints/
        Returns:
            int, a random element from the sum tree.
        Raises:
            Exception: If the sum tree is empty (i.e. its node values sum to 0), or if
            the supplied query_value is larger than the total sum.
        """
        if self.total() == 0.0:
            raise Exception('Cannot sample from an empty sum tree.')
        
        if query_value and (query_value < 0. or query_value > 1.):
            raise ValueError('query_value must be in [0, 1].')
        
        # Sample a value in range [0, R), where R is the value stored at the root.
        query_value = random.random() if query_value is None else query_value
        query_value *= self.total()
        
        # Now traverse the sum tree.
        node_index = 0
        for nodes_at_this_depth in self.nodes[1:]:
            # Compute children of previous depth's node.
            left_child = node_index * 2
            
            left_sum = nodes_at_this_depth[left_child]
            # Each subtree describes a range [0, a), where a is its value.
            if query_value < left_sum:  # Recurse into left subtree.
                node_index = left_child
            else:  # Recurse into right subtree.
                node_index = left_child + 1
                # Adjust query to be relative to right subtree.
                query_value -= left_sum
        
        return node_index

    def stratified_sample(self, batch_size):
        """Performs stratified sampling using the sum tree.
        Let R be the value at the root (total value of sum tree). This method will
        divide [0, R) into batch_size segments, pick a random number from each of
        those segments, and use that random number to sample from the sum_tree. This
        is as specified in Schaul et al. (2015).
        Args:
            batch_size: int, the number of strata to use.
        Returns:
            list of batch_size elements sampled from the sum tree.
        Raises:
            Exception: If the sum tree is empty (i.e. its node values sum to 0).
        """
        if self.total() == 0.0:
            raise Exception('Cannot sample from an empty sum tree.')
        
        bounds = np.linspace(0., 1., batch_size + 1)
        assert len(bounds) == batch_size + 1
        segments = [(bounds[i], bounds[i+1]) for i in range(batch_size)]
        query_values = [random.uniform(x[0], x[1]) for x in segments]
        return [self.sample(query_value=x) for x in query_values]

    def get(self, node_index):
        """Returns the value of the leaf node corresponding to the index.
        Args:
            node_index: The index of the leaf node.
        Returns:
            The value of the leaf node.
        """
        return self.nodes[-1][node_index]
    
    def set(self, node_index, value):
        """Sets the value of a leaf node and updates internal nodes accordingly.
        This operation takes O(log(capacity)).
        Args:
            node_index: int, the index of the leaf node to be updated.
            value: float, the value which we assign to the node. This value must be
            nonnegative. Setting value = 0 will cause the element to never be
            sampled.
        Raises:
            ValueError: If the given value is negative.
        """
        if value < 0.0:
            raise ValueError('Sum tree values should be nonnegative. Got {}'.
                             format(value))
        self.max_recorded_priority = max(value, self.max_recorded_priority)
        
        self.nodes[-1][node_index] = value
        for nodes_in_parent_layer, nodes_in_child_layer in zip(reversed(self.nodes[:-1]), reversed(self.nodes[1:])):
            # Note: Adding a delta leads to some intolerable numerical inaccuracies.
            node_index //= 2
            nodes_in_parent_layer[node_index] = nodes_in_child_layer[2*node_index] + nodes_in_child_layer[2*node_index+1]
        
        assert node_index == 0, ('Sum tree traversal failed, final node index '
                                 'is not 0.')

class ReplayBuffer:
    e = 0.01
    a = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001
    
    def __init__(self, buffer_size):
        self.buffer_size = int(buffer_size)
        self.buffer = []
        self.index = 0
        self.sum_tree = SumTree(self.buffer_size)
        
    def __len__(self):
        return len(self.buffer)
    
    def _get_priority(self, error):
        return (np.abs(error) + self.e) ** self.a
    
    def add(self, error, state, action, reward, next_state, done):
        priority = self._get_priority(error)
        self.sum_tree.set(self.index, priority)
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(None)
        self.buffer[self.index] = (state, action, reward, next_state, done)
        self.index = (self.index + 1) % self.buffer_size
    
    def sample(self, batch_size):
        indices = self.sum_tree.stratified_sample(batch_size)
        priorities = [self.sum_tree.get(index) for index in indices]
        batch = [self.buffer[index] for index in indices]
        
        self.beta = np.min([1, self.beta + self.beta_increment_per_sampling])
        
        sampling_probabilities = priorities / self.sum_tree.total()
        is_weight = np.power(self.buffer_size * sampling_probabilities, -self.beta)
        is_weight /= is_weight.max()
        
        return indices, batch, is_weight
    
    def update(self, index, error):
        priority = self._get_priority(error)
        self.sum_tree.set(index, priority)

# CartPole Environment

In [11]:
env = gym.make('CartPole-v1')

# Neural Network

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 2)
    
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# DQN Agent

In [13]:
class Agent:
    def __init__(self, config, nn):
        self.step = 0
        self.gamma = config["gamma"]
        self.batch_size = config["batch_size"]
        self.replay_buffer = ReplayBuffer(config["buffer_size"])
        self.n_gradient_steps = config["n_gradient_steps"]
        self.n_actions = config["n_actions"]
        self.epsilon = config["epsilon_max"]
        self.epsilon_max = config["epsilon_max"]
        self.epsilon_min = config["epsilon_min"]
        self.epsilon_decay = config["epsilon_decay"]
        self.nn = nn.to(device)
        self.target_nn = deepcopy(self.nn).to(device)
        self.criterion = torch.nn.MSELoss()
        self.optimizer = optim.RMSprop(self.nn.parameters(), lr=config["learning_rate"])
        self.option = config["option"]
        if "beta" in config:
            self.beta = config["beta"]
            self.update_target_nn = self.smooth_update_target_nn
        elif "C" in config:
            self.C = config["C"]
            self.update_target_nn = self.periodic_update_target_nn
        else:
            self.update_target_nn = (lambda: None)
            self.option = 0
    
    def append_sample(self, state, action, reward, next_state, done):
        state_action_values = self.nn(state.unsqueeze(0)).gather(1, action.unsqueeze(0))
        if self.option == 2:
            # Double Q-Learning
            next_state_action_values = self.target_nn(next_state.unsqueeze(0)).gather(1, self.nn(next_state.unsqueeze(0)).max(1)[1].unsqueeze(1)).detach()
        elif self.option == 1:
            # Target
            next_state_action_values = self.target_nn(next_state.unsqueeze(0)).max(1)[0].unsqueeze(1).detach()
        else:
            # Vanilla
            next_state_action_values = self.nn(next_state.unsqueeze(0)).max(1)[0].unsqueeze(1).detach()
        expected_state_action_values = reward.unsqueeze(0) + self.gamma * next_state_action_values * (1 - done.unsqueeze(0))
        error = (state_action_values - expected_state_action_values).data.numpy()
        
        self.replay_buffer.add(error, state, action, reward, next_state, done)
    
    def update_epsilon(self):
        self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * math.exp(-1. * self.step / self.epsilon_decay)
    
    def epsilon_greedy_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                return torch.argmax(self.nn(torch.Tensor(state))).item()
        else:
            return torch.tensor(random.randrange(self.n_actions), device=device, dtype=torch.long).item()
    
    def greedy_action(self, state):
        with torch.no_grad():
            return torch.argmax(self.nn(torch.Tensor(state))).item()
    
    def gradient_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        indices, batch, is_weights = self.replay_buffer.sample(self.batch_size)
        batch = Transition(*zip(*batch))
        
        state_batch = torch.stack(batch.state)
        action_batch = torch.stack(batch.action)
        reward_batch = torch.stack(batch.reward)
        next_state_batch = torch.stack(batch.next_state)
        done_batch = torch.stack(batch.done)
        
        state_action_values = self.nn(state_batch).gather(1, action_batch)
        if self.option == 2:
            # Double Q-Learning
            next_state_action_values = self.target_nn(next_state_batch).gather(1, self.nn(next_state_batch).max(1)[1].unsqueeze(1)).detach()
        elif self.option == 1:
            # Target
            next_state_action_values = self.target_nn(next_state_batch).max(1)[0].unsqueeze(1).detach()
        else:
            # Vanilla
            next_state_action_values = self.nn(next_state_batch).max(1)[0].unsqueeze(1).detach()
        expected_state_action_values = reward_batch + self.gamma * next_state_action_values * (1 - done_batch)
        
        errors = (state_action_values - expected_state_action_values).data.numpy()
        for i in range(self.batch_size):
            index = indices[i]
            self.replay_buffer.update(index, errors[i])
        
        loss = self.criterion(state_action_values, expected_state_action_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.update_target_nn()
    
    def periodic_update_target_nn(self):
        if self.step % self.C == 0:
            self.target_nn.load_state_dict(self.nn.state_dict())  
    
    def smooth_update_target_nn(self):
        for param_nn, param_target_nn in zip(self.nn.parameters(), self.target_nn.parameters()):
            param_target_nn.data.copy_(param_nn * self.beta + param_target_nn * (1 - self.beta))
    
    def train(self, env, n_episodes):
        episode_return_list = []
        for i_episode in range(1, n_episodes+1):
            episode_return = 0
            state = env.reset()
            for t in count():
                self.update_epsilon()
                action = self.epsilon_greedy_action(state)
                next_state, reward, done, _ = env.step(action)
                episode_return += reward
                
                self.append_sample(torch.Tensor(state), torch.tensor([action], dtype=torch.long), torch.tensor([reward], dtype=torch.float), torch.Tensor(next_state), torch.tensor([int(done)], dtype=torch.long))
                state = next_state
                self.step += 1
                
                for _ in range(self.n_gradient_steps):
                    self.gradient_step()
                
                if done:
                    episode_return_list.append(episode_return)
                    print("Episode {:4d} : {:4d} steps | epsilon = {:4.2f} | return = {:.1f}".format(i_episode, t+1, self.epsilon, episode_return))
                    break
        return episode_return_list
    
    def test(self, env, step_max):
        episode_return = 0
        state = env.reset()
        for t in count():
            env.render()
            action = self.greedy_action(state)
            next_state, reward, done, _ = env.step(action)
            episode_return += reward
            
            state = next_state

            if done or t+1 >= step_max:
                env.close()
                return episode_return
    
    def save(self, name):
        torch.save(self.nn.state_dict(), "trained_nn/{}.pt".format(name))
    
    def load(self, name):
        self.nn.load_state_dict(torch.load("trained_nn/{}.pt".format(name)))

config = {"gamma": 0.95,
          "batch_size": 8,
          "buffer_size": 1e6,
          "n_gradient_steps": 1,
          "n_actions": 2,
          "learning_rate": 0.001,
          "epsilon_max": 1.,
          "epsilon_min": 0.01,
          "epsilon_decay": 5000,
          "C": 100,
          "option": 1}

dqn = DQN()
agent = Agent(config, dqn)

In [14]:
episode_return_list = agent.train(env, 100)

Episode    1 :   15 steps | epsilon = 1.00 | return = 15.0
Episode    2 :   12 steps | epsilon = 0.99 | return = 12.0
Episode    3 :   31 steps | epsilon = 0.99 | return = 31.0
Episode    4 :   14 steps | epsilon = 0.99 | return = 14.0
Episode    5 :   16 steps | epsilon = 0.98 | return = 16.0
Episode    6 :   38 steps | epsilon = 0.98 | return = 38.0
Episode    7 :   18 steps | epsilon = 0.97 | return = 18.0
Episode    8 :   23 steps | epsilon = 0.97 | return = 23.0
Episode    9 :   34 steps | epsilon = 0.96 | return = 34.0
Episode   10 :   20 steps | epsilon = 0.96 | return = 20.0
Episode   11 :   27 steps | epsilon = 0.95 | return = 27.0
Episode   12 :   40 steps | epsilon = 0.94 | return = 40.0
Episode   13 :   17 steps | epsilon = 0.94 | return = 17.0
Episode   14 :   31 steps | epsilon = 0.94 | return = 31.0
Episode   15 :   27 steps | epsilon = 0.93 | return = 27.0
Episode   16 :   20 steps | epsilon = 0.93 | return = 20.0
Episode   17 :   32 steps | epsilon = 0.92 | return = 32

In [15]:
class CartPoleSwingUp(gym.Wrapper):
    def __init__(self, env, **kwargs):
        super(CartPoleSwingUp, self).__init__(env, **kwargs)
        self.theta_dot_threshold = 4*np.pi

    def reset(self):
        self.env.env.state = [0, 0, np.pi, 0] + super().reset()
        self.env.env.steps_beyond_done = None
        return np.array(self.env.env.state)

    def step(self, action):
        state, reward, done, _ = super().step(action)
        x, x_dot, theta, theta_dot = state
        
        done = x < -self.x_threshold \
               or x > self.x_threshold \
               or theta_dot < -self.theta_dot_threshold \
               or theta_dot > self.theta_dot_threshold
        
        if done:
            # game over
            reward = -10.
            if self.steps_beyond_done is None:
                self.steps_beyond_done = 0
            else:
                logger.warn("You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.")
                self.steps_beyond_done += 1
        else:
            if -self.theta_threshold_radians < theta and theta < self.theta_threshold_radians:
                # pole upright
                reward = 1.
            else:
                # pole swinging
                reward = 0.

        return np.array(self.state), reward, done, {}

env = CartPoleSwingUp(gym.make('CartPole-v1'))

In [16]:
config = {"gamma": 0.95,
          "batch_size": 64,
          "buffer_size": 1e6,
          "n_gradient_steps": 1,
          "n_actions": 2,
          "learning_rate": 0.001,
          "epsilon_max": 1.,
          "epsilon_min": 0.01,
          "epsilon_decay": 5000,
          "C": 100,
          "option": 2}

dqn = DQN()
agent = Agent(config, dqn)

In [17]:
episode_return_list = agent.train(env, 500)

Episode    1 :   79 steps | epsilon = 0.98 | return = -10.0
Episode    2 :  134 steps | epsilon = 0.96 | return = -10.0
Episode    3 :  253 steps | epsilon = 0.91 | return = -10.0
Episode    4 :  139 steps | epsilon = 0.89 | return = -10.0
Episode    5 :   70 steps | epsilon = 0.88 | return = -10.0
Episode    6 :  307 steps | epsilon = 0.82 | return = -10.0
Episode    7 :   83 steps | epsilon = 0.81 | return = -10.0
Episode    8 :  411 steps | epsilon = 0.75 | return = -10.0
Episode    9 :  239 steps | epsilon = 0.71 | return = 3.0
Episode   10 :  182 steps | epsilon = 0.69 | return = -10.0
Episode   11 :  121 steps | epsilon = 0.67 | return = -10.0
Episode   12 :  190 steps | epsilon = 0.65 | return = -10.0
Episode   13 :  240 steps | epsilon = 0.62 | return = -10.0
Episode   14 :  179 steps | epsilon = 0.60 | return = -10.0
Episode   15 :  108 steps | epsilon = 0.58 | return = -10.0
Episode   16 :  229 steps | epsilon = 0.56 | return = -10.0
Episode   17 :  219 steps | epsilon = 0.53

Episode  139 :  392 steps | epsilon = 0.01 | return = -10.0
Episode  140 :  494 steps | epsilon = 0.01 | return = -1.0
Episode  141 :  125 steps | epsilon = 0.01 | return = -5.0
Episode  142 :  132 steps | epsilon = 0.01 | return = -1.0
Episode  143 :  143 steps | epsilon = 0.01 | return = -2.0
Episode  144 :  140 steps | epsilon = 0.01 | return = -2.0
Episode  145 :  118 steps | epsilon = 0.01 | return = 3.0
Episode  146 :  135 steps | epsilon = 0.01 | return = -1.0
Episode  147 :  442 steps | epsilon = 0.01 | return = -10.0
Episode  148 :  157 steps | epsilon = 0.01 | return = -2.0
Episode  149 :  375 steps | epsilon = 0.01 | return = -10.0
Episode  150 :  316 steps | epsilon = 0.01 | return = -4.0
Episode  151 :  414 steps | epsilon = 0.01 | return = -10.0
Episode  152 :  525 steps | epsilon = 0.01 | return = -1.0
Episode  153 :  142 steps | epsilon = 0.01 | return = 2.0
Episode  154 :  682 steps | epsilon = 0.01 | return = -2.0
Episode  155 :  504 steps | epsilon = 0.01 | return = 

Episode  278 :  489 steps | epsilon = 0.01 | return = 40.0
Episode  279 :  355 steps | epsilon = 0.01 | return = 18.0
Episode  280 :  338 steps | epsilon = 0.01 | return = -2.0
Episode  281 :  257 steps | epsilon = 0.01 | return = -3.0
Episode  282 :  244 steps | epsilon = 0.01 | return = 95.0
Episode  283 :  239 steps | epsilon = 0.01 | return = 60.0
Episode  284 :  325 steps | epsilon = 0.01 | return = 4.0
Episode  285 : 1061 steps | epsilon = 0.01 | return = 0.0
Episode  286 :  809 steps | epsilon = 0.01 | return = 76.0
Episode  287 :  560 steps | epsilon = 0.01 | return = -10.0
Episode  288 :  702 steps | epsilon = 0.01 | return = 220.0
Episode  289 :  498 steps | epsilon = 0.01 | return = 45.0
Episode  290 :  229 steps | epsilon = 0.01 | return = -10.0
Episode  291 : 1479 steps | epsilon = 0.01 | return = 71.0
Episode  292 :  773 steps | epsilon = 0.01 | return = -10.0
Episode  293 :  894 steps | epsilon = 0.01 | return = -10.0
Episode  294 :  915 steps | epsilon = 0.01 | return =

Episode  417 :  170 steps | epsilon = 0.01 | return = 15.0
Episode  418 :  218 steps | epsilon = 0.01 | return = 38.0
Episode  419 :  180 steps | epsilon = 0.01 | return = 7.0
Episode  420 : 1525 steps | epsilon = 0.01 | return = 44.0
Episode  421 :  222 steps | epsilon = 0.01 | return = 44.0
Episode  422 :  202 steps | epsilon = 0.01 | return = 24.0
Episode  423 :  269 steps | epsilon = 0.01 | return = 54.0
Episode  424 :  257 steps | epsilon = 0.01 | return = 52.0
Episode  425 :  243 steps | epsilon = 0.01 | return = 55.0
Episode  426 :  221 steps | epsilon = 0.01 | return = 32.0
Episode  427 :  226 steps | epsilon = 0.01 | return = 20.0
Episode  428 :  243 steps | epsilon = 0.01 | return = 27.0
Episode  429 :  258 steps | epsilon = 0.01 | return = 75.0
Episode  430 :  202 steps | epsilon = 0.01 | return = 36.0
Episode  431 :  220 steps | epsilon = 0.01 | return = 48.0
Episode  432 :  235 steps | epsilon = 0.01 | return = 58.0
Episode  433 :  229 steps | epsilon = 0.01 | return = 50.

In [22]:
agent.test(env, 500)

-10.0

In [61]:
import pickle

with open('agent.pkl', 'wb') as output:
    pickle.dump(agent, output, pickle.HIGHEST_PROTOCOL)

In [10]:
import pickle

with open('agent.pkl', 'rb') as input:
    agent_ = pickle.load(input)