Skip to content

Commit

Permalink
Inject current learning rate to the optimizer state. Allow plotting t…
Browse files Browse the repository at this point in the history
…he norm of any array in the train state.

We can now plot the norm of current parameter values, the learning rate or any other array part of the optimizer state:
```
config.plot_norm_train_state_patterns = [
  # norm of the kernel params in MoE layers.
  'params/.*/moe/mlp/.*/kernel',
  # norm (=value) of current learning rate.
  'opt_state/.*/hyperparameter/learning_rate',
  # norm of 1st order moments of the kernel gradients in MoE layers (i.e. Adam inner state).
  'opt_state/.*/mu/.*/moe/mlp/.*/kernel',
]
```

PiperOrigin-RevId: 499588562
  • Loading branch information
jpuigcerver authored and Copybara-Service committed Jan 5, 2023
1 parent fcf81b0 commit 9ff0aed
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
10 changes: 7 additions & 3 deletions vmoe/train/optimizer.py
Expand Up @@ -104,8 +104,7 @@ def create_optimizer(
# WARNING: Use this with caution. Notice that this is NOT equivalent to having
# a specific learning rate per parameter, since the scale that you use here
# will affect the state of the optimizers like momentum.
if gradient_scale:
ops.append(gradient_scaling(gradient_scale))
ops.append(gradient_scaling(gradient_scale))
# Optionally, add gradient clipping.
ops.append(gradient_clipping(**(gradient_clip or {})))
# Optimizer-dependant scaling of gradients.
Expand All @@ -127,7 +126,12 @@ def create_optimizer(
learning_rate = {'schedule': 'constant', 'value': learning_rate}
lr_schedule = schedule.create_learning_rate_schedule(
**learning_rate, total_steps=total_steps)
ops.append(optax.scale_by_schedule(lambda count: -lr_schedule(count)))
# Wrap scale with inject_hyperparams to keep the last learning rate in the
# optimizer state.
@optax.inject_hyperparams
def _scale_by_learning_rate(learning_rate):
return optax.scale(-learning_rate)
ops.append(_scale_by_learning_rate(lr_schedule))
# Optionally, freeze some variables.
ops.append(freeze_weights(
frozen_pattern=frozen_pattern, trainable_pattern=trainable_pattern))
Expand Down
34 changes: 26 additions & 8 deletions vmoe/train/trainer.py
Expand Up @@ -576,7 +576,8 @@ def train_step(
images: Array,
labels: Array,
loss_fn: Callable[[Array, Array], Array],
plot_grad_norm_name_fn: Optional[Callable[[str], bool]] = None,
plot_norm_grad_fn: Optional[Callable[[str], bool]] = None,
plot_norm_train_state_fn: Optional[Callable[[str], bool]] = None,
) -> Tuple[TrainState, Mapping[str, Any]]:
"""Performs one update step of the given TrainState object ."""
rngs, next_rngs = utils.tree_rngs_split(state.rngs)
Expand All @@ -592,15 +593,29 @@ def compute_grads_and_metrics(params):
return total_loss, metrics

grads, metrics = compute_grads_and_metrics(state.params)
# Update train state.
state = state.apply_gradients(grads=grads, rngs=next_rngs)

if plot_grad_norm_name_fn:
if plot_norm_grad_fn:
# Compute norm of selected parameters and add them as auxiliary metrics.
metrics.update({
f'grads_norm/{name}': jnp.sqrt(jnp.vdot(grad, grad))
f'norm_grads/{name}': jnp.sqrt(jnp.vdot(grad, grad))
for name, grad in flax.traverse_util.flatten_dict(
grads, sep='/').items() if plot_grad_norm_name_fn(name)
grads, sep='/').items() if plot_norm_grad_fn(name)
})
return state.apply_gradients(grads=grads, rngs=next_rngs), metrics

if plot_norm_train_state_fn:
# Compute the norm of selected arrays in the train state and add them as
# auxiliary metrics. This is useful to plot the norm of the values of
# parameters, the current learning rate or any other internal state.
metrics.update({
f'norm_train_state/{name}': jnp.sqrt(jnp.vdot(value, value))
for name, value in flax.traverse_util.flatten_dict(
flax.serialization.to_state_dict(state), sep='/').items()
if plot_norm_train_state_fn(name)
})

return state, metrics


def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
Expand Down Expand Up @@ -662,12 +677,15 @@ def _train_and_evaluate(config: ml_collections.ConfigDict, workdir: str,
initialization_kwargs=config.get('initialization'))
init_step = int(train_state.step)
train_loss_fn, eval_loss_fn, label_pred_fn = get_loss_fn(**config.loss)
plot_grad_norm_name_fn = utils.make_match_fn_from_regex_list(
config.get('plot_grad_norm_patterns'))
plot_norm_grad_fn = utils.make_match_fn_from_regex_list(
config.get('plot_norm_grad_patterns'))
plot_norm_train_state_fn = utils.make_match_fn_from_regex_list(
config.get('plot_norm_train_state_patterns'))
train_step_fn = functools.partial(
train_step,
loss_fn=train_loss_fn,
plot_grad_norm_name_fn=plot_grad_norm_name_fn)
plot_norm_grad_fn=plot_norm_grad_fn,
plot_norm_train_state_fn=plot_norm_train_state_fn)
if config.get('adversarial', {}):
adversarial_config = config.adversarial.to_dict()
train_step_fn = wrap_train_step_with_adversarial_attack(
Expand Down

0 comments on commit 9ff0aed

Please sign in to comment.