# References
- [SumTree introduction in Python](https://adventuresinmachinelearning.com/sumtree-introduction-python/)
- [A Deeper Look at Experience Replay](https://arxiv.org/abs/1712.01275)
- [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952)
- [Double Learning and Prioritized Experience Replay](https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/)
- [GitHub rlcode: Prioritized Experience Replay](https://github.com/rlcode/per)

# Import

In [2]:
import time
import random
import math
from copy import deepcopy
from collections import namedtuple, deque
from itertools import count, product
import numpy as np
import matplotlib.pyplot as plt

import gym
from gym import logger
logger.set_level(gym.logger.DISABLED)
from replay_buffer import ReplayBuffer

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

# Replay Buffer

In [3]:
class SumTree:
    write = 0
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.n_entries = 0
    
    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        
        self.tree[parent] += change
        
        if parent != 0:
            self._propagate(parent, change)
    
    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
            #return min(idx, self.n_entries + self.capacity - 2) # idx can be greater than n_entries + capacity - 2 which means dataIdx >= n_entries
        
        if s < self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])
    
    def total(self):
        return self.tree[0]
    
    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1
        
        self.data[self.write] = data
        self.update(idx, p)
        
        self.write += 1
        if self.write >= self.capacity:
            self.write = 0
        
        if self.n_entries < self.capacity:
            self.n_entries += 1
    
    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]
        
        self.tree[idx] = p
        self._propagate(idx, change)
    
    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        
        return (idx, self.tree[idx], self.data[dataIdx])

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))
    
class ReplayBuffer:
    e = 0.01
    a = 0.6
    beta = 0.4
    beta_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(int(capacity))
        self.capacity = int(capacity)
    
    def __len__(self):
        return self.tree.n_entries
    
    def _get_priority(self, error):
        return (np.abs(error) + self.e) ** self.a
    
    def add(self, error, sample):
        p = self._get_priority(error)
        self.tree.add(p, Transition(*sample))
    
    def sample(self, n):
        batch = []
        idxs = []
        segment = self.tree.total() / n
        priorities = []
        
        self.beta = np.min([1, self.beta + self.beta_increment_per_sampling])
        
        for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            
            s = random.uniform(a, b)
            idx, p, data = self.tree.get(s)
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)
        
        sampling_probabilities = priorities / self.tree.total()
        is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
        is_weight /= is_weight.max()
        
        return batch, idxs, is_weight
    
    def update(self, idx, error):
        p = self._get_priority(error)
        self.tree.update(idx, p)


# CartPole Environment

In [4]:
env = gym.make('CartPole-v1')

# Neural Network

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 2)
    
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# DQN Agent

