In [1]:
import gym
import pybullet
import pybullet_envs
from gym import wrappers
from datetime import datetime

# DR TRPO related files
from train_helper import *
from value import NNValueFunction
from utils import Logger
from dr_policy import DRPolicyKL, DRPolicyWass

# Discrete State Space - KL DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [2]:
env_name = 'NChain-v0'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyKL(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.9
lam = 1
total_eps = 200
batch_eps = 20
logger = Logger(logname=env_name + '_DR-KL_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        policy.update(observes, actions, advantages, disc_freqs)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
***** Episode 20, Mean Return = 1312.2, Mean Discounted Return = 14.0 *****
ExplainedVarNew: -9.93e-07
ExplainedVarOld: -0.00323
ValFuncLoss: 32.7


***** Episode 40, Mean Return = 1646.3, Mean Discounted Return = 14.3 *****
ExplainedVarNew: -0.0761
ExplainedVarOld: -3.64e-07
ValFuncLoss: 99.7


***** Episode 60, Mean Return = 2091.4, Mean Discounted Return = 13.6 *****
ExplainedVarNew: -0.218
ExplainedVarOld: -0.0502
ValFuncLoss: 222


***** Episode 80, Mean Return = 2709.9, Mean Discounted Return = 19.7 *****
ExplainedVarNew: -0.285
ExplainedVarOld: -0.191
ValFuncLoss: 351


***** Episode 100, Mean Return = 3005.5, Mean Discounted Return = 20.1 *****
ExplainedVarNew: -0.296
ExplainedVarOld: -0.275
ValFuncLoss: 396


***** Episode 120, Mean Return = 3271.7, Mean Discounted Return = 22.7 *****
ExplainedVarNew: -0.258
ExplainedVarOld: -0.263
ValFuncLoss: 448


***** Episode 140, Mean Return = 3294.0, Mean Discounted Return = 21.5 *****
E

# Discrete State Space - Wasserstein DR TRPO Policy
### 'Taxi-v3', 'Roulette-v0', 'NChain-v0', 'FrozenLake-v0', 'CliffWalking-v0', 'FrozenLake8x8-v0'

In [3]:
env_name = 'NChain-v0'
pybullet.connect(pybullet.DIRECT)
env = gym.make(env_name)
sta_num = env.observation_space.n
act_num = env.action_space.n
policy = DRPolicyWass(sta_num, act_num)
val_func = NNValueFunction(1, 10)
gamma = 0.8
lam = 1
total_eps = 150
batch_eps = 15
logger = Logger(logname=env_name + '_DR-Wass_Batch=' + str(batch_eps), now=datetime.utcnow().strftime("%b-%d_%H:%M:%S"))


eps = 0
while eps < total_eps:
        trajectories = run_policy(env, policy, batch_eps, logger)
        eps += len(trajectories)
        # add estimated values to episodes
        add_value(trajectories, val_func)  
        # calculated discounted sum of Rs
        add_disc_sum_rew(trajectories, gamma, logger)  
        # calculate advantage
        add_gae(trajectories, gamma, lam)  
        # concatenate all episodes into single NumPy arrays
        observes, actions, advantages, disc_sum_rew = build_train_set(trajectories)
        disc_freqs = find_disc_freqs(trajectories, sta_num, gamma)
        log_batch_stats(observes, actions, advantages, disc_sum_rew, eps, logger)
        policy.update(observes, actions, advantages, disc_freqs)
        val_func.fit(observes, disc_sum_rew, logger)
        # write logger results to file and stdout
        logger.write(display=True) 
logger.close()

Value Params -- h1: 10, h2: 7, h3: 5, lr: 0.00378
***** Episode 15, Mean Return = 1302.3, Mean Discounted Return = 5.2 *****
ExplainedVarNew: -0.00463
ExplainedVarOld: -0.0029
ValFuncLoss: 13.9


***** Episode 30, Mean Return = 2000.5, Mean Discounted Return = 6.2 *****
ExplainedVarNew: -0.249
ExplainedVarOld: -0.00077
ValFuncLoss: 105


***** Episode 45, Mean Return = 2594.5, Mean Discounted Return = 7.1 *****
ExplainedVarNew: -0.487
ExplainedVarOld: -0.237
ValFuncLoss: 181


***** Episode 60, Mean Return = 2350.1, Mean Discounted Return = 7.9 *****
ExplainedVarNew: -0.496
ExplainedVarOld: -0.5
ValFuncLoss: 159


***** Episode 75, Mean Return = 3046.4, Mean Discounted Return = 8.8 *****
ExplainedVarNew: -0.486
ExplainedVarOld: -0.463
ValFuncLoss: 216


***** Episode 90, Mean Return = 2987.1, Mean Discounted Return = 9.2 *****
ExplainedVarNew: -0.485
ExplainedVarOld: -0.528
ValFuncLoss: 196


***** Episode 105, Mean Return = 3067.9, Mean Discounted Return = 8.8 *****
ExplainedVarNew: -