In [2]:
import gym
from datetime import datetime
import numpy as np
from scipy.stats import entropy

# DR TRPO related files
from train_helper import *
from value import NNValueFunction
from utils import Logger
from dr_policy import DRPolicyKL, DRPolicyWass, DRPolicySinkhorn

import tensorflow as tf
import tensorflow.contrib.distributions as dist
import tensorflow.contrib.layers as layers

## Discriminator Network Setup 

In [4]:
class Discriminator:
    def __init__(self, sess, hidden_size, lr, name):
        self.sess = sess
        self.hidden_size = hidden_size
        self.lr = lr
        self.name = name

        self.ob_ac = tf.placeholder(dtype=tf.float32, shape=[None, 2])
        
        with tf.variable_scope(name):
            self._build_network()

    def _build_network(self):
        with tf.variable_scope('discriminator'):
            d_h1 = layers.fully_connected(self.ob_ac, self.hidden_size, activation_fn=tf.tanh)
            d_h2 = layers.fully_connected(d_h1, self.hidden_size, activation_fn=tf.tanh)
            d_out = layers.fully_connected(d_h2, 1, activation_fn=None)

        self.reward = - tf.squeeze(tf.log(tf.sigmoid(d_out)))
        
        expert_out, policy_out = tf.split(d_out, num_or_size_splits=2, axis=0)

        self.loss = (tf.losses.sigmoid_cross_entropy(tf.ones_like(policy_out), policy_out)
                     + tf.losses.sigmoid_cross_entropy(tf.zeros_like(expert_out), expert_out))
        
        with tf.name_scope('train_op'):
            grads = tf.gradients(self.loss, self.params())
            self.grads = list(zip(grads, self.params()))
            self.train_op = tf.train.AdamOptimizer(self.lr).apply_gradients(self.grads)

    def params(self):
        return tf.global_variables(self.name).copy()

    def get_reward(self, expert_ob_ac):
        feed_dict = {self.ob_ac: expert_ob_ac}

        return self.sess.run(self.reward, feed_dict=feed_dict)

    def update(self, all_ob_ac):
        feed_dict = {self.ob_ac: all_ob_ac}

        self.sess.run(self.train_op, feed_dict=feed_dict)

# Discrete State Space - SPO + GAIL
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [8]:
config = tf.ConfigProto(
device_count={'GPU': 1},
intra_op_parallelism_threads=1,
allow_soft_placement=True
)
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
tf.keras.backend.set_session(sess)
discriminator = Discriminator(sess, 10, 0.01, 'D')
sess.run(tf.global_variables_initializer())
tf.reset_default_graph()

