In [1]:
import os, sys, time, logging, warnings
from pathlib import Path

import numpy as np
import seaborn as sns

import tensorflow as tf
from tensorflow.keras import Model

import ray, gym

from tensorflow_probability.python.distributions import MultivariateNormalDiag, TransformedDistribution, Categorical
import tensorflow_probability as tfp

In [2]:
tf.get_logger().setLevel('ERROR')
logging.disable(logging.WARNING)
tf.keras.backend.set_floatx('float64')

In [3]:
class ReplayBuffer:
    """FIFO experience replay buffer."""
    
    def __init__(self, size, obs_shape, action_shape):
        self.size = size        
        self.total = 0
        self.data = {
            'state': np.zeros((size, *obs_shape), dtype=np.float64),
            'action': np.zeros((size, *action_shape), dtype=np.float64),
            'reward': np.zeros((size, 1), dtype=np.float64),
            'next_state': np.zeros((size, *obs_shape), dtype=np.float64),
            'done': np.zeros((size, 1), dtype=np.float64)
        }
    
    @property
    def num_stored(self):
        return self.total if self.total < self.size else self.size
    
    @property
    def cur_idx(self):
        return self.num_stored % self.size
    
    def store(self, state, action, reward, next_state, done):
        self.total += 1
        self.data['state'][self.cur_idx] = state
        self.data['action'][self.cur_idx] = action
        self.data['reward'][self.cur_idx] = reward
        self.data['next_state'][self.cur_idx] = next_state
        self.data['done'][self.cur_idx] = done
        
    def sample_list_of_datasets(self, 
                                batch_size, 
                                dataset_size,
                                map_func=None,
                                prioritized_sampling=False,
                                prio_sampling_degree=0.2):
        if prioritized_sampling:
            abs_reward = np.abs(self.data['reward'][:self.num_stored]).squeeze()
            probs =  abs_reward**prio_sampling_degree / np.sum(abs_reward**prio_sampling_degree)
            sample_dist = Categorical(probs=probs)
            indices = sample_dist.sample((dataset_size,)).numpy()
            entropy_diff = Categorical(np.repeat(1/self.num_stored, self.num_stored)).entropy() - sample_dist.entropy()
            tf.summary.scalar('metrics/prio_sample_dist_entropy_diff', 
                              entropy_diff, 
                              description="Difference in entropy between uniform and prioritized sampling distribution.")
        else:
            indices = np.random.randint(0, self.num_stored, dataset_size)

        dataset_list = []
        for k in self.data.keys():
            dataset_list.append(tf.data.Dataset.from_tensor_slices(
                    tf.convert_to_tensor(self.data[k][indices])
                ).batch(batch_size, drop_remainder=True).map(map_func))
        return dataset_list

In [4]:
# use clipped log sigma for stable training
clip_log_sigma = (-20., 2.)

class ActorNetwork(Model):
    def __init__(self,
                 input_shape=None,
                 action_space=2, 
                 normalize_mean=None,
                 normalize_sd=None):
        super(ActorNetwork, self).__init__()
        
        self.action_space = action_space
        if normalize_mean is not None:
            assert normalize_sd is not None
            assert normalize_sd.shape == input_shape
            assert normalize_mean.shape == input_shape
        self.normalize_mean = normalize_mean
        self.normalize_sd = normalize_sd
        
        self.mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(256, input_shape=input_shape,
                                  activation='relu', kernel_regularizer='l2'),
            tf.keras.layers.Dense(256, activation='relu', kernel_regularizer='l2'),
        ])
        self.fc_mu = tf.keras.layers.Dense(action_space, use_bias=False)
        self.fc_log_sigma = tf.keras.layers.Dense(action_space, use_bias=False)
        
    def call(self, x):
        if self.normalize_mean is not None:
            x = (x - self.normalize_mean) / self.normalize_sd
            
        x = self.mlp(x)
        output = {}
        output["mu"] = self.fc_mu(x)
        log_sigma = tf.clip_by_value(self.fc_log_sigma(x), *clip_log_sigma)
        output["sigma"] = tf.math.exp(log_sigma)
        return output
    
    def act(self, x, return_log_prob=False, random_agent=False):
        if x.ndim == 1:
            x = x[np.newaxis,]
        net_out = self.call(x)
        output = {}
        if not random_agent:
            mus, sigmas = net_out["mu"], net_out["sigma"]

            action_distribution = TransformedDistribution(
                MultivariateNormalDiag(loc=mus, scale_diag=sigmas),
                tfp.bijectors.Tanh())

            # bounded by [-1, 1]
            action = action_distribution.sample()
            output["action"] = tf.squeeze(action)

            if return_log_prob:
                output["log_probability"] = action_distribution.log_prob(action)[:,tf.newaxis]
        else:
            output["action"] = np.random.uniform(-1., 1., (self.action_space,))
        return output

