Skip to content

Commit

Permalink
Update to tensorflow 1.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacques KAISER committed Jul 20, 2018
1 parent 69ef481 commit c2193fe
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
16 changes: 8 additions & 8 deletions README.md
Expand Up @@ -20,23 +20,23 @@ First, install prerequisites with:

To train a model for an environment with a continuous action space:

$ python main.py --env=Pendulum-v0 --is_train=True
$ python main.py --env=Pendulum-v0 --is_train=True --display=True
$ python main.py --env_name=Pendulum-v0 --is_train=True
$ python main.py --env_name=Pendulum-v0 --is_train=True --display=True

To test and record the screens with gym:

$ python main.py --env=Pendulum-v0 --is_train=False
$ python main.py --env=Pendulum-v0 --is_train=False --display=True
$ python main.py --env_name=Pendulum-v0 --is_train=False
$ python main.py --env_name=Pendulum-v0 --is_train=False --display=True


## Results

Training details of `Pendulum-v0` with different hyperparameters.

$ python main.py --env=Pendulum-v0 # dark green
$ python main.py --env=Pendulum-v0 --action_fn=tanh # light green
$ python main.py --env=Pendulum-v0 --use_batch_norm=True # yellow
$ python main.py --env=Pendulum-v0 --use_seperate_networks=True # green
$ python main.py --env_name=Pendulum-v0 # dark green
$ python main.py --env_name=Pendulum-v0 --action_fn=tanh # light green
$ python main.py --env_name=Pendulum-v0 --use_batch_norm=True # yellow
$ python main.py --env_name=Pendulum-v0 --use_seperate_networks=True # green

![Pendulum-v0_2016-07-15](https://github.com/carpedm20/naf-tensorflow/blob/master/assets/Pendulum-v0_2016-07-15.png)

Expand Down
7 changes: 4 additions & 3 deletions main.py
Expand Up @@ -57,14 +57,15 @@
np.random.seed(conf.random_seed)

def main(_):
model_dir = get_model_dir(conf,
model_dir = get_model_dir(conf,
['is_train', 'random_seed', 'monitor', 'display', 'log_level'])

preprocess_conf(conf)

with tf.Session() as sess:
# environment
env = gym.make(conf.env_name)
env._seed(conf.random_seed)
env.seed(conf.random_seed)

assert isinstance(env.observation_space, gym.spaces.Box), \
"observation space must be continuous"
Expand All @@ -86,7 +87,7 @@ def main(_):
'sess': sess,
'input_shape': env.observation_space.shape,
'action_size': env.action_space.shape[0],
'hidden_dims': conf.hidden_dims,
'hidden_dims': conf.hidden_dims,
'use_batch_norm': conf.use_batch_norm,
'use_seperate_networks': conf.use_seperate_networks,
'hidden_w': conf.hidden_w, 'action_w': conf.action_w,
Expand Down
8 changes: 4 additions & 4 deletions src/network.py
Expand Up @@ -63,16 +63,16 @@ def __init__(self, sess, input_shape, action_size, hidden_dims,

diag_elem = tf.exp(tf.slice(l, (0, pivot), (-1, 1)))
non_diag_elems = tf.slice(l, (0, pivot+1), (-1, count-1))
row = tf.pad(tf.concat(1, (diag_elem, non_diag_elems)), ((0, 0), (idx, 0)))
row = tf.pad(tf.concat((diag_elem, non_diag_elems), 1), ((0, 0), (idx, 0)))
rows.append(row)

pivot += count

L = tf.transpose(tf.pack(rows, axis=1), (0, 2, 1))
P = tf.batch_matmul(L, tf.transpose(L, (0, 2, 1)))
L = tf.transpose(tf.stack(rows, axis=1), (0, 2, 1))
P = tf.matmul(L, tf.transpose(L, (0, 2, 1)))

tmp = tf.expand_dims(u - mu, -1)
A = -tf.batch_matmul(tf.transpose(tmp, [0, 2, 1]), tf.batch_matmul(P, tmp))/2
A = -tf.matmul(tf.transpose(tmp, [0, 2, 1]), tf.matmul(P, tmp))/2
A = tf.reshape(A, [-1, 1])

with tf.name_scope('Q'):
Expand Down
4 changes: 2 additions & 2 deletions src/statistic.py
Expand Up @@ -20,7 +20,7 @@ def __init__(self, sess, env_name, model_dir, variables, max_update_per_step, ma

self.model_dir = model_dir
self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep)
self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph)
self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph)

with tf.variable_scope('summary'):
scalar_summary_tags = ['total r', 'avg r', 'avg q', 'avg v', 'avg a', 'avg l']
Expand All @@ -30,7 +30,7 @@ def __init__(self, sess, env_name, model_dir, variables, max_update_per_step, ma

for tag in scalar_summary_tags:
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
self.summary_ops[tag] = tf.scalar_summary('%s/%s' % (self.env_name, tag), self.summary_placeholders[tag])
self.summary_ops[tag] = tf.summary.scalar('%s/%s' % (self.env_name, tag), self.summary_placeholders[tag])

def reset(self):
self.total_q = 0.
Expand Down
4 changes: 3 additions & 1 deletion utils.py
Expand Up @@ -7,7 +7,8 @@
pp = pprint.PrettyPrinter().pprint

def get_model_dir(config, exceptions=None):
attrs = config.__dict__['__flags']

attrs = config.__flags
pp(attrs)

keys = attrs.keys()
Expand All @@ -28,6 +29,7 @@ def preprocess_conf(conf):

for option, value in options.items():
option = option.lower()
value = value.value

if option == 'hidden_dims':
conf.hidden_dims = eval(conf.hidden_dims)
Expand Down

0 comments on commit c2193fe

Please sign in to comment.