env_name = 'Taxi-v3'
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicySinkhorn(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
gae_weight = 1
total_eps = 2000
batch_eps = 10
timesteps = 0
logger = Logger(logname=env_name + '_DR-Sinkhorn_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        # weight for RL
        lamb = 0.9
        # the randomness of the actions an agent takes can be quantified by the entropy
        entro = entropy(policy.distributions)
        trajectories = run_policy(env, policy, batch_eps, discriminator, lamb, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)
        # calculate advantage
        add_gae(trajectories, gamma, gae_weight)
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        timesteps += len(observes)
        logger.log({'_Timesteps': timesteps})
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        policy.update(observes, actions, advantages, disc_freqs, env_name, eps)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True)
        
        policy_ob_ac = np.stack((observes, actions),1)
        data = np.load('expert_traj/expert_taxi.npz')
        expert_obs = data['obs'].T[0]
        expert_actions = data['actions'].T[0]
        expert_ob_ac = np.stack((expert_obs, expert_actions),1)
        min_len = min(len(expert_ob_ac), len(policy_ob_ac))
        discriminator.update(np.concatenate([expert_ob_ac[:min_len], policy_ob_ac[:min_len]], axis=0))
logger.close()
sess.close()



Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
***** Episode 10, Timesteps 1871, Mean Return = -658.6  *****
ExplainedVarNew: -2.21e-09
ExplainedVarOld: -3.7e-05
ValFuncLoss: 941


***** Episode 20, Timesteps 3871, Mean Return = -687.4  *****
ExplainedVarNew: -2.85e-10
ExplainedVarOld: -2.79e-11
ValFuncLoss: 750


***** Episode 30, Timesteps 5871, Mean Return = -638.7  *****
ExplainedVarNew: -9e-09
ExplainedVarOld: -1.66e-07
ValFuncLoss: 478


***** Episode 40, Timesteps 7871, Mean Return = -677.6  *****
ExplainedVarNew: -1.41e-09
ExplainedVarOld: -4.98e-09
ValFuncLoss: 460


***** Episode 50, Timesteps 9871, Mean Return = -652.7  *****
ExplainedVarNew: -1.2e-10
ExplainedVarOld: -2.65e-10
ValFuncLoss: 330


***** Episode 60, Timesteps 11871, Mean Return = -608.4  *****
ExplainedVarNew: -1e-11
ExplainedVarOld: -1.54e-11
ValFuncLoss: 221


***** Episode 70, Timesteps 13867, Mean Return = -632.8  *****
ExplainedVarNew: -3.3e-10
ExplainedVarOld: -4.59e-10
ValFuncLoss: 211


***** Episod

***** Episode 610, Timesteps 109069, Mean Return = -178.0  *****
ExplainedVarNew: -0.000107
ExplainedVarOld: -7.27e-05
ValFuncLoss: 101


***** Episode 620, Timesteps 110407, Mean Return = -297.7  *****
ExplainedVarNew: -6.98e-05
ExplainedVarOld: -5.38e-06
ValFuncLoss: 120


***** Episode 630, Timesteps 111701, Mean Return = -269.5  *****
ExplainedVarNew: -0.000512
ExplainedVarOld: -0.000119
ValFuncLoss: 131


***** Episode 640, Timesteps 112329, Mean Return = -86.3  *****
ExplainedVarNew: -0.00104
ExplainedVarOld: -3.19e-05
ValFuncLoss: 139


***** Episode 650, Timesteps 113555, Mean Return = -231.9  *****
ExplainedVarNew: -0.00882
ExplainedVarOld: -0.00996
ValFuncLoss: 108


***** Episode 660, Timesteps 114469, Mean Return = -165.8  *****
ExplainedVarNew: -0.00436
ExplainedVarOld: -0.00527
ValFuncLoss: 108


***** Episode 670, Timesteps 115013, Mean Return = -102.6  *****
ExplainedVarNew: -0.00915
ExplainedVarOld: -0.00579
ValFuncLoss: 207


***** Episode 680, Timesteps 116102, Mean 

***** Episode 1230, Timesteps 164057, Mean Return = -138.8  *****
ExplainedVarNew: -0.0046
ExplainedVarOld: -0.00824
ValFuncLoss: 209


***** Episode 1240, Timesteps 164918, Mean Return = -191.9  *****
ExplainedVarNew: -0.0236
ExplainedVarOld: -0.0171
ValFuncLoss: 149


***** Episode 1250, Timesteps 166092, Mean Return = -272.5  *****
ExplainedVarNew: -0.0376
ExplainedVarOld: -0.0317
ValFuncLoss: 140


***** Episode 1260, Timesteps 166474, Mean Return = -64.0  *****
ExplainedVarNew: -0.0672
ExplainedVarOld: -0.0527
ValFuncLoss: 278


***** Episode 1270, Timesteps 167228, Mean Return = -183.0  *****
ExplainedVarNew: -0.0218
ExplainedVarOld: -0.0224
ValFuncLoss: 241


***** Episode 1280, Timesteps 168034, Mean Return = -179.6  *****
ExplainedVarNew: -0.0343
ExplainedVarOld: -0.0397
ValFuncLoss: 203


***** Episode 1290, Timesteps 168767, Mean Return = -168.0  *****
ExplainedVarNew: -0.0604
ExplainedVarOld: -0.0472
ValFuncLoss: 177


***** Episode 1300, Timesteps 169434, Mean Return = -15

***** Episode 1840, Timesteps 205713, Mean Return = -161.7  *****
ExplainedVarNew: -0.0302
ExplainedVarOld: -0.0305
ValFuncLoss: 268


***** Episode 1850, Timesteps 206216, Mean Return = -117.8  *****
ExplainedVarNew: -0.0547
ExplainedVarOld: -0.0629
ValFuncLoss: 307


***** Episode 1860, Timesteps 206582, Mean Return = -64.4  *****
ExplainedVarNew: -0.0909
ExplainedVarOld: -0.154
ValFuncLoss: 306


***** Episode 1870, Timesteps 207029, Mean Return = -96.1  *****
ExplainedVarNew: -0.0281
ExplainedVarOld: -0.0336
ValFuncLoss: 285


***** Episode 1880, Timesteps 207538, Mean Return = -115.2  *****
ExplainedVarNew: -0.0332
ExplainedVarOld: -0.0244
ValFuncLoss: 302


***** Episode 1890, Timesteps 208613, Mean Return = -281.6  *****
ExplainedVarNew: -0.0351
ExplainedVarOld: -0.0338
ValFuncLoss: 163


***** Episode 1900, Timesteps 209163, Mean Return = -115.6  *****
ExplainedVarNew: -0.0316
ExplainedVarOld: -0.0348
ValFuncLoss: 265


***** Episode 1910, Timesteps 209926, Mean Return = -184.2

# Discrete State Space - WPO + GAIL
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [9]:
config = tf.ConfigProto(
device_count={'GPU': 1},
intra_op_parallelism_threads=1,
allow_soft_placement=True
)
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
tf.keras.backend.set_session(sess)
discriminator = Discriminator(sess, 10, 0.01, 'D')
sess.run(tf.global_variables_initializer())
tf.reset_default_graph()

env_name = 'Taxi-v3'
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyWass(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
gae_weight = 1
total_eps = 2000
batch_eps = 10
timesteps = 0
logger = Logger(logname=env_name + '_DR-Wass_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        # weight for RL
        lamb = 0.9
        # the randomness of the actions an agent takes can be quantified by the entropy
        entro = entropy(policy.distributions)
        trajectories = run_policy(env, policy, batch_eps, discriminator, lamb, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, gae_weight)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        timesteps += len(observes)
        logger.log({'_Timesteps': timesteps})
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        policy.update(observes, actions, advantages, disc_freqs, env_name, eps)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True)
        
        policy_ob_ac = np.stack((observes, actions),1)
        data = np.load('expert_traj/expert_taxi.npz')
        expert_obs = data['obs'].T[0]
        expert_actions = data['actions'].T[0]
        expert_ob_ac = np.stack((expert_obs, expert_actions),1)
        min_len = min(len(expert_ob_ac), len(policy_ob_ac))
        discriminator.update(np.concatenate([expert_ob_ac[:min_len], policy_ob_ac[:min_len]], axis=0))
logger.close()
sess.close()



Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
***** Episode 10, Timesteps 2000, Mean Return = -719.0  *****
ExplainedVarNew: -2.12e-06
ExplainedVarOld: -3.59e-06
ValFuncLoss: 1.08e+03


***** Episode 20, Timesteps 4000, Mean Return = -581.8  *****
ExplainedVarNew: -6.88e-08
ExplainedVarOld: -1.97e-06
ValFuncLoss: 550


***** Episode 30, Timesteps 6000, Mean Return = -649.5  *****
ExplainedVarNew: -1.67e-08
ExplainedVarOld: -7.88e-08
ValFuncLoss: 541


***** Episode 40, Timesteps 8000, Mean Return = -663.8  *****
ExplainedVarNew: -7.23e-09
ExplainedVarOld: -2.42e-08
ValFuncLoss: 502


***** Episode 50, Timesteps 10000, Mean Return = -555.6  *****
ExplainedVarNew: -4.79e-10
ExplainedVarOld: -9.48e-10
ValFuncLoss: 250


***** Episode 60, Timesteps 11872, Mean Return = -577.6  *****
ExplainedVarNew: -7.57e-14
ExplainedVarOld: -1.09e-13
ValFuncLoss: 277


***** Episode 70, Timesteps 13620, Mean Return = -430.5  *****
ExplainedVarNew: -2.35e-13
ExplainedVarOld: -2.88e-13
ValFuncLoss: 197

***** Episode 610, Timesteps 93341, Mean Return = -218.8  *****
ExplainedVarNew: -1.13e-08
ExplainedVarOld: -4.54e-08
ValFuncLoss: 120


***** Episode 620, Timesteps 94396, Mean Return = -203.1  *****
ExplainedVarNew: -3.11e-10
ExplainedVarOld: -2.77e-09
ValFuncLoss: 144


***** Episode 630, Timesteps 95827, Mean Return = -172.3  *****
ExplainedVarNew: -3.87e-13
ExplainedVarOld: -2.46e-13
ValFuncLoss: 65.6


***** Episode 640, Timesteps 97636, Mean Return = -258.6  *****
ExplainedVarNew: -2.95e-10
ExplainedVarOld: -8.37e-11
ValFuncLoss: 107


***** Episode 650, Timesteps 98936, Mean Return = -154.7  *****
ExplainedVarNew: -5.84e-11
ExplainedVarOld: -1.44e-11
ValFuncLoss: 53.9


***** Episode 660, Timesteps 99999, Mean Return = -99.1  *****
ExplainedVarNew: -6.66e-10
ExplainedVarOld: -3.22e-11
ValFuncLoss: 36


***** Episode 670, Timesteps 100896, Mean Return = -106.9  *****
ExplainedVarNew: -7.85e-11
ExplainedVarOld: -7.47e-08
ValFuncLoss: 79


***** Episode 680, Timesteps 102149, Mean

***** Episode 1210, Timesteps 172016, Mean Return = -131.0  *****
ExplainedVarNew: 0
ExplainedVarOld: -3.55e-15
ValFuncLoss: 4.92


***** Episode 1220, Timesteps 173450, Mean Return = -113.3  *****
ExplainedVarNew: -9.69e-12
ExplainedVarOld: -8.94e-12
ValFuncLoss: 7.57


***** Episode 1230, Timesteps 175068, Mean Return = -131.4  *****
ExplainedVarNew: -1.61e-11
ExplainedVarOld: -1.28e-11
ValFuncLoss: 4.98


***** Episode 1240, Timesteps 176692, Mean Return = -131.8  *****
ExplainedVarNew: -1.02e-10
ExplainedVarOld: -1.04e-10
ValFuncLoss: 5.34


***** Episode 1250, Timesteps 178136, Mean Return = -143.5  *****
ExplainedVarNew: -1.32e-10
ExplainedVarOld: -2.71e-10
ValFuncLoss: 35.8


***** Episode 1260, Timesteps 179219, Mean Return = -114.2  *****
ExplainedVarNew: -1.26e-13
ExplainedVarOld: -2.4e-13
ValFuncLoss: 63


***** Episode 1270, Timesteps 180473, Mean Return = -97.0  *****
ExplainedVarNew: 1.11e-16
ExplainedVarOld: 1.11e-16
ValFuncLoss: 12.5


***** Episode 1280, Timesteps 1815

***** Episode 1800, Timesteps 250696, Mean Return = -96.3  *****
ExplainedVarNew: -1.15e-14
ExplainedVarOld: -1.24e-12
ValFuncLoss: 11.2


***** Episode 1810, Timesteps 251952, Mean Return = -97.1  *****
ExplainedVarNew: 6.66e-16
ExplainedVarOld: -2.22e-16
ValFuncLoss: 11.3


***** Episode 1820, Timesteps 253395, Mean Return = -114.5  *****
ExplainedVarNew: -8.88e-16
ExplainedVarOld: -1.55e-15
ValFuncLoss: 8.02


***** Episode 1830, Timesteps 254838, Mean Return = -114.7  *****
ExplainedVarNew: -2e-15
ExplainedVarOld: -4.44e-16
ValFuncLoss: 8.07


***** Episode 1840, Timesteps 256097, Mean Return = -97.5  *****
ExplainedVarNew: -4.44e-16
ExplainedVarOld: -1.11e-15
ValFuncLoss: 11.5


***** Episode 1850, Timesteps 256786, Mean Return = -43.4  *****
ExplainedVarNew: -3.11e-15
ExplainedVarOld: -2.2e-13
ValFuncLoss: 30.8


***** Episode 1860, Timesteps 258211, Mean Return = -112.9  *****
ExplainedVarNew: -5.26e-13
ExplainedVarOld: -2.31e-12
ValFuncLoss: 7.66


***** Episode 1870, Timesteps

# Discrete State Space - KL PO
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [None]:
env_name = 'Taxi-v3'
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyKL(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 1000
batch_eps = 60
logger = Logger(logname=env_name + '_DR-KL_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        policy.update(observes, actions, advantages, disc_freqs, env_name, eps)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

## Generate Expert Trajectories

In [None]:
from stable_baselines import PPO2
from stable_baselines.gail import generate_expert_traj

# Generate expert trajectories (train expert)
model = PPO2('MlpPolicy', 'Taxi-v3', verbose=1)
# Train for 60000 timesteps and record 10 trajectories
# all the data will be saved in 'expert_pendulum.npz' file
generate_expert_traj(model, 'expert_taxi', n_timesteps=600000, n_episodes=10)