Skip to content

Commit

Permalink
Update run_dagger.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bustermachinej committed Apr 10, 2017
1 parent 01b545a commit 25f0dba
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions run_dagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,10 @@ def main():
# architecture of the MLP policy function
x = tf.placeholder(tf.float32, shape=[None, obs_dim])
yhot = tf.placeholder(tf.float32, shape=[None, act_dim])
is_train = tf.placeholder(tf.bool)

h1 = tf.layers.dense(inputs=x, units=128, activation=tf.nn.relu)
# drop1=tf.layers.dropout(inputs=h1,rate=0.5,training=is_train)
h2 = tf.layers.dense(inputs=h1, units=64, activation=tf.nn.relu)
# drop2=tf.layers.dropout(inputs=h2,rate=0.5,training=is_train)
h3 = tf.layers.dense(inputs=h2, units=32, activation=tf.nn.relu)
# drop3=tf.layers.dropout(inputs=h3,rate=0.5,training=is_train)
yhat = tf.layers.dense(inputs=h3, units=act_dim, activation=None)

loss_l2 = tf.reduce_mean(tf.square(yhot - yhat))
Expand All @@ -106,10 +102,10 @@ def main():
batch_size = 25
for step in range(10000):
batch_i = np.random.randint(0, obs_data.shape[0], size=batch_size)
train_step.run(feed_dict={x: obs_data[batch_i, ], yhot: act_data[batch_i, ], is_train: 1})
train_step.run(feed_dict={x: obs_data[batch_i, ], yhot: act_data[batch_i, ]})
if (step % 1000 == 0):
print 'opmization step ', step
print 'obj value is ', loss_l2.eval(feed_dict={x:obs_data, yhot:act_data, is_train: 0})
print 'obj value is ', loss_l2.eval(feed_dict={x:obs_data, yhot:act_data})
print 'Optimization Finished!'
# use trained MLP to perform
max_steps = env.spec.timestep_limit
Expand All @@ -124,7 +120,7 @@ def main():
totalr = 0.
steps = 0
while not done:
action = yhat.eval(feed_dict={x:obs[None, :], is_train:0})
action = yhat.eval(feed_dict={x:obs[None, :]})
observations.append(obs)
actions.append(action)
obs, r, done, _ = env.step(action)
Expand Down

0 comments on commit 25f0dba

Please sign in to comment.