In [2]:
import tensorflow as tf
import numpy as np

In [3]:
import datetime
import random

## Experience replay

In [6]:
class PrioritizedExperienceReplay():
    '''
    alpha: tradeoff between sampling high priority transitions and random sampling
    beta: used to compute importance sampling weights, increased from beta_start to 1.0 over the course of beta_steps
    TL;DR of PER: https://medium.com/arxiv-bytes/summary-prioritized-experience-replay-e5f9257cef2d
    '''
    
    def __init__(self, max_size=500000, alpha=0.6, beta_start=0.4, beta_steps=100000):
        self.max_size = max_size
        self.replay_memory = []
        self.priorities = np.zeros((max_size,), dtype=np.float32)
        
        self.alpha = alpha
        self.beta = beta_start
        self.beta_incr = (1.0 - beta_start) / beta_steps
        
        self.index = 0
        self.priorities[0] = 1.0**alpha  # init the first max prob
        
    def update_beta(self):
        self.beta = min(1.0, self.beta + self.beta_incr)
        
    def get_probabilities(self):
        '''
        turn current priorities in probabilities
        '''
        size = len(self.replay_memory)
        end_index = size if size < self.max_size else self.index
        
        prios = self.priorities[:end_index]
        probs = prios / prios.sum()
        
    def insert(self, transition):
        
        # add the transition to the memory
        if len(self.replay_memory) < self.max_size:
            self.replay_memory.append(transition)
        else:
            self.replay_memory[self.index] = transition
            
        # update priorities and index
        self.priorities[self.index] = self.priorities.max()
        self.index = (self.index + 1) % self.max_size
        
    def sample(self, batch_size):
        '''
        sample a batch of transitions
        TODO: weigths
        '''
        current_size = len(replay_buffer)
        
        # samples
        probs = self.get_probabilities()
        indices = np.random.choice(current_size, batch_size, p=probs)
        samples = [self.replay_memory[i] for i in indices]
        #samples = random.sample(self.replay_memory, batch_size)
        
        # importance sampling weights
        prob_min = probs.min()
        max_weight = (prob_min * total)**(-beta)

        is_weights  = (current_size * probs[indices]) ** (-self.beta)
        is_weights /= max_weight  # to ensure it is not > 1
        
        self.update_beta()
        return samples, is_weights
        

## Main DQN function

In [None]:
def dqn(sess,
        env):
    
    replay_memory = PrioritizedExperienceReplay()

## Main entrypoint

In [None]:
tf.reset_default_graph()

# Where we save our checkpoints and graphs
exp_folder = "experiments/exp%s" % (datetime.datetime.now().strftime("%Y%m%d%H%M"))

# main run
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for t, stats in dqn(sess,
                        env):