In [None]:
###############################################
###############################################
###############################################

# Generic imports
import os
import warnings
import random

# Import tensorflow and filter warning messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '10'
warnings.filterwarnings('ignore',category=FutureWarning)
import tensorflow                    as     tf
import tensorflow_probability        as     tfp
from   tensorflow.keras              import Model
from   tensorflow.keras.layers       import Dense
from   tensorflow.keras.optimizers   import Adam

# Define alias
tfd = tfp.distributions

###############################################
### Q-network
class q_network(Model):
    def __init__(self, arch, act_dim, lr):
        super(q_network, self).__init__()

        # Define network
        self.ac = []
        for layer in range(len(arch)):
            self.ac.append(Dense(arch[layer],
                                 activation = 'relu'))
        self.ac.append(Dense(act_dim,
                             activation = 'linear'))

        # Define optimizer
        self.opt = Adam(learning_rate = lr,
                        clipvalue     = 100.0)

    # Network forward pass
    def call(self, state):

        # Copy inputs
        var = state

        # Compute output
        for layer in range(len(self.ac)):
            var = self.ac[layer](var)

        return var

###############################################
### Ring buffer

##########
# TO DO
# Implement a ring buffer class:
# - That contains a fixed-size list of size ring_size
# - With an "append" function to add an element to it
# - With a "get_mini_batch" function to obtain a shuffled
#   mini batch from an input set of indices
# - With a "check" function that returns the number of elements in it
##########
class ring_buffer():
    # Your buffer must have a fixed lenght, and be a list
    # Check how to implement an empty list of given size
    # You will need to declare a few other things to make it work
    def __init__(self, ring_size):


    # Append element
    # The input element must be added to the buffer in a circular way
    # If the end of the buffer is reached, the new element must overwrite
    # the oldest element in the buffer (hence circular)
    def append(self, elt):


    # Get mini-batch given a list of indices
    # Given a list of indices (in the correct range that depends on the
    # current filling level of the buffer), return a mini-batch with the
    # corresponding elements
    def get_mini_batch(self, idx):


    # Return global index
    # This function must return the filling level of the buffer
    # If the buffer has already been filled entirely once, return the global size
    # If not, return the number of elements that have already been added
    def check(self):


###############################################
###############################################
###############################################

# Generic imports
import gym
import math
import numpy as np
import copy  as cp

###############################################
### A DQN agent
class dqn_agent:
    def __init__(self, act_dim, obs_dim, lr, arch, gamma, mem_size,
                 batch_size, target_freq, eps_start, eps_end, n_ep_decay):

        # Initialize from arguments
        self.act_dim     = act_dim
        self.obs_dim     = obs_dim
        self.lr          = lr
        self.arch        = arch
        self.gamma       = gamma
        self.mem_size    = mem_size
        self.batch_size  = batch_size
        self.target_freq = target_freq
        self.eps_start   = eps_start
        self.eps_end     = eps_end
        self.n_ep_decay  = n_ep_decay
        self.eps         = eps_start

        # Initialize target update counter
        self.tgt_update  = 0

        # Initialize buffers
        self.buff_rwd = ring_buffer(self.mem_size)
        self.buff_obs = ring_buffer(self.mem_size)
        self.buff_nxt = ring_buffer(self.mem_size)
        self.buff_act = ring_buffer(self.mem_size)
        self.buff_trm = ring_buffer(self.mem_size)

        # Build networks
        self.q_net = q_network(self.arch,
                               self.act_dim,
                               self.lr)
        self.q_tgt = q_network(self.arch,
                               self.act_dim,
                               self.lr)

        # Init parameters
        dummy = self.q_net(tf.ones([1,self.obs_dim]))
        dummy = self.q_tgt(tf.ones([1,self.obs_dim]))

    # Get actions from network
    def get_actions(self, obs):

        # Handle epsilon-greedy strategy
        p = random.uniform(0, 1)
        if (p < self.eps):
            action = random.randrange(0, self.act_dim)
        else:
            obs    = tf.cast([obs], tf.float32)
            values = self.q_net(obs)
            values = values.numpy()
            action = np.argmax(values)

        return action

    # Update epsilon with linear decay
    def update_epsilon(self, ep):

        r        = min(float(ep/self.n_ep_decay),1.0)
        self.eps = self.eps_start + r*(self.eps_end-self.eps_start)

    # Train network
    def train(self):

        # Check that ring buffer has enough samples
        n = self.buff_rwd.check()
        if (n < self.batch_size ): return

        # Get values from ring buffers
        idx = random.sample(range(0,n),self.batch_size)
        rwd = self.buff_rwd.get_mini_batch(idx)
        obs = self.buff_obs.get_mini_batch(idx)
        nxt = self.buff_nxt.get_mini_batch(idx)
        act = self.buff_act.get_mini_batch(idx)
        trm = self.buff_trm.get_mini_batch(idx)

        # Cast and reshape
        rwd = tf.cast(rwd, tf.float32)
        obs = tf.cast(obs, tf.float32)
        nxt = tf.cast(nxt, tf.float32)
        act = tf.cast(act, tf.int32)
        trm = tf.cast(trm, tf.float32)

        rwd = tf.reshape(rwd, [-1,1])
        obs = tf.reshape(obs, [-1,self.obs_dim])
        nxt = tf.reshape(nxt, [-1,self.obs_dim])
        act = tf.reshape(act, [-1,1])
        trm = tf.reshape(trm, [-1,1])

        # Train
        self.train_dqn(rwd, obs, nxt, act, trm)

        # Update target if necessary
        if (self.tgt_update == self.target_freq):
            self.q_tgt.set_weights(self.q_net.get_weights())
            self.tgt_update  = 0
        else:
            self.tgt_update += 1

    # Unroll episode
    def unroll_episode(self, env, score, ep, n_ep):

        # Reset observation and done flag
        obs  = env.reset()
        done = False

        # Loop
        while (not done):

            # Make one iteration
            act               = self.get_actions(obs)
            nxt, rwd, done, _ = env.step(act)

            # Store in buffers
            self.buff_obs.append(obs)
            self.buff_nxt.append(nxt)
            self.buff_rwd.append(rwd)
            self.buff_act.append(act)
            self.buff_trm.append(float(done))

            # Update observation
            obs = nxt

            # Update score
            score[ep] += rwd

            # Train agent
            self.train()

        # Print
        self.print_episode(ep, n_ep, score)

        # Update epsilon
        self.update_epsilon(ep)

    # Training function for actor
    @tf.function
    def train_dqn(self, rwd, obs, nxt, act, trm):
        with tf.GradientTape() as tape:

            # Compute loss
            tgt  = tf.reshape(tf.reduce_max(self.q_tgt(nxt),axis=1), [-1,1])
            tgt  = rwd + (1.0-trm)*self.gamma*tgt
            val  = tf.gather(self.q_net(obs), act, axis=1, batch_dims=1)
            diff = tf.square(tgt - val)
            loss = tf.reduce_mean(diff)

            # Apply gradients
            q_var = self.q_net.trainable_variables
            grads = tape.gradient(loss, q_var)

        self.q_net.opt.apply_gradients(zip(grads,q_var))

    # Printings at the end of an episode
    def print_episode(self, ep, n_ep, score):

        if (ep < 1) or (not (ep%50 == 0)): return
        lgt = min(ep, 25)
        avg = np.mean(score[ep-lgt:ep])
        avg = f"{avg:.3f}"

        print('# Ep #'+str(ep)+', avg score = '+str(avg), end='\n')

