Set things up

In [1]:
import numpy as np
import tensorflow as tf

from nn_policy import FeedForwardCritic
from nn_policy import FeedForwardPolicy
from rllab.envs.mujoco.half_cheetah_env import HalfCheetahEnv
from rllab.exploration_strategies.ou_strategy import OUStrategy
from sandbox.rocky.tf.algos.ddpg import DDPG as ShaneDDPG
from sandbox.rocky.tf.envs.base import TfEnv
from sandbox.rocky.tf.policies.deterministic_mlp_policy import \
    DeterministicMLPPolicy
from sandbox.rocky.tf.q_functions.continuous_mlp_q_function import \
    ContinuousMLPQFunction

from ddpg import DDPG as MyDDPG
from testing_utils import are_np_arrays_equal

In [2]:
env = TfEnv(HalfCheetahEnv())
action_dim = env.action_dim
obs_dim = env.observation_space.low.shape[0]

batch_size = 2
rewards = np.random.rand(batch_size)
terminals = (np.random.rand(batch_size) > 0.5).astype(np.int)
obs = np.random.rand(batch_size, obs_dim)
actions = np.random.rand(batch_size, action_dim)
next_obs = np.random.rand(batch_size, obs_dim)

ddpg_params = dict(
    batch_size=64,
    n_epochs=0,
    epoch_length=0,
    eval_samples=0,
    discount=0.99,
    qf_learning_rate=1e-3,
    policy_learning_rate=1e-4,
    soft_target_tau=0.001,
    replay_pool_size=1000000,
    min_pool_size=1000,
    scale_reward=0.1,
)
discount = ddpg_params['discount']

Create my stuff

In [3]:
sess1  tf.Session()
with sess1.as_default():
    es = OUStrategy(env_spec=env.spec)
    ddpg_params['Q_weight_decay'] = 0.
    qf_params = dict(
        embedded_hidden_sizes=(100, ),
        observation_hidden_sizes=(100, ),
        hidden_nonlinearity=tf.nn.relu,
    )
    policy_params = dict(
        observation_hidden_sizes=(100, 100),
        hidden_nonlinearity=tf.nn.relu,
        output_nonlinearity=tf.nn.tanh,
    )
    qf = FeedForwardCritic(
        "critic",
        env.observation_space.flat_dim,
        env.action_space.flat_dim,
        **qf_params
    )
    policy = FeedForwardPolicy(
        "actor",
        env.observation_space.flat_dim,
        env.action_space.flat_dim,
        **policy_params
    )
    my_algo = MyDDPG(
        env,
        es,
        policy,
        qf,
        **ddpg_params
    )

In [5]:
sess_shane = tf.Session()
with sess_shane.as_default():
    es = OUStrategy(env_spec=env.spec)
    policy = DeterministicMLPPolicy(
        name="init_policy",
        env_spec=env.spec,
        hidden_sizes=(100, 100),
        hidden_nonlinearity=tf.nn.relu,
        output_nonlinearity=tf.nn.tanh,
    )
    qf = ContinuousMLPQFunction(
        name="qf",
        env_spec=env.spec,
        hidden_sizes=(100, 100),
    )
    ddpg.pop('Q_weight_decay')
    shane_algo = ShaneDDPG(
        env,
        policy,
        qf,
        es,
        **ddpg_params
    )
    sess.run(tf.initialize_all_variables())
    shane_algo.init_opt()
    # This initializes the optimizer parameters
    sess.run(tf.initialize_all_variables())

NameError: name 'ddpg' is not defined

In [4]:
shane_policy = shane_algo.policy
shane_qf = shane_algo.qf
shane_policy_param_values = shane_policy.flat_to_params(
    shane_policy.get_param_values()
)
shane_qf_param_values = shane_qf.flat_to_params(
    shane_qf.get_param_values()
)
f_train_policy = shane_algo.opt_info['f_train_policy']
f_train_qf = shane_algo.opt_info['f_train_qf']
target_qf = shane_algo.opt_info["target_qf"]
target_policy = shane_algo.opt_info["target_policy"]
next_actions, _ = target_policy.get_actions(next_obs)
next_qvals = target_qf.get_qval(next_obs, next_actions)

ys = rewards + (1. - terminals) * discount * next_qvals
# qf_loss, qval, _ = f_train_qf(ys, obs, actions)
# policy_surr, _ = f_train_policy(obs)

In [None]:
my_algo.actor.set_param_values(shane_policy_param_values)
my_algo.target_actor.set_param_values(shane_policy_param_values)
my_algo.critic.set_param_values(shane_qf_param_values)
my_algo.target_critic.set_param_values(shane_qf_param_values)
feed_dict = my_algo._update_feed_dict(rewards, terminals, obs,
                                      actions, next_obs)

