Skip to content

Commit

Permalink
fixes acktr_cont issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mansimov committed Sep 30, 2017
1 parent 699919f commit f8663ea
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
8 changes: 5 additions & 3 deletions baselines/acktr/acktr_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def rollout(env, policy, max_pathlength, animate=False, obfilter=None):
"action_dist": np.array(ac_dists), "logp" : np.array(logps)}

def learn(env, policy, vf, gamma, lam, timesteps_per_batch, num_timesteps,
animate=False, callback=None, optimizer="adam", desired_kl=0.002):
animate=False, callback=None, desired_kl=0.002):

obfilter = ZFilter(env.observation_space.shape)

Expand Down Expand Up @@ -117,14 +117,16 @@ def learn(env, policy, vf, gamma, lam, timesteps_per_batch, num_timesteps,
# Policy update
do_update(ob_no, action_na, standardized_adv_n)

min_stepsize = np.float32(1e-8)
max_stepsize = np.float32(1e0)
# Adjust stepsize
kl = policy.compute_kl(ob_no, oldac_dist)
if kl > desired_kl * 2:
logger.log("kl too high")
U.eval(tf.assign(stepsize, stepsize / 1.5))
U.eval(tf.assign(stepsize, tf.maximum(min_stepsize, stepsize / 1.5)))
elif kl < desired_kl / 2:
logger.log("kl too low")
U.eval(tf.assign(stepsize, stepsize * 1.5))
U.eval(tf.assign(stepsize, tf.minimum(max_stepsize, stepsize * 1.5)))
else:
logger.log("kl just right!")

Expand Down
2 changes: 1 addition & 1 deletion baselines/acktr/run_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def train(env_id, num_timesteps, seed):
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
args = parser.parse_args()
train(args.env_id, num_timesteps=1e6, seed=args.seed)
train(args.env, num_timesteps=1e6, seed=args.seed)
4 changes: 2 additions & 2 deletions baselines/acktr/value_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
wd_dict = {}
h1 = tf.nn.elu(dense(X, 64, "h1", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict))
h2 = tf.nn.elu(dense(h1, 64, "h2", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict))
vpred_n = dense(h2, 1, "hfinal", weight_init=U.normc_initializer(1.0), bias_init=0, weight_loss_dict=wd_dict)[:,0]
vpred_n = dense(h2, 1, "hfinal", weight_init=None, bias_init=0, weight_loss_dict=wd_dict)[:,0]
sample_vpred_n = vpred_n + tf.random_normal(tf.shape(vpred_n))
wd_loss = tf.get_collection("vf_losses", None)
loss = U.mean(tf.square(vpred_n - vtarg_n)) + tf.add_n(wd_loss)
Expand All @@ -22,7 +22,7 @@ def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
optim = kfac.KfacOptimizer(learning_rate=0.001, cold_lr=0.001*(1-0.9), momentum=0.9, \
clip_kl=0.3, epsilon=0.1, stats_decay=0.95, \
async=1, kfac_update=2, cold_iter=50, \
weight_decay_dict=wd_dict, max_grad_norm=None)
weight_decay_dict=wd_dict, max_grad_norm=1.0)
vf_var_list = []
for var in tf.trainable_variables():
if "vf" in var.name:
Expand Down

0 comments on commit f8663ea

Please sign in to comment.