In [6]:
class Agent:
    def __init__(self, config, nn):
        self.step = 0
        self.gamma = config["gamma"]
        self.batch_size = config["batch_size"]
        self.replay_buffer = ReplayBuffer(config["buffer_size"])
        self.n_gradient_steps = config["n_gradient_steps"]
        self.n_actions = config["n_actions"]
        self.epsilon = config["epsilon_max"]
        self.epsilon_max = config["epsilon_max"]
        self.epsilon_min = config["epsilon_min"]
        self.epsilon_decay = config["epsilon_decay"]
        self.nn = nn.to(device)
        self.target_nn = deepcopy(self.nn).to(device)
        self.criterion = torch.nn.MSELoss()
        self.optimizer = optim.RMSprop(self.nn.parameters(), lr=config["learning_rate"])
        self.option = config["option"]
        if "beta" in config:
            self.beta = config["beta"]
            self.update_target_nn = self.smooth_update_target_nn
        elif "C" in config:
            self.C = config["C"]
            self.update_target_nn = self.periodic_update_target_nn
        else:
            self.update_target_nn = (lambda: None)
            self.option = 0
    
    def append_sample(self, state, action, reward, next_state, done):
        state_action_values = self.nn(state.unsqueeze(0)).gather(1, action.unsqueeze(0))
        if self.option == 2:
            # Double Q-Learning
            next_state_action_values = self.target_nn(next_state.unsqueeze(0)).gather(1, self.nn(next_state.unsqueeze(0)).max(1)[1].unsqueeze(1)).detach()
        elif self.option == 1:
            # Target
            next_state_action_values = self.target_nn(next_state.unsqueeze(0)).max(1)[0].unsqueeze(1).detach()
        else:
            # Vanilla
            next_state_action_values = self.nn(next_state.unsqueeze(0)).max(1)[0].unsqueeze(1).detach()
        expected_state_action_values = reward.unsqueeze(0) + self.gamma * next_state_action_values * (1 - done.unsqueeze(0))
        error = torch.abs(state_action_values - expected_state_action_values).data.numpy()
        
        self.replay_buffer.add(error, (state, action, reward, next_state, done))
    
    def update_epsilon(self):
        self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * math.exp(-1. * self.step / self.epsilon_decay)
    
    def epsilon_greedy_action(self, state):
        if random.random() > self.epsilon:
            with torch.no_grad():
                return torch.argmax(self.nn(torch.Tensor(state))).item()
        else:
            return torch.tensor(random.randrange(self.n_actions), device=device, dtype=torch.long).item()
    
    def greedy_action(self, state):
        with torch.no_grad():
            return torch.argmax(self.nn(torch.Tensor(state))).item()
    
    def gradient_step(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        transitions, idxs, is_weights = self.replay_buffer.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.stack(batch.state)
        action_batch = torch.stack(batch.action)
        reward_batch = torch.stack(batch.reward)
        next_state_batch = torch.stack(batch.next_state)
        done_batch = torch.stack(batch.done)
        
        state_action_values = self.nn(state_batch).gather(1, action_batch)
        if self.option == 2:
            # Double Q-Learning
            next_state_action_values = self.target_nn(next_state_batch).gather(1, self.nn(next_state_batch).max(1)[1].unsqueeze(1)).detach()
        elif self.option == 1:
            # Target
            next_state_action_values = self.target_nn(next_state_batch).max(1)[0].unsqueeze(1).detach()
        else:
            # Vanilla
            next_state_action_values = self.nn(next_state_batch).max(1)[0].unsqueeze(1).detach()
        expected_state_action_values = reward_batch + self.gamma * next_state_action_values * (1 - done_batch)
        
        errors = torch.abs(state_action_values - expected_state_action_values).data.numpy()
        for i in range(self.batch_size):
            idx = idxs[i]
            self.replay_buffer.update(idx, errors[i])
        
        loss = self.criterion(state_action_values, expected_state_action_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.update_target_nn()
    
    def periodic_update_target_nn(self):
        if self.step % self.C == 0:
            self.target_nn.load_state_dict(self.nn.state_dict())  
    
    def smooth_update_target_nn(self):
        for param_nn, param_target_nn in zip(self.nn.parameters(), self.target_nn.parameters()):
            param_target_nn.data.copy_(param_nn * self.beta + param_target_nn * (1 - self.beta))
    
    def train(self, env, n_episodes):
        episode_return_list = []
        for i_episode in range(1, n_episodes+1):
            episode_return = 0
            state = env.reset()
            for t in count():
                self.update_epsilon()
                action = self.epsilon_greedy_action(state)
                next_state, reward, done, _ = env.step(action)
                episode_return += reward
                
                self.append_sample(torch.Tensor(state), torch.tensor([action], dtype=torch.long), torch.tensor([reward], dtype=torch.float), torch.Tensor(next_state), torch.tensor([int(done)], dtype=torch.long))
                state = next_state
                self.step += 1
                
                for _ in range(self.n_gradient_steps):
                    self.gradient_step()
                
                if done:
                    episode_return_list.append(episode_return)
                    print("Episode {:4d} : {:4d} steps | epsilon = {:4.2f} | return = {:.1f}".format(i_episode, t+1, self.epsilon, episode_return))
                    break
        return episode_return_list
    
    def test(self, env, step_max):
        episode_return = 0
        state = env.reset()
        for t in count():
            env.render()
            action = self.greedy_action(state)
            next_state, reward, done, _ = env.step(action)
            episode_return += reward
            
            state = next_state

            if done or t+1 >= step_max:
                env.close()
                return episode_return
    
    def save(self, name):
        torch.save(self.nn.state_dict(), "trained_nn/{}.pt".format(name))
    
    def load(self, name):
        self.nn.load_state_dict(torch.load("trained_nn/{}.pt".format(name)))

config = {"gamma": 0.95,
          "batch_size": 8,
          "buffer_size": 1e6,
          "n_gradient_steps": 1,
          "n_actions": 2,
          "learning_rate": 0.001,
          "epsilon_max": 1.,
          "epsilon_min": 0.01,
          "epsilon_decay": 5000,
          "C": 100,
          "option": 1}

dqn = DQN()
agent = Agent(config, dqn)

In [6]:
episode_return_list = agent.train(env, 100)

Episode    1 :   29 steps | epsilon = 0.99 | return = 29.0
Episode    2 :   26 steps | epsilon = 0.99 | return = 26.0
Episode    3 :   21 steps | epsilon = 0.99 | return = 21.0
Episode    4 :   13 steps | epsilon = 0.98 | return = 13.0
Episode    5 :   13 steps | epsilon = 0.98 | return = 13.0
Episode    6 :   42 steps | epsilon = 0.97 | return = 42.0
Episode    7 :   12 steps | epsilon = 0.97 | return = 12.0
Episode    8 :   28 steps | epsilon = 0.96 | return = 28.0
Episode    9 :   23 steps | epsilon = 0.96 | return = 23.0
Episode   10 :   21 steps | epsilon = 0.96 | return = 21.0
Episode   11 :   24 steps | epsilon = 0.95 | return = 24.0
Episode   12 :   22 steps | epsilon = 0.95 | return = 22.0
Episode   13 :   71 steps | epsilon = 0.93 | return = 71.0
Episode   14 :   15 steps | epsilon = 0.93 | return = 15.0
Episode   15 :   11 steps | epsilon = 0.93 | return = 11.0
Episode   16 :   26 steps | epsilon = 0.92 | return = 26.0
Episode   17 :   12 steps | epsilon = 0.92 | return = 12

In [7]:
class CartPoleSwingUp(gym.Wrapper):
    def __init__(self, env, **kwargs):
        super(CartPoleSwingUp, self).__init__(env, **kwargs)
        self.theta_dot_threshold = 4*np.pi

    def reset(self):
        self.env.env.state = [0, 0, np.pi, 0] + super().reset()
        self.env.env.steps_beyond_done = None
        return np.array(self.env.env.state)

    def step(self, action):
        state, reward, done, _ = super().step(action)
        x, x_dot, theta, theta_dot = state
        
        done = x < -self.x_threshold \
               or x > self.x_threshold \
               or theta_dot < -self.theta_dot_threshold \
               or theta_dot > self.theta_dot_threshold
        
        if done:
            # game over
            reward = -10.
            if self.steps_beyond_done is None:
                self.steps_beyond_done = 0
            else:
                logger.warn("You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior.")
                self.steps_beyond_done += 1
        else:
            if -self.theta_threshold_radians < theta and theta < self.theta_threshold_radians:
                # pole upright
                reward = 1.
            else:
                # pole swinging
                reward = 0.

        return np.array(self.state), reward, done, {}

env = CartPoleSwingUp(gym.make('CartPole-v1'))

In [8]:
config = {"gamma": 0.95,
          "batch_size": 64,
          "buffer_size": 1e6,
          "n_gradient_steps": 1,
          "n_actions": 2,
          "learning_rate": 0.001,
          "epsilon_max": 1.,
          "epsilon_min": 0.01,
          "epsilon_decay": 5000,
          "C": 100,
          "option": 2}

dqn = DQN()
agent = Agent(config, dqn)

In [9]:
episode_return_list = agent.train(env, 500)

Episode    1 :  387 steps | epsilon = 0.93 | return = -10.0
Episode    2 :  193 steps | epsilon = 0.89 | return = -10.0
Episode    3 :   89 steps | epsilon = 0.88 | return = -10.0
Episode    4 :  128 steps | epsilon = 0.85 | return = -10.0
Episode    5 :  175 steps | epsilon = 0.83 | return = -10.0
Episode    6 :   53 steps | epsilon = 0.82 | return = -10.0
Episode    7 :  151 steps | epsilon = 0.79 | return = -10.0
Episode    8 :   57 steps | epsilon = 0.78 | return = -10.0
Episode    9 :   99 steps | epsilon = 0.77 | return = -10.0
Episode   10 :  156 steps | epsilon = 0.75 | return = -10.0
Episode   11 :   67 steps | epsilon = 0.74 | return = -10.0
Episode   12 :   94 steps | epsilon = 0.72 | return = -10.0
Episode   13 :   95 steps | epsilon = 0.71 | return = -10.0
Episode   14 :  252 steps | epsilon = 0.67 | return = -10.0
Episode   15 :  134 steps | epsilon = 0.66 | return = -10.0
Episode   16 :  200 steps | epsilon = 0.63 | return = -10.0
Episode   17 :  302 steps | epsilon = 0.



TypeError: zip argument #64 must support iteration

In [11]:
agent.replay_buffer.tree.total()

1406.2579345703125

In [10]:
agent.replay_buffer.tree.n_entries

5046

In [49]:
agent.replay_buffer.tree.capacity

1000000

In [12]:
agent.replay_buffer.tree.get(agent.replay_buffer.tree.total() - 0.02)

(1005044,
 0.3313210606575012,
 Transition(state=tensor([-0.2161,  0.5510,  4.0701,  1.3220]), action=tensor([1]), reward=tensor([0.]), next_state=tensor([-0.2051,  0.7295,  4.0966,  1.2470]), done=tensor([0])))

In [14]:
e = 0
while e <= agent.replay_buffer.tree.total():
    print(agent.replay_buffer.tree.get(e)[1])
    e += 1

0.10779391974210739
0.4694114625453949
0.40345054864883423
0.3975670337677002
0.35289648175239563
0.25959840416908264
0.37099429965019226
0.37603694200515747
0.4159894585609436
0.5031104683876038
0.6805336475372314
0.7422857284545898
0.6383516788482666
0.2833240330219269
0.4410078227519989
0.3239527642726898
0.4206015169620514
0.46233949065208435
0.5516879558563232
0.28144407272338867
0.342303991317749
0.7345882654190063
0.6751468181610107
0.24015632271766663
0.18442586064338684
0.2297784984111786
0.07366155833005905
0.144300639629364
0.07827139645814896
0.3701689839363098
0.5057386159896851
0.2772115468978882
0.4427688419818878
0.11088944226503372
0.14837409555912018
0.3371681571006775
0.3454602360725403
0.33757418394088745
0.2057182639837265
0.3177955448627472
0.19138458371162415
0.24309928715229034
0.3079693615436554
0.30093249678611755
0.10048722475767136
0.5821244120597839
0.23987500369548798
0.08279425650835037
0.10048987716436386
0.06896474957466125
0.2923209071159363
0.51388502

In [22]:
(0.99 + 0.01) ** 0.6

1.0