In [5]:
def get_q_mlp(units=(256, 256), activation='relu', kernel_regularizer=None):
    layers = [tf.keras.layers.Dense(u, activation=activation, kernel_regularizer=kernel_regularizer)
              for u in units]
    layers.append(tf.keras.layers.Dense(1))
    return tf.keras.Sequential(layers)
        
class DoubleQNetwork(Model):
    def __init__(self,
                 action_space=2,
                 normalize_mean=None,
                 normalize_sd=None):
        
        if normalize_mean is not None:
            assert normalize_sd is not None
        self.normalize_mean = normalize_mean
        self.normalize_sd = normalize_sd
        self.action_space = action_space
        
        super(DoubleQNetwork, self).__init__()
        self.qnet_1 = get_q_mlp()
        self.qnet_2 = get_q_mlp()
        
    def call(self, state, action):
        x = tf.concat((state, action), axis=1)

        if self.normalize_mean is not None:
            x = (x - self.normalize_mean) / self.normalize_sd
            
        return {"q1": self.qnet_1(x), 
                "q2": self.qnet_2(x)}
    
    def update_normalization(self, normalize_mean, normalize_sd):
        self.normalize_mean = normalize_mean
        self.normalize_sd = normalize_sd

In [6]:
env_id = "LunarLanderContinuous-v2"
env = gym.make(env_id)
env.action_space.low, env.action_space.high

(array([-1., -1.], dtype=float32), array([1., 1.], dtype=float32))

In [7]:
train_every_n_step = 1

buffer_size = 100000
warmup_steps = 4000
total_timesteps = 500000
epoch_length = 1000
epochs = total_timesteps // epoch_length
log_every_n_step = 100
test_every = 1


optim_batch_size = 256
# lock ratio between env steps and gradient updates to 1
sample_size = optim_batch_size * train_every_n_step
# discount factor
gamma = 0.99
# entropy coefficient
alpha = 0.6
# polyak averaging
polyak = 0.995
# exponential moving average
alpha_exp_avg = 0.005
# whether to oversample SARS pairs with extreme (sparse) rewards
prioritized_sampling = True
prio_sampling_degree = 0.2
# whether to normalize inputs (zero mean, one std)
normalize_inputs = False

model_kwargs = {
    "input_shape": env.observation_space.shape,
    "action_space": env.action_space.shape[0],
    "normalize_mean": np.zeros(env.observation_space.shape),
    "normalize_sd": np.ones(env.observation_space.shape)
}

qnet_kwargs = {
    "action_space": env.action_space.shape[0],
    "normalize_mean": np.zeros((env.observation_space.shape[0] + env.action_space.shape[0],)),
    "normalize_sd": np.ones((env.observation_space.shape[0] + env.action_space.shape[0],))
}

In [8]:
# put update into function to make loop less cluttered
def update_normalization(moving_mean, moving_std):
    actor.normalize_mean = moving_mean[:env.observation_space.shape[0]]
    actor.normalize_sd = moving_std[:env.observation_space.shape[0]]
    update_DoubleQNetwork.update_normalization(moving_mean, moving_std)
    target_DoubleQNetwork.update_normalization(moving_mean, moving_std)

In [9]:
def warmup(env, actor, buffer, steps):
    state = env.reset()
    for _ in range(steps):
        action = actor.act(state, random_agent=True)['action']
        next_state, reward, done, info = env.step(action)
        buffer.store(state, action, reward, next_state, done)
        state = next_state
        if done:
            state = env.reset()
    return env.reset()

In [10]:
def test(env_id, actor, test_episodes=10, render=False):
    test_env = gym.make(env_id)
    returns = []
    for _ in range(test_episodes):
        test_return = 0
        state = test_env.reset()
        done = False
        while not done:
            if render:
                env.render()
            action = actor.act(state)['action']
            state, reward, done, info = test_env.step(action)
            test_return += reward
        returns.append(test_return)
    test_env.close()
    return returns

In [11]:
#TODO normalize q-values

