Skip to content

Commit

Permalink
add multiple devices support
Browse files Browse the repository at this point in the history
  • Loading branch information
freedomtan committed Apr 18, 2018
1 parent 955b7f2 commit a3b37ef
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion models/official/squeezenet/squeezenet_main.py
Expand Up @@ -137,8 +137,9 @@ def main(argv):
params=dict(params, use_tpu=FLAGS.use_tpu),
)
else:
#model_fn=squeezenet_model.model_fn,
estimator = tf.estimator.Estimator(
model_fn=squeezenet_model.model_fn,
model_fn=tf.contrib.estimator.replicate_model_fn(squeezenet_model.model_fn),
config=run_config,
params=dict(params, batch_size=FLAGS.batch_size, use_tpu=FLAGS.use_tpu),
)
Expand Down
2 changes: 2 additions & 0 deletions models/official/squeezenet/squeezenet_model.py
Expand Up @@ -132,6 +132,8 @@ def model_fn(features, labels, mode, params):

if params["use_tpu"]:
optimizer = tpu_optimizer.CrossShardOptimizer(optimizer)
else:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

train_op = optimizer.minimize(loss, tf.train.get_global_step())

Expand Down

0 comments on commit a3b37ef

Please sign in to comment.