Skip to content

Commit

Permalink
Add a simple version of Adasum to Trax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 359767146
  • Loading branch information
Lukasz Kaiser authored and Copybara-Service committed Feb 26, 2021
1 parent 889768b commit 56c747a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
56 changes: 42 additions & 14 deletions trax/optimizers/trainer.py
Expand Up @@ -47,10 +47,11 @@ class Trainer:
side effect, the function also modifies the model weights and optimizer slots.
"""

def __init__(self, model_with_loss, optimizer, n_devices=None):
def __init__(self, model_with_loss, optimizer, n_devices=None, adasum=False):
self._model_with_loss = model_with_loss
self._optimizer = optimizer
self._n_devices = n_devices or fastmath.device_count()
self._adasum = adasum

# optimizer slots and opt_params may need to be replicated
self._slots, self._opt_params = tl.for_n_devices(
Expand All @@ -77,6 +78,7 @@ def __init__(self, model_with_loss, optimizer, n_devices=None):
self._optimizer,
n_devices=self._n_devices,
accelerate=True,
adasum=self._adasum
)
)

Expand Down Expand Up @@ -160,11 +162,30 @@ def _unreplicate(self, x):
return fastmath.nested_map(lambda x: x[0], x)


def _average_multidevice_gradients(gradients):
def _average_multidevice_gradients(gradients, adasum=False):
"""Averages gradients over all the devices across different hosts."""
gradients_psum = fastmath.psum(gradients, 'batch') # sum over all devices
n_devices_total = fastmath.psum(jnp.array(1.0), 'batch')
return fastmath.nested_map(lambda g: g / n_devices_total, gradients_psum)
n = fastmath.psum(jnp.array(1.0), 'batch') # number of devices on all hosts
if not adasum:
return fastmath.nested_map(lambda g: g / n, gradients_psum)
# This implements an approximation of the Adasum algorithm from the following
# paper: https://arxiv.org/pdf/2006.02924.pdf
# Since implementing halving and averaging half-by-half is tricky, we first
# average all hosts, so we use the sum as a point of comparison for gradients.
# So for 2 devices, this algorithm is the same as in the paper, but with more
# devices it does a different kind of averaging. It still has the property
# that orthogonal gradients will result in a sum while identical ones will
# be averaged, as postulated in the paper.
adasum_nominator = fastmath.nested_map_multiarg(
lambda g, q: jnp.vdot(g, q), # pylint: disable=unnecessary-lambda
gradients, gradients_psum)
grad_norm = fastmath.nested_map(lambda g: jnp.vdot(g, g), gradients)
# If all devices have identical gradients, then the nominator is equal
# to n * grad_norm; if they're orthogonal, then nominator = grad_norm.
scaled_grads = fastmath.nested_map_multiarg(
lambda g, nominator, g_norm: g*(1 - (nominator - g_norm) / (n * g_norm)),
gradients, adasum_nominator, grad_norm)
return fastmath.psum(scaled_grads, 'batch')


# Returns a function with the following signature:
Expand All @@ -173,7 +194,8 @@ def _average_multidevice_gradients(gradients):
def _accelerate_update_fn(forward_and_backward_fn,
optimizer,
n_devices,
accelerate=True):
accelerate=True,
adasum=False):
"""Accelerates the given forward_and_backward_fn function."""
if n_devices == 1:
def single_device_update_fn(
Expand Down Expand Up @@ -202,7 +224,7 @@ def _multi_device_update_fn(
weights, slots = weights_and_slots
(loss, state), gradients = (
forward_and_backward_fn(batch, weights, state, rng))
gradients = _average_multidevice_gradients(gradients)
gradients = _average_multidevice_gradients(gradients, adasum=adasum)
weights, slots, stats = optimizer.tree_update(
step, gradients, weights, slots, opt_params, store_slots=False)
stats['loss'] = loss
Expand Down Expand Up @@ -235,7 +257,7 @@ class ReversibleSerialTrainer:
"""

def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None,
memoize_jit=True, free_accelerators_on_step=False):
memoize_jit=True, free_accelerators_on_step=False, adasum=False):
"""Creates a ReversibleSerialTrainer and the needed optimizers.
This trainer performs updates equivalent to using the default Trainer on::
Expand Down Expand Up @@ -263,11 +285,13 @@ def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None,
free_accelerators_on_step: If true, frees memory on accelerators when
starting a step. All layers and arguments must be on host for that,
otherwise it can lead to failures. Can prevent memory fragmentation.
adasum: if True, use adaptive summation to gather multi-device gradients.
"""
self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
self._loss_layer = loss_layer
self._optimizer_fn = optimizer_fn
self._n_devices = n_devices or fastmath.device_count()
self._adasum = adasum
self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
self._n_steps_per_log = 100 # Log layers and stats every 100 steps.
self._n_async_layers = 1 # How many layers to run asynchronously.
Expand Down Expand Up @@ -303,11 +327,12 @@ def _make_optimizer(layer):
self._fbos = []
for i, (std_layer, rev_layers) in enumerate(self._blocks):
(std_opt, rev_opts) = self._optimizers[i]
std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt, self._n_devices)
std_fbo = _fbo_with_layer_and_opt(
std_layer, std_opt, self._n_devices, adasum=self._adasum)
rev_and_fbos = []
for layer, opt in zip(rev_layers, rev_opts):
rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
layer, opt, self._n_devices)
layer, opt, self._n_devices, self._adasum)
# The donated args are (outputs, weights, grads) and we can donate
# them because weights and grads are immediately replaced and in
# case of reversible layers, the outputs are never used again.
Expand All @@ -320,7 +345,7 @@ def _make_optimizer(layer):
self._fbos.append((jit_std_fbo, rev_and_fbos))