In [12]:
tf.keras.backend.clear_session()
restore_ckpt = None#saving_path.parent/"model_08-03-2021_10-58-58_AM"/'checkpoints'
restart = True
while restart:
    now = time.strftime("%d-%m-%Y_%I-%M-%S_%p", time.gmtime())
    # initialize buffer
    buffer = ReplayBuffer(buffer_size, env.observation_space.shape, env.action_space.shape)
    # where to save the results
    saving_path = Path(os.getcwd() + "/../Homework/A3/progress_test/" + f"model_{now}")
    
    # initialize actor
    actor = ActorNetwork(**model_kwargs)
    dummy = (np.zeros((1, *env.observation_space.shape)), np.zeros((1, *env.action_space.shape)))
    actor(dummy[0])
    
    # initialize update and target double q networks
    update_DoubleQNetwork = DoubleQNetwork(**qnet_kwargs)
    update_DoubleQNetwork(*dummy)
    # get seperate network for delayed DQN and copy weights
    target_DoubleQNetwork = DoubleQNetwork(**qnet_kwargs)
    target_DoubleQNetwork(*dummy)
    target_DoubleQNetwork.set_weights(update_DoubleQNetwork.get_weights())
    # disable gradients for target q-nets
    for layer in target_DoubleQNetwork.layers:
        layer.trainable = False
    
    # init optimizer
    critic_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
    actor_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
    # map function for datasets
    reshape_and_cast = lambda x: tf.cast(tf.reshape(x, (optim_batch_size, -1)), tf.float64)
    
    step, episode = tf.Variable(0, dtype=tf.int64), 0
    
    # init checkpoints
    ckpt_q_update = tf.train.Checkpoint(model=update_DoubleQNetwork, optimizer=critic_optimizer, step=step)
    ckpt_q_target = tf.train.Checkpoint(model=target_DoubleQNetwork, step=step)
    ckpt_actor = tf.train.Checkpoint(model=actor, optimizer=actor_optimizer, step=step)

    if restore_ckpt is not None:
        ckpt_q_update.restore(tf.train.latest_checkpoint(restore_ckpt/'critic'/'update'))#.assert_consumed()
        ckpt_q_target.restore(tf.train.latest_checkpoint(restore_ckpt/'critic'/'target')).assert_consumed()
        ckpt_actor.restore(tf.train.latest_checkpoint(restore_ckpt/'actor'))#.assert_consumed()
        start_steps = int(step)
    
    writer = tf.summary.create_file_writer(str(saving_path.parent/'logs'/f"model_{now}"))
    
    start_time = time.time()
    critic_losses, actor_losses, q_values, q_targets, log_probabilities = [], [], [], [], []
    moving_mean, moving_reward_avg, moving_std, moving_reward_std = 0, 0, 0, 0
    with writer.as_default(step=step):
        for e in range(epochs):
            # run agents
            if e == 0:
                cur_state = warmup(env, actor, buffer, warmup_steps)
                episode += np.sum(buffer.data['done'])
            for i in range(epoch_length):
                for _ in range(train_every_n_step):
                    action = actor.act(cur_state)['action']
                    next_state, reward, done, info = env.step(action)
                    buffer.store(cur_state, action, reward, next_state, done)
                    cur_state = next_state
                    if done:
                        cur_state = env.reset()
                        episode += 1

                # update moving average
                if normalize_inputs:
                    state_action = np.concatenate((buffer.data['state'][:buffer.num_stored], buffer.data['action'][:buffer.num_stored]), axis=1)
                    moving_mean = alpha_exp_avg if (e+i)!=0 else 1 * (np.mean(state_action, axis=0) - moving_mean)
                    moving_std =  alpha_exp_avg if (e+i)!=0 else 1 * (np.std(state_action, axis=0) - moving_std)
                    update_normalization(moving_mean, moving_std)

                # sample data as tf datasets to optimize on from buffer
                data_list = buffer.sample_list_of_datasets(
                    optim_batch_size, 
                    sample_size,
                    map_func=reshape_and_cast,
                    prioritized_sampling=prioritized_sampling,
                    prio_sampling_degree=prio_sampling_degree)

                for i, (state, action, reward, next_state, done) in enumerate(zip(*data_list)):
                    # sample next action from policy
                    output = actor.act(next_state, return_log_prob=True)
                    # get Q-values for next state of delayed network
                    target_qs = target_DoubleQNetwork(next_state, output['action'])
                    q_value_next = tf.reduce_min([target_qs['q1'], target_qs['q2']], axis=0)
                    # compute target
                    q_target = reward + (1 - done) * gamma * (q_value_next - alpha * output["log_probability"])
                    # record forward pass
                    with tf.GradientTape() as tape:
                        qs_pred = update_DoubleQNetwork(state, action)
                        loss_q1 = tf.reduce_mean(tf.keras.losses.MSE(q_target, qs_pred['q1']))
                        loss_q2 = tf.reduce_mean(tf.keras.losses.MSE(q_target, qs_pred['q2']))

                    # backpropagte loss
                    gradients = tape.gradient([loss_q1, loss_q2], update_DoubleQNetwork.trainable_variables)
                    # update q-network
                    critic_optimizer.apply_gradients(zip(gradients, update_DoubleQNetwork.trainable_variables))
                    
                    # disable op tracking for update q-nets temporarily
                    for layer in update_DoubleQNetwork.layers:
                        layer.trainable = False

                    # record forward pass
                    with tf.GradientTape() as tape:
                        # sample action with updated policy
                        output = actor.act(state, return_log_prob=True)
                        qs = update_DoubleQNetwork(state, output['action'])
                        q_value = tf.reduce_min([qs['q1'], qs['q2']], axis=0)
                        actor_loss = tf.reduce_mean(alpha * output["log_probability"] - q_value)

                    q_targets.append(np.mean(q_target))
                    q_values.append(np.mean(q_value))

                    # backpropagate policy gradient
                    gradients = tape.gradient(actor_loss, actor.trainable_variables)
                    # update policy network
                    if not np.isnan(np.mean([np.mean(grad) for grad in gradients])):
                        actor_optimizer.apply_gradients(zip(gradients, actor.trainable_variables))
                    else:
                        "Skipping nan gradient."
                        
                    # reenable op tracking for update q-nets
                    for layer in update_DoubleQNetwork.layers:
                        layer.trainable = True

                    critic_losses.append(float(loss_q1 + loss_q2))
                    actor_losses.append(float(actor_loss))
                    log_probabilities.append(np.mean(output["log_probability"]))

                    # update delayed network
                    old = [polyak * weight for weight in target_DoubleQNetwork.get_weights()]
                    new = [(1 - polyak) * weight for weight in update_DoubleQNetwork.get_weights()]
                    target_DoubleQNetwork.set_weights([old_w + new_w for old_w, new_w in zip(old, new)])

                    step.assign_add(1)
                if int(step) % (log_every_n_step) == 0:
                    tf.summary.scalar('actor/actor_loss', np.mean(actor_losses[-log_every_n_step:]))
                    tf.summary.scalar('critic/critic_loss', np.mean(critic_losses[-log_every_n_step:]))
                    tf.summary.scalar('actor/log_prob', np.mean(log_probabilities[-log_every_n_step:]))
                    tf.summary.scalar('critic/q_vals', np.mean(q_values[-log_every_n_step:]))
                    tf.summary.scalar('critic/q_targets', np.mean(q_targets[-log_every_n_step:]))

            print(f"Update step: {int(step)}")

            if e % test_every == 0:
                # print progress
                returns = test(env_id, actor, test_episodes=6, render=False)
                returns.extend(test(env_id, actor, test_episodes=4, render=True))
                mean_return = np.mean(returns)
                
                tf.summary.scalar('metrics/test_return', mean_return)
                print(
                    f"epoch ::: {e}   episode ::: {episode}   update step ::: {int(step)}   time elapsed ::: {time.strftime('%H:%M:%S', time.gmtime((time.time() - start_time)))}",   
                    f"critic loss avg ::: {np.round(np.mean(critic_losses), 2)}   min ::: {np.round(np.min(critic_losses), 2)}    max ::: {np.round(np.nanmax(critic_losses), 2)}",   
                    f"actor loss avg ::: {np.round(np.mean(actor_losses), 2)}   min ::: {np.round(np.min(actor_losses), 2)}   max ::: {np.round(np.nanmax(actor_losses), 2)}", 
                    f"q-vals avg ::: {np.round(np.mean(q_values), 2)}   min ::: {np.round(np.nanmin(q_values), 2)}   max ::: {np.round(np.nanmax(q_values), 2)}", 
                    f"log_prob avg ::: {np.round(np.mean(log_probabilities), 2)}   min ::: {np.round(np.min(log_probabilities), 2)}   max ::: {np.round(np.nanmax(log_probabilities), 2)}",
                    f"env return avg ::: {np.round(mean_return, 2)}   buffer size ::: {buffer.num_stored}",
                    f"current lr ::: {np.round(actor_optimizer._decayed_lr('float32').numpy(), 5)}",
                    sep='\n'
                )
                if mean_return >= 200:
                    restart = False

            # early stopping metric
            if (e % (test_every * 5) == 0) and (e != 0):
                now = time.strftime("%d-%m-%Y_%I-%M-%S_%p", time.gmtime())
                ckpt_q_update.save(saving_path/'checkpoints'/'critic'/'update'/f"model_{e}_{now}")
                ckpt_q_target.save(saving_path/'checkpoints'/'critic'/'target'/f"model_{e}_{now}")
                ckpt_actor.save(saving_path/'checkpoints'/'actor'/f"model_{e}_{now}")
                if (mean_return < old_return_avg) and (e < 25):
                    tf.keras.backend.clear_session()
                    print("Early stop! Resetting...")
                    break

            if (e % (test_every * 5) == 0) or (e == 0):
                old_return_avg = mean_return


            if e % test_every == 0:
                critic_losses, actor_losses, q_values, q_targets, log_probabilities = [], [], [], [], []
            
            writer.flush()

