diff --git a/spinup/algos/ddpg/ddpg.py b/spinup/algos/ddpg/ddpg.py index 294379023..44ba9ff60 100644 --- a/spinup/algos/ddpg/ddpg.py +++ b/spinup/algos/ddpg/ddpg.py @@ -179,7 +179,7 @@ def ddpg(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q': q}) def get_action(o, noise_scale): - a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)}) + a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0] a += noise_scale * np.random.randn(act_dim) return np.clip(a, -act_limit, act_limit) diff --git a/spinup/algos/sac/sac.py b/spinup/algos/sac/sac.py index 49dcdecf0..0f0eb1d90 100644 --- a/spinup/algos/sac/sac.py +++ b/spinup/algos/sac/sac.py @@ -215,7 +215,7 @@ def sac(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, def get_action(o, deterministic=False): act_op = mu if deterministic else pi - return sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)}) + return sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})[0] def test_agent(n=10): global sess, mu, pi, q1, q2, q1_pi, q2_pi diff --git a/spinup/algos/td3/td3.py b/spinup/algos/td3/td3.py index 5ea42365e..3c375ca7c 100644 --- a/spinup/algos/td3/td3.py +++ b/spinup/algos/td3/td3.py @@ -205,7 +205,7 @@ def td3(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q1': q1, 'q2': q2}) def get_action(o, noise_scale): - a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)}) + a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0] a += noise_scale * np.random.randn(act_dim) return np.clip(a, -act_limit, act_limit)