In [9]:
my_ys = sess.run(my_algo.ys, feed_dict=feed_dict).flatten()
actor_loss = sess.run(
    my_algo.actor_surrogate_loss,
    feed_dict=feed_dict)
critic_loss = sess.run(
    my_algo.actor_surrogate_loss,
    feed_dict=feed_dict)
critic_output = sess.run(
    my_algo.critic.output,
    feed_dict=feed_dict).flatten()

In [6]:
# Check params didn't change
shane_policy_param_values_new = shane_policy.flat_to_params(
    shane_policy.get_param_values()
)
shane_qf_param_values_new = shane_qf.flat_to_params(
    shane_qf.get_param_values()
)
print(all((a==b).all() for a, b in zip(shane_qf_param_values_new, shane_qf_param_values)))
print(all((a==b).all() for a, b in zip(shane_policy_param_values_new, shane_policy_param_values)))
print(shane_qf_param_values_new[0])
print(shane_qf_param_values[0])

False
False
[[ 0.10921419  0.09765881 -0.03137018 ..., -0.02751763  0.22333932
   0.02181816]
 [-0.17966431 -0.18096715 -0.17333549 ...,  0.1443533  -0.0943089
   0.02416623]
 [ 0.11363882  0.15522802  0.12664583 ..., -0.22307341  0.12125483
  -0.08646591]
 ..., 
 [-0.06120278 -0.1051338  -0.12680353 ...,  0.17530295  0.16245732
   0.13607773]
 [-0.05959937 -0.22201149 -0.21675102 ...,  0.17212844  0.05435613
  -0.21109346]
 [-0.00831135 -0.16856278  0.04452699 ..., -0.05924542  0.0239968
   0.1407443 ]]
[[-0.19854721  0.02298436 -0.04613249 ..., -0.00690316  0.06697407
  -0.17650957]
 [ 0.17003807  0.08754197  0.08480752 ..., -0.17002124 -0.0353867
   0.19973361]
 [-0.06235437 -0.07145138  0.14938217 ...,  0.12414679 -0.00750153
   0.07396972]
 ..., 
 [-0.08707686 -0.12958983  0.21109462 ...,  0.06713736  0.13400322
  -0.14379099]
 [-0.11963328  0.11264029  0.01875345 ..., -0.1142823  -0.16798946
  -0.04623538]
 [-0.10833817  0.13236687 -0.16835923 ...,  0.0404802   0.0754824
   0.086

Check params didn't change

In [7]:
critic = my_algo.critic
actor = my_algo.actor

print(all((a==b).all() for a, b in zip(critic.get_param_values(), shane_qf_param_values)))
print(all((a==b).all() for a, b in zip(actor.get_param_values(), shane_policy_param_values)))


my_critic_out = sess.run(
    critic.output,
    feed_dict={
        critic.actions_placeholder: actions,
        critic.observations_placeholder: obs,
    }).flatten()
shane_critic = shane_algo.qf
shane_critic_out = shane_critic.get_qval(next_obs, next_actions)
# print(my_critic_out)
# print(shane_critic_out)
# print(shane_critic._output_layer.input_layer.nonlinearity)
print([a.shape for a in shane_qf_param_values])
W1, b1, W2, b2, W3, b3 = shane_qf_param_values
output = np.matmul(obs, W1) + b1
output = np.maximum(output, 0)
output = np.hstack((output, actions))
output = np.matmul(output, W2) + b2
output = np.maximum(output, 0)
output = np.matmul(output, W3) + b3
output = output.flatten()


print(my_critic_out.shape)
print(shane_critic_out.shape)
print(output.shape)

print(my_critic_out)
print(shane_critic_out)
print(output)

False
True
True
False
[(20, 100), (100,), (106, 100), (100,), (100, 1), (1,)]
(2,)
(2,)
(2,)
[-0.28969401 -0.26044056]
[ 0.38071227  0.14097965]
[-0.289694  -0.2604406]


In [12]:
print(W1)
print(b1)

[[ 0.18006641 -0.19114682 -0.17156905 ..., -0.08719799 -0.02502914
   0.030267  ]
 [ 0.14154136  0.16746089 -0.01075949 ..., -0.04649469 -0.13204986
  -0.08563615]
 [-0.03190383  0.14813185  0.1691182  ..., -0.12389179 -0.0197947
  -0.0118479 ]
 ..., 
 [ 0.16140383 -0.15739915 -0.1783531  ..., -0.10120998  0.12150282
  -0.13049315]
 [ 0.19560936 -0.22234261 -0.15613042 ..., -0.17660858  0.03528771
   0.16670644]
 [-0.06725742 -0.01214127  0.16213682 ..., -0.12453543 -0.16090882
  -0.08211282]]
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]


In [14]:
print(W2)
print(b2)

[[ 0.11715043 -0.01768987  0.15925303 ...,  0.07064421  0.03645761
   0.16479751]
 [ 0.10044381  0.02719294 -0.16904333 ..., -0.11369165  0.01187196
   0.09673491]
 [-0.03212716 -0.15086609 -0.08499511 ..., -0.10084184  0.06693818
  -0.06268194]
 ..., 
 [ 0.09487638 -0.01168025 -0.02640516 ...,  0.09314603 -0.16379729
   0.08736303]
 [-0.07443461  0.02360873  0.06644437 ...,  0.11957365 -0.134672
  -0.01987462]
 [-0.14258642  0.05320536 -0.11457152 ..., -0.16631398 -0.14329247
   0.00755855]]
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]


