In [2]:
import gym
import pybullet
import pybullet_envs
from gym import wrappers
from datetime import datetime

# DR TRPO related files
from train_helper import *
from value import NNValueFunction
from utils import Logger
from dr_policy import DRPolicyKL, DRPolicyWass

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Discrete State Space - KL DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [None]:
env_name = 'NChain-v0'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyKL(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 2000
batch_eps = 20
logger = Logger(logname=env_name + '_DR-KL_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        policy.update(observes, actions, advantages, disc_freqs)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

# Discrete State Space - Wasserstein DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [6]:
env_name = 'NChain-v0'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyWass(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 1500
batch_eps = 5
logger = Logger(logname=env_name + '_DR-Wass_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        policy.update(observes, actions, advantages, disc_freqs)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
***** Episode 5, Mean Return = 1321.2, Mean Discounted Return = 13.5 *****
ExplainedVarNew: -7.53e-06
ExplainedVarOld: -0.000427
ValFuncLoss: 67.1


***** Episode 10, Mean Return = 1372.0, Mean Discounted Return = 11.4 *****
ExplainedVarNew: -7.39e-07
ExplainedVarOld: -5.8e-06
ValFuncLoss: 41.4


***** Episode 15, Mean Return = 1302.0, Mean Discounted Return = 10.8 *****
ExplainedVarNew: -1.44e-06
ExplainedVarOld: -9.29e-07
ValFuncLoss: 30.7


***** Episode 20, Mean Return = 1284.0, Mean Discounted Return = 11.5 *****
ExplainedVarNew: -2.35e-05
ExplainedVarOld: -1.58e-06
ValFuncLoss: 27.8


***** Episode 25, Mean Return = 1343.6, Mean Discounted Return = 11.4 *****
ExplainedVarNew: -0.0244
ExplainedVarOld: -2.02e-05
ValFuncLoss: 33.9


***** Episode 30, Mean Return = 1338.0, Mean Discounted Return = 12.4 *****
ExplainedVarNew: -0.0411
ExplainedVarOld: -0.02
ValFuncLoss: 41.7


***** Episode 35, Mean Return = 1270.8, Mean Discounted Retu

***** Episode 290, Mean Return = 2597.6, Mean Discounted Return = 14.5 *****
ExplainedVarNew: -0.294
ExplainedVarOld: -0.268
ValFuncLoss: 388


***** Episode 295, Mean Return = 2458.4, Mean Discounted Return = 14.6 *****
ExplainedVarNew: -0.34
ExplainedVarOld: -0.288
ValFuncLoss: 384


***** Episode 300, Mean Return = 2414.4, Mean Discounted Return = 25.7 *****
ExplainedVarNew: -0.355
ExplainedVarOld: -0.376
ValFuncLoss: 353


***** Episode 305, Mean Return = 2416.4, Mean Discounted Return = 15.8 *****
ExplainedVarNew: -0.321
ExplainedVarOld: -0.355
ValFuncLoss: 345


***** Episode 310, Mean Return = 2552.0, Mean Discounted Return = 15.5 *****
ExplainedVarNew: -0.312
ExplainedVarOld: -0.335
ValFuncLoss: 348


***** Episode 315, Mean Return = 2735.6, Mean Discounted Return = 22.6 *****
ExplainedVarNew: -0.284
ExplainedVarOld: -0.277
ValFuncLoss: 403


***** Episode 320, Mean Return = 2376.4, Mean Discounted Return = 19.1 *****
ExplainedVarNew: -0.336
ExplainedVarOld: -0.315
ValFuncLoss:

***** Episode 575, Mean Return = 2997.2, Mean Discounted Return = 19.0 *****
ExplainedVarNew: -0.268
ExplainedVarOld: -0.283
ValFuncLoss: 393


***** Episode 580, Mean Return = 3190.4, Mean Discounted Return = 16.1 *****
ExplainedVarNew: -0.26
ExplainedVarOld: -0.234
ValFuncLoss: 463


***** Episode 585, Mean Return = 3205.2, Mean Discounted Return = 24.9 *****
ExplainedVarNew: -0.279
ExplainedVarOld: -0.276
ValFuncLoss: 441


***** Episode 590, Mean Return = 3387.2, Mean Discounted Return = 19.2 *****
ExplainedVarNew: -0.269
ExplainedVarOld: -0.265
ValFuncLoss: 472


***** Episode 595, Mean Return = 2911.2, Mean Discounted Return = 23.7 *****
ExplainedVarNew: -0.336
ExplainedVarOld: -0.336
ValFuncLoss: 377


***** Episode 600, Mean Return = 3100.8, Mean Discounted Return = 13.7 *****
ExplainedVarNew: -0.264
ExplainedVarOld: -0.296
ValFuncLoss: 410


***** Episode 605, Mean Return = 3093.2, Mean Discounted Return = 25.3 *****
ExplainedVarNew: -0.27
ExplainedVarOld: -0.25
ValFuncLoss: 4

***** Episode 860, Mean Return = 2959.2, Mean Discounted Return = 27.6 *****
ExplainedVarNew: -0.297
ExplainedVarOld: -0.346
ValFuncLoss: 344


***** Episode 865, Mean Return = 2978.4, Mean Discounted Return = 24.3 *****
ExplainedVarNew: -0.276
ExplainedVarOld: -0.274
ValFuncLoss: 368


***** Episode 870, Mean Return = 3032.4, Mean Discounted Return = 21.3 *****
ExplainedVarNew: -0.256
ExplainedVarOld: -0.251
ValFuncLoss: 403


***** Episode 875, Mean Return = 3052.4, Mean Discounted Return = 18.3 *****
ExplainedVarNew: -0.288
ExplainedVarOld: -0.283
ValFuncLoss: 374


***** Episode 880, Mean Return = 2904.8, Mean Discounted Return = 17.6 *****
ExplainedVarNew: -0.274
ExplainedVarOld: -0.262
ValFuncLoss: 397


***** Episode 885, Mean Return = 2831.6, Mean Discounted Return = 17.7 *****
ExplainedVarNew: -0.3
ExplainedVarOld: -0.3
ValFuncLoss: 365


***** Episode 890, Mean Return = 3078.0, Mean Discounted Return = 23.7 *****
ExplainedVarNew: -0.261
ExplainedVarOld: -0.244
ValFuncLoss: 45

***** Episode 1145, Mean Return = 3106.4, Mean Discounted Return = 19.9 *****
ExplainedVarNew: -0.349
ExplainedVarOld: -0.34
ValFuncLoss: 417


***** Episode 1150, Mean Return = 2882.8, Mean Discounted Return = 16.2 *****
ExplainedVarNew: -0.313
ExplainedVarOld: -0.357
ValFuncLoss: 379


***** Episode 1155, Mean Return = 3102.4, Mean Discounted Return = 30.4 *****
ExplainedVarNew: -0.284
ExplainedVarOld: -0.297
ValFuncLoss: 410


***** Episode 1160, Mean Return = 3148.0, Mean Discounted Return = 30.5 *****
ExplainedVarNew: -0.277
ExplainedVarOld: -0.276
ValFuncLoss: 420


***** Episode 1165, Mean Return = 3358.0, Mean Discounted Return = 26.8 *****
ExplainedVarNew: -0.258
ExplainedVarOld: -0.232
ValFuncLoss: 509


***** Episode 1170, Mean Return = 2998.8, Mean Discounted Return = 22.0 *****
ExplainedVarNew: -0.317
ExplainedVarOld: -0.316
ValFuncLoss: 414


***** Episode 1175, Mean Return = 2822.8, Mean Discounted Return = 15.9 *****
ExplainedVarNew: -0.258
ExplainedVarOld: -0.357
ValFu

***** Episode 1430, Mean Return = 3254.4, Mean Discounted Return = 26.3 *****
ExplainedVarNew: -0.253
ExplainedVarOld: -0.222
ValFuncLoss: 442


***** Episode 1435, Mean Return = 2829.2, Mean Discounted Return = 27.0 *****
ExplainedVarNew: -0.287
ExplainedVarOld: -0.264
ValFuncLoss: 407


***** Episode 1440, Mean Return = 3221.2, Mean Discounted Return = 32.9 *****
ExplainedVarNew: -0.288
ExplainedVarOld: -0.294
ValFuncLoss: 424


***** Episode 1445, Mean Return = 3176.0, Mean Discounted Return = 21.9 *****
ExplainedVarNew: -0.278
ExplainedVarOld: -0.276
ValFuncLoss: 435


***** Episode 1450, Mean Return = 2787.6, Mean Discounted Return = 22.1 *****
ExplainedVarNew: -0.307
ExplainedVarOld: -0.326
ValFuncLoss: 357


***** Episode 1455, Mean Return = 3095.2, Mean Discounted Return = 17.8 *****
ExplainedVarNew: -0.263
ExplainedVarOld: -0.282
ValFuncLoss: 397


***** Episode 1460, Mean Return = 3003.2, Mean Discounted Return = 22.5 *****
ExplainedVarNew: -0.259
ExplainedVarOld: -0.232
ValF