Skip to content

Commit

Permalink
[TRAX] s/eval_task/eval_tasks in training.Loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322692154
  • Loading branch information
afrozenator authored and Copybara-Service committed Jul 23, 2020
1 parent 5e6dc04 commit 969dffa
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,9 @@ def train(output_dir,
if checkpoints_at is not None:
checkpoint_at = lambda step: step in checkpoints_at
loop = training.Loop(model(mode='train'),
train_task,
[train_task],
eval_model=model(mode='eval'),
eval_task=eval_task,
eval_tasks=[eval_task],
output_dir=output_dir,
checkpoint_at=checkpoint_at)

Expand Down Expand Up @@ -704,10 +704,13 @@ def mapped_update(weights_and_slots, i, opt_params, batch, state, rng):
# the number of devices on this host machine, however psum goes over all
# devices of all hosts (ex: a TPU pod) and we need to be averaging over all
# of them.
grads = jax.tree_util.tree_map(
lambda g: ( # pylint: disable=g-long-lambda
fastmath.psum(g, 'batch') / fastmath.psum(np.array(1.0), 'batch')),
grads)
#
# Collect all gradients.
grads = fastmath.psum(grads, 'batch')
n_devices_total = fastmath.psum(np.array(1.0), 'batch')
# Average across hosts.
grads = jax.tree_util.tree_map(lambda g: g / n_devices_total, grads)

new_weights, new_slots, stats = optimizer.tree_update(
i, grads, weights, slots, opt_params)
return (new_weights, new_slots), stats, state, subrng
Expand Down

0 comments on commit 969dffa

Please sign in to comment.