Update step: 1000
epoch ::: 0   episode ::: 42.0   update step ::: 1000   time elapsed ::: 00:03:03
critic loss avg ::: 184.11   min ::: 10.32    max ::: 1006.13
actor loss avg ::: 5.06   min ::: -0.65   max ::: 11.68
q-vals avg ::: -5.55   min ::: -11.95   max ::: -0.14
log_prob avg ::: -0.81   min ::: -1.39   max ::: -0.21
env return avg ::: -271.6   buffer size ::: 5000
current lr ::: 0.0003000000142492354
Update step: 2000
epoch ::: 1   episode ::: 48.0   update step ::: 2000   time elapsed ::: 00:06:05
critic loss avg ::: 70.97   min ::: 7.86    max ::: 265.54
actor loss avg ::: 9.42   min ::: 4.19   max ::: 15.71
q-vals avg ::: -9.58   min ::: -15.79   max ::: -4.5
log_prob avg ::: -0.27   min ::: -0.7   max ::: 0.15
env return avg ::: -267.4   buffer size ::: 6000
current lr ::: 0.0003000000142492354
Update step: 3000
epoch ::: 2   episode ::: 55.0   update step ::: 3000   time elapsed ::: 00:09:07
critic loss avg ::: 49.59   min ::: 9.06    max ::: 150.71
actor loss avg ::: 12.

