Skip to content

Commit

Permalink
Merge bc95649 into 714fb8e
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed Jul 1, 2020
2 parents 714fb8e + bc95649 commit 254cb8a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
1 change: 1 addition & 0 deletions trax/fastmath/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def _custom_grad(f_vjp, f_original):
'erf': jax_special.erf,
'expit': jax_special.expit,
'grad': jax.grad,
'value_and_grad': jax.value_and_grad,
'jit': jax.jit,
'logsumexp': jax_special.logsumexp,
'lt': lax.lt,
Expand Down
20 changes: 20 additions & 0 deletions trax/fastmath/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,26 @@ def grad(*args, **kwargs):
return backend()['grad'](*args, **kwargs)


def value_and_grad(*args, **kwargs):
"""Computes the gradient of the specified function together with the value."""
if 'value_and_grad' in backend():
return backend()['value_and_grad'](*args, **kwargs)
grad_fn = grad(*args, **kwargs)
fn = args[0]
has_aux = False
if has_aux in kwargs:
has_aux = kwargs['has_aux']
if not has_aux:
def val_and_grad(*fn_args, **fn_kwargs):
return fn(*fn_args, **fn_kwargs), grad_fn(*fn_args, **fn_kwargs)
return val_and_grad
def val_and_grad_aux(*fn_args, **fn_kwargs):
g, aux = grad_fn(*fn_args, **fn_kwargs)
res, _ = fn(*fn_args, **fn_kwargs)
return (res, aux), g
return val_and_grad_aux


def vjp(*args, **kwargs):
"""Computes the vector-Jacobian product for the specified function."""
return backend()['vjp'](*args, **kwargs)
Expand Down
30 changes: 22 additions & 8 deletions trax/supervised/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
self._model_in_training = tl.Serial(model, task.loss_layer)
self._eval_model = model if eval_model is None else eval_model
self._eval_task = eval_task
self._rjust_len = max([0] + [len(name) for name in eval_task.metric_names])

self._output_dir = os.path.expanduser(output_dir) if output_dir else None
if output_dir is not None:
tf.io.gfile.makedirs(output_dir)
default_fn = _at_step_1_and_periodically_at(task.n_steps_per_checkpoint)
self._checkpoint_at = checkpoint_at or default_fn
self._eval_at = eval_at or default_fn
Expand All @@ -120,9 +124,10 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
_, _ = task.optimizer.tree_init(self._model_in_training.weights)

self._gradients_and_state_fn = (
fastmath.jit(fastmath.grad(self._model_in_training.pure_fn,
argnums=1, # arg1 of pure_fn: weights
has_aux=True))) # return (gradients, state)
fastmath.jit(fastmath.value_and_grad(
self._model_in_training.pure_fn,
argnums=1, # arg1 of pure_fn: weights
has_aux=True))) # return (loss, state), gradients

if eval_task is not None:
model_with_metrics = _model_with_metrics(self._eval_model, eval_task)
Expand All @@ -142,13 +147,23 @@ def run(self, n_steps=1):
weights = self._model_in_training.weights
state = self._model_in_training.state
slots = self._task.optimizer.slots
loss_acc, step_acc = 0.0, 0
for _ in range(n_steps):
self._step += 1
weights, state, slots = self._run_one_step(weights, state, slots)
loss, weights, state, slots = self._run_one_step(weights, state, slots)
loss_acc += loss
step_acc += 1
if self._eval_at(self._step):
self._model_in_training.weights = weights
self._model_in_training.state = state
self._eval_model.weights = self._model.weights
# TODO(lukaszkaiser): move this to a better place with other reporting
loss_name = self._task.loss_layer.name
step_acc = max(1, step_acc) # only here do avoid potential divide-by-0
self._log_step('%s %s | % .8f' % (
'train'.ljust(5), loss_name.rjust(self._rjust_len),
loss_acc / float(step_acc)))
loss_acc, step_acc = 0.0, 0
self.run_evals(weights, state)
if self._checkpoint_at(self._step):
self.save_checkpoint(weights, state, slots)
Expand Down Expand Up @@ -199,11 +214,11 @@ def _run_one_step(self, weights, state, slots):
opt_params = optimizer._init_opt_params # pylint: disable=protected-access
opt_params.update({'learning_rate': self._task.learning_rate(step)})

gradients, updated_state = (
(loss, updated_state), gradients = (
self._gradients_and_state_fn(batch, weights, state, self.new_rng()))
updated_weights, updated_slots, _ = (
optimizer.tree_update(step, gradients, weights, slots, opt_params))
return updated_weights, updated_state, updated_slots
return loss, updated_weights, updated_state, updated_slots

def run_evals(self, weights=None, state=None):
"""Runs and records evals for this training session.
Expand All @@ -230,10 +245,9 @@ def run_evals(self, weights=None, state=None):
self._metrics_fn(batch, metrics_weights, metrics_state, rng))
sums += metric_values
averages = sums / n_batches
rjust_len = max([0] + [len(name) for name in eval_task.metric_names])
for name, average_value in zip(eval_task.metric_names, averages):
self._log_step('%s %s | % .8f' % (
'eval'.ljust(5), name.rjust(rjust_len), average_value))
'eval'.ljust(5), name.rjust(self._rjust_len), average_value))

def _log_step(self, msg):
"""Logs message, labeled with the current training step number."""
Expand Down

0 comments on commit 254cb8a

Please sign in to comment.