###############################################
###############################################
###############################################

# Generic imports
import time

# Process training
def train(run, env_name, n_ep, lr, arch, gamma, mem_size, batch_size,
          update_freq, eps_start, eps_end, n_ep_decay):

    # Declare environement and agent
    env     = gym.make(env_name)
    act_dim = env.action_space.n
    obs_dim = env.observation_space.shape[0]
    agent   = dqn_agent(act_dim, obs_dim, lr, arch, gamma,
                        mem_size, batch_size, update_freq,
                        eps_start, eps_end, n_ep_decay)
    score   = np.zeros(n_ep)

    # Loop until max episode number is reached
    for ep in range(n_ep):
        agent.unroll_episode(env, score, ep, n_ep)

    # Close environments
    env.close()

    # Return array of episode scores
    return score

###############################################
###############################################
###############################################

# Generic imports
import matplotlib.pyplot as plt

# Plot avg/std of score as a function of episodes
def plot_score(score, title):

    ##########
    # TO DO
    # - Compute a sliding average of the scores (optional)
    # - Compute avg and std of score over n_avg runs
    # - Plot avg, avg-std, avg+std as a function of episodes
    ##########


###############################################
###############################################
###############################################

# Parameters
env_name    = "CartPole-v0" # Name of the environment
n_run       = 2             # Nb of runs for results averaging
n_ep        = 1000          # Nb of episodes in each run
lr          = 5.0e-4        # Actor learning rate
arch        = [32,32]       # Actor architecture
gamma       = 0.99          # Discount value
mem_size    = 10000         # Size of memory replay
batch_size  = 64            # Batch size for training
target_freq = 100           # Update frequency for target network
eps_start   = 0.1           # Initial epsilon-greedy value
eps_end     = 0.05          # Final   epsilon-greedy value
n_ep_decay  = 500           # Nb of episodes on which decay happens
title       = 'CartPole-v0, dqn'

# Perform multiple runs
score = np.zeros((n_run, n_ep))
for ep in range(n_run):
    print('### Avg run #'+str(ep))
    start_time   = time.time()
    score[ep, :] = train(ep, env_name, n_ep, lr, arch, gamma,
                         mem_size, batch_size, target_freq,
                         eps_start, eps_end, n_ep_decay)
    ctime        = time.time() - start_time
    ctime        = f"{ctime:.3f}"
    print("--- "+ctime+" seconds ---")

# Plot score
plot_score(score, title)