In [None]:
# default_exp collections.callbacks.ema

In [None]:
# hide
%load_ext nb_black
%load_ext autoreload
%autoreload 2
%matplotlib inline

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.export import Config
from nbdev.showdoc import *
from timm.utils import *

warnings.filterwarnings("ignore")
setup_default_logging()

<IPython.core.display.Javascript object>

# Model Exponential Moving Average Callback for PyTorch Lightning
> This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

In [None]:
# export
import logging

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
from timm.utils.model import get_state_dict, unwrap_model
from timm.utils.model_ema import ModelEmaV2

from gale.utils.logger import log_main_process

_logger = logging.getLogger(__name__)

<IPython.core.display.Javascript object>

In [None]:
# export
class EMACallback(Callback):
    """
    Model Exponential Moving Average. Empirically it has been found that using the moving average
    of the trained parameters of a deep network is better than using its trained parameters directly.

    If `use_ema_weights`, then the ema parameters of the network is set after training end.
    """

    def __init__(self, decay=0.9999, use_ema_weights: bool = True):
        self.decay = decay
        self.ema = None
        self.use_ema_weights = use_ema_weights

    def on_fit_start(self, trainer, pl_module):
        "Initialize `ModelEmaV2` from timm to keep a copy of the moving average of the weights"
        self.ema = ModelEmaV2(pl_module, decay=self.decay, device=None)

    def on_train_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        "Update the stored parameters using a moving average"
        # Update currently maintained parameters.
        self.ema.update(pl_module)

    def on_validation_epoch_start(self, trainer, pl_module):
        "do validation using the stored parameters"
        # save original parameters before replacing with EMA version
        self.store(pl_module.parameters())

        # update the LightningModule with the EMA weights
        # ~ Copy EMA parameters to LightningModule
        self.copy_to(self.ema.module.parameters(), pl_module.parameters())

    def on_validation_end(self, trainer, pl_module):
        "Restore original parameters to resume training later"
        self.restore(pl_module.parameters())

    def on_train_end(self, trainer, pl_module):
        # update the LightningModule with the EMA weights
        if self.use_ema_weights:
            self.copy_to(self.ema.module.parameters(), pl_module.parameters())
            msg = "Model weights replaced with the EMA version."
            log_main_process(_logger, logging.INFO, msg)

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        if self.ema is not None:
            return {"state_dict_ema": get_state_dict(self.ema, unwrap_model)}

    def on_load_checkpoint(self, callback_state):
        if self.ema is not None:
            self.ema.module.load_state_dict(callback_state["state_dict_ema"])

    def store(self, parameters):
        "Save the current parameters for restoring later."
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

    def copy_to(self, shadow_parameters, parameters):
        "Copy current parameters into given collection of parameters."
        for s_param, param in zip(shadow_parameters, parameters):
            if param.requires_grad:
                param.data.copy_(s_param.data)

<IPython.core.display.Javascript object>

In [None]:
show_doc(EMACallback)

<h2 id="EMACallback" class="doc_header"><code>class</code> <code>EMACallback</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>EMACallback</code>(**`decay`**=*`0.9999`*, **`use_ema_weights`**:`bool`=*`True`*) :: `Callback`

Model Exponential Moving Average. Empirically it has been found that using the moving average 
of the trained parameters of a deep network is better than using its trained parameters directly.

If `use_ema_weights`, then the ema parameters of the network is set after training end.

<IPython.core.display.Javascript object>

## Export-

In [None]:
# hide
notebook2script()

Converted 00_utils.logger.ipynb.
Converted 00a_utils.display.ipynb.
Converted 00b_utils.structures.ipynb.
Converted 01_torch_utils.ipynb.
Converted 01a_losses.ipynb.
Converted 02_optimizer.ipynb.
Converted 02a_schedules.ipynb.
Converted 03_core-classes.ipynb.
Converted 04_classification.models.backbones.ipynb.
Converted 04a_classification.models.heads.ipynb.
Converted 04b_classification.model.meta_arch.common.ipynb.
Converted 04b_classification.model.meta_arch.vit.ipynb.
Converted 05_classification.core.ipynb.
Converted 05a_classification.augment.ipynb.
Converted 05b_classification.data.ipynb.
Converted 06_classification.task.ipynb.
Converted 07_collections.pandas.ipynb.
Converted 07a_collections.callbacks.notebook.ipynb.
Converted 07b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>