Skip to content

Commit

Permalink
Merge pull request #689 from mv1388/documentation_for_multi_loss_mult…
Browse files Browse the repository at this point in the history
…i_optimizer

Documentation for multi-loss multi-optimizer
  • Loading branch information
mv1388 committed Jul 11, 2022
2 parents 5fb5e38 + b7b6456 commit 4d497cd
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion docs/source/torchtrain/multi_loss_opti.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
Multi-Loss and Multi-Optimizer
==============================

TODO
TrainLoop supports training using multiple separate losses and/or multiple different
optimizers at the same time.

The multi loss/optimizer functionality is achieved by wrapping multiple loss or
optimizer objects into the ``MultiLoss`` and ``MultiOptimizer`` wrappers respectively
provided in :mod:`aitoolbox.torchtrain.multi_loss_optim`.


Multi-Loss Training
-------------------

To implement training with multiple losses use :class:`aitoolbox.torchtrain.multi_loss_optim.MultiLoss`
to wrap different calculated losses together and return them from model's ``get_loss()`` function.
Train loop will then automatically know to correctly execute backprop through each of the losses.

Multiple losses need to be provided to the MultiLoss as a dict:

.. code-block:: python
MultiLoss({'main_loss': main_loss, 'aux_loss': aux_loss})
In case of more elaborate backprop logic is needed one can override MultiLoss'
:meth:`aitoolbox.torchtrain.multi_loss_optim.MultiLoss.backward` method with the desired advanced logic.


Multi-Optimizer Training
------------------------

To use multiple optimizers, for example each one optimizing a different part of the model, define multiple
optimizers each with access to different parameters of the model. These separate optimizers need to be provided
in a list to the :class:`aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer` wrapper.
The ``MultiOptimizer`` can subsequently be given to the TrainLoop the same way as the normal single optimizer.

``MultiOptimizer`` definition example:

.. code-block:: python
MultiOptimizer([optimizer_1, optimizer_2])
When more advanced multi-optimizer training logic is required the user can override the
:meth:`aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer.step` and/or the
:meth:`aitoolbox.torchtrain.multi_loss_optim.MultiOptimizer.zero_grad` methods as needed.

Lastly, when using the ``MultiOptimizer`` the training state checkpoint saving is also automatically
handled by the train loop. As part of this the train loop automatically stores the state of
each of the optimizers wrapped inside of the ``MultiOptimizer``. The same functionality is provided
when loading the saved model.

0 comments on commit 4d497cd

Please sign in to comment.