loss_fbo = _fbo_with_layer_and_opt(
self._loss_layer, self._loss_opt, self._n_devices, 'loss')
self._loss_layer, self._loss_opt, self._n_devices, 'loss', self._adasum)
self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))

@property
Expand Down Expand Up @@ -670,7 +695,8 @@ def _run_backward_one_reversible(self, layer, stack, grad_stack, step, rng,
# We call them in short FBO for "Forward + Backward + Optimizer update".


def _fbo_with_layer_and_opt(layer, optimizer, n_devices, stats_name=None):
def _fbo_with_layer_and_opt(layer, optimizer, n_devices,
stats_name=None, adasum=False):
"""Create the fbo function for a given layer and optimizer."""
def fbo(inputs, weights, grads, state, slots, opt_params, rng, step):
"""FBO of the layer."""
Expand Down Expand Up @@ -698,7 +724,8 @@ def pure_fn_without_state_and_rng(x, w):

# In multi-device setting, average gradients from multiple devices.
if n_devices > 1:
grads_weights = _average_multidevice_gradients(grads_weights)
grads_weights = _average_multidevice_gradients(
grads_weights, adasum=adasum)

# Run the optimizer.
new_weights, new_slots, stats = optimizer.tree_update(
Expand All @@ -715,7 +742,7 @@ def pure_fn_without_state_and_rng(x, w):
# This function uses the `reverse_and_grad` method of reversible layers.


def _reverse_and_fbo_with_layer_and_opt(layer, optimizer, n_devices):
def _reverse_and_fbo_with_layer_and_opt(layer, optimizer, n_devices, adasum):
"""Create the reverse_and_fbo function for a given layer and optimizer."""
def reverse_and_fbo(output, weights, grads, state, new_state,
slots, opt_params, rng, step):
Expand All @@ -730,7 +757,8 @@ def reverse_and_fbo(output, weights, grads, state, new_state,

# In multi-device setting, average gradients from multiple devices.
if n_devices > 1:
grads_weights = _average_multidevice_gradients(grads_weights)
grads_weights = _average_multidevice_gradients(
grads_weights, adasum=adasum)

# Run the optimizer.
new_weights, new_slots, stats = optimizer.tree_update(
Expand Down
5 changes: 4 additions & 1 deletion trax/supervised/trainer_lib.py
Expand Up @@ -530,6 +530,7 @@ def train(output_dir,
use_loop=True,
loss_chunk_size=0,
use_memory_efficient_trainer=False,
adasum=False,
init_checkpoint=None,
callbacks=None,
additional_train_tasks=None,
Expand Down Expand Up @@ -564,7 +565,8 @@ def train(output_dir,
checkpoint_lowest: save the checkpoint lowest at this metric.
use_loop: whether to use training.Loop instead of Trainer.
loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
use_memory_efficient_trainer: whether to use memory-efficient trainer..
use_memory_efficient_trainer: whether to use memory-efficient trainer.
adasum: if True, use adaptive summation for multi-device gradients.
init_checkpoint: a checkpoint for fine tuning.
callbacks: a list of callbacks to call during training.
additional_train_tasks: additional tasks which should be performed during
Expand Down Expand Up @@ -629,6 +631,7 @@ def train(output_dir,
n_devices=n_devices,
loss_chunk_size=loss_chunk_size,
use_memory_efficient_trainer=use_memory_efficient_trainer,
adasum=adasum,
random_seed=random_seed,
callbacks=callbacks,
)
Expand Down
9 changes: 7 additions & 2 deletions trax/supervised/training.py
Expand Up @@ -113,6 +113,7 @@ def __init__(
random_seed=None,
loss_chunk_size=0,
use_memory_efficient_trainer=False,
adasum=False,
callbacks=None,
):
"""Configures a training ``Loop``, including a random initialization.
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
computation more more memory-efficient.
use_memory_efficient_trainer: whether to use a special memory-efficient
trainer; if set to 2, the memory efficiency if very aggressive
adasum: if True, use adaptive summation for multi-device gradients
callbacks: List of subclasses of StepCallback to call on training
steps.
"""
Expand Down Expand Up @@ -187,6 +189,7 @@ def __init__(

self._use_memory_efficient_trainer = use_memory_efficient_trainer
self._loss_chunk_size = loss_chunk_size
self._adasum = adasum
# TODO(lukaszkaiser): can we have different eval models and save memory?
if use_memory_efficient_trainer:
assert len(tasks) == 1, 'only single task supported for now'
Expand Down Expand Up @@ -302,7 +305,8 @@ def _init_trainer(self, task):
shapes.signature(task.sample_batch)
)
task.optimizer.tree_init(model_in_training.weights)
return optimizers.Trainer(model_in_training, task.optimizer)
return optimizers.Trainer(
model_in_training, task.optimizer, adasum=self._adasum)
# In the memory-efficient path, we initialize the model here.
blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
[self._model, task.loss_layer], loss_chunk_size=self._loss_chunk_size)
Expand All @@ -312,7 +316,8 @@ def _init_trainer(self, task):
# TODO(lukaszkaiser): here optimizer is a function, revisit this.
return optimizers.ReversibleSerialTrainer(
blocks, loss_layer, task.optimizer,
free_accelerators_on_step=(self._use_memory_efficient_trainer == 2))
free_accelerators_on_step=(self._use_memory_efficient_trainer == 2),
adasum=self._adasum)

def _init_evaluator(self, eval_task):
"""Initializes the per-task evaluator."""
Expand Down

0 comments on commit 56c747a

Please sign in to comment.