epoch ::: 20   episode ::: 90.0   update step ::: 21000   time elapsed ::: 01:21:13
critic loss avg ::: 32.21   min ::: 5.65    max ::: 184.07
actor loss avg ::: -5.55   min ::: -14.79   max ::: 3.74
q-vals avg ::: 5.36   min ::: -3.88   max ::: 14.51
log_prob avg ::: -0.32   min ::: -0.62   max ::: -0.01
env return avg ::: 13.88   buffer size ::: 25000
current lr ::: 0.0003000000142492354
Early stop! Resetting...
Update step: 1000
epoch ::: 0   episode ::: 44.0   update step ::: 1000   time elapsed ::: 00:03:05
critic loss avg ::: 206.42   min ::: 13.11    max ::: 981.42
actor loss avg ::: 5.8   min ::: -0.58   max ::: 11.25
q-vals avg ::: -6.31   min ::: -11.53   max ::: -0.2
log_prob avg ::: -0.85   min ::: -1.38   max ::: -0.27
env return avg ::: -232.67   buffer size ::: 5000
current lr ::: 0.0003000000142492354
Update step: 2000
epoch ::: 1   episode ::: 50.0   update step ::: 2000   time elapsed ::: 00:06:44
critic loss avg ::: 85.42   min ::: 10.45    max ::: 242.88
actor loss 

Update step: 9000
epoch ::: 8   episode ::: 63.0   update step ::: 9000   time elapsed ::: 00:35:03
critic loss avg ::: 28.44   min ::: 5.84    max ::: 116.51
actor loss avg ::: 6.66   min ::: -3.42   max ::: 15.38
q-vals avg ::: -6.72   min ::: -15.36   max ::: 3.35
log_prob avg ::: -0.1   min ::: -0.5   max ::: 0.2
env return avg ::: -84.25   buffer size ::: 13000
current lr ::: 0.0003000000142492354
Update step: 10000
epoch ::: 9   episode ::: 64.0   update step ::: 10000   time elapsed ::: 00:39:01
critic loss avg ::: 24.96   min ::: 6.27    max ::: 108.61
actor loss avg ::: 5.1   min ::: -2.62   max ::: 12.81
q-vals avg ::: -5.15   min ::: -12.81   max ::: 2.51
log_prob avg ::: -0.09   min ::: -0.45   max ::: 0.28
env return avg ::: -4.82   buffer size ::: 14000
current lr ::: 0.0003000000142492354
Update step: 11000
epoch ::: 10   episode ::: 65.0   update step ::: 11000   time elapsed ::: 00:42:30
critic loss avg ::: 23.92   min ::: 6.04    max ::: 103.2
actor loss avg ::: 5.63 

KeyboardInterrupt: 