In [15]:
print(W3)
print(b3)

[[-0.09496497]
 [ 0.02086809]
 [ 0.02640805]
 [ 0.0055166 ]
 [ 0.09683746]
 [-0.10159475]
 [ 0.14121649]
 [ 0.03647819]
 [-0.04363871]
 [ 0.09008625]
 [ 0.20288092]
 [ 0.10421315]
 [ 0.14983174]
 [ 0.1653336 ]
 [ 0.17695239]
 [-0.03653538]
 [-0.16007495]
 [ 0.14253625]
 [-0.04394285]
 [-0.17573562]
 [-0.23238011]
 [ 0.10677379]
 [-0.03567053]
 [ 0.16729721]
 [-0.19013867]
 [ 0.23726842]
 [-0.20066203]
 [-0.03206511]
 [-0.00955209]
 [ 0.18400922]
 [-0.15858348]
 [-0.09610504]
 [ 0.1731275 ]
 [ 0.08832544]
 [-0.09939101]
 [ 0.10766405]
 [-0.18003237]
 [-0.16474082]
 [ 0.09337634]
 [ 0.06137091]
 [ 0.08485803]
 [ 0.12888962]
 [-0.1710934 ]
 [-0.14716667]
 [-0.09536344]
 [ 0.07462418]
 [-0.0192439 ]
 [-0.04157276]
 [ 0.13935924]
 [ 0.1240854 ]
 [ 0.09136358]
 [-0.14306489]
 [-0.06212461]
 [-0.0628694 ]
 [-0.1938062 ]
 [ 0.18955332]
 [-0.05474287]
 [ 0.19348317]
 [ 0.13769644]
 [ 0.15709323]
 [ 0.13692456]
 [-0.19098249]
 [-0.22752208]
 [ 0.17719397]
 [-0.00713232]
 [-0.19903697]
 [-0.04470

In [None]:
same_locations = abs(my_ys - ys) < 0.001
diff_locations = abs(my_ys - ys) > 1e-3

print(ys)
print(my_ys)
print(rewards)

print(same_locations == terminals)

In [12]:
print(critic_output)
print(next_qvals)

True
True
[ 0.08043291  0.15407091  0.04160417  0.25431705 -0.01546162  0.12317318
  0.08870779  0.0034778   0.11139356  0.02977742  0.09610271  0.04143379
  0.14590845  0.30310166 -0.03043694  0.29100132  0.24680986  0.14942929
  0.19308668  0.21655881  0.05423871  0.28347656  0.23354399  0.1701221
  0.00468823  0.28256524  0.33678555  0.00583877  0.18574598  0.19084638
 -0.06781723  0.1096947 ]
[-0.41515172 -0.47687221 -0.3862282  -0.37483859 -0.36577362 -0.35058931
 -0.55418265 -0.28721622 -0.36617911 -0.39964348 -0.35683811 -0.25585258
 -0.42136523 -0.33881104 -0.30203792 -0.31551418 -0.20384291 -0.39162099
 -0.20989272 -0.38798833 -0.4832077  -0.47076946 -0.39890233 -0.33049214
 -0.32134044 -0.37559023 -0.43229294 -0.44028047 -0.39340836 -0.28017768
 -0.23092277 -0.41609758]


In [5]:
    print("foo")
    print(are_np_arrays_equal(
        my_ys,
        ys
    ))
    print(are_np_arrays_equal(
        actor_loss,
        policy_surr
    ))
    print(are_np_arrays_equal(
        qf_loss,
        critic_loss
    ))

foo


NameError: name 'my_ys' is not defined

In [None]:
sess.close()