From 88217cad6d741ea1510d13e54089739f5a0f4d7d Mon Sep 17 00:00:00 2001 From: Fei Sun Date: Mon, 10 Apr 2023 18:43:01 -0700 Subject: [PATCH] Add an option to zero out the gradient before the forward Summary: Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4905 Currently the optimizer zeros the gradients after the forward and before the backward. In a recent PyTorch change, it set all gradients to None by default. This has a benefit of reducing the memory consumption (since all gradients are None). However, doing this after the forward does not provide any memory saving, since the the memory consumption is maximum at the end of forward. It doesn't matter whether the gradient is set to None before the forward or after the forward. So we should set it before the forward the enjoy the memory saving. We add a flag to enable it instead of doing it by default for now. Since people can override the zero_grad function (as the comment indicates), we do not know exactly what is done inside the function. This is to be on the safe side so the current flows may not be broken. Once we have gone through the existing flows, we should make the flag enabled by default. Reviewed By: tglik Differential Revision: D44264848 fbshipit-source-id: a68c7cbd36439faf65801f0f771ae8bc9c130699 --- detectron2/engine/train_loop.py | 39 ++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index 2ed26c0661..be42ed1570 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -242,7 +242,9 @@ class SimpleTrainer(TrainerBase): or write your own training loop. """ - def __init__(self, model, data_loader, optimizer, gather_metric_period=1): + def __init__( + self, model, data_loader, optimizer, gather_metric_period=1, zero_grad_before_forward=False + ): """ Args: model: a torch Module. Takes a data from data_loader and returns a @@ -251,6 +253,7 @@ def __init__(self, model, data_loader, optimizer, gather_metric_period=1): optimizer: a torch optimizer. gather_metric_period: an int. Every gather_metric_period iterations the metrics are gathered from all the ranks to rank 0 and logged. + zero_grad_before_forward: whether to zero the gradients before the forward. """ super().__init__() @@ -268,6 +271,7 @@ def __init__(self, model, data_loader, optimizer, gather_metric_period=1): self._data_loader_iter_obj = None self.optimizer = optimizer self.gather_metric_period = gather_metric_period + self.zero_grad_before_forward = zero_grad_before_forward def run_step(self): """ @@ -281,6 +285,13 @@ def run_step(self): data = next(self._data_loader_iter) data_time = time.perf_counter() - start + if self.zero_grad_before_forward: + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + """ If you want to do something with the losses, you can wrap the model. """ @@ -290,12 +301,12 @@ def run_step(self): loss_dict = {"total_loss": loss_dict} else: losses = sum(loss_dict.values()) - - """ - If you need to accumulate gradients or do something similar, you can - wrap the optimizer with your custom `zero_grad()` method. - """ - self.optimizer.zero_grad() + if not self.zero_grad_before_forward: + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() losses.backward() self.after_backward() @@ -400,13 +411,15 @@ def __init__( data_loader, optimizer, gather_metric_period=1, + zero_grad_before_forward=False, grad_scaler=None, precision: torch.dtype = torch.float16, log_grad_scaler: bool = False, ): """ Args: - model, data_loader, optimizer, gather_metric_period: same as in :class:`SimpleTrainer`. + model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward: + same as in :class:`SimpleTrainer`. grad_scaler: torch GradScaler to automatically scale gradients. precision: torch.dtype as the target precision to cast to in computations """ @@ -415,7 +428,9 @@ def __init__( assert not (model.device_ids and len(model.device_ids) > 1), unsupported assert not isinstance(model, DataParallel), unsupported - super().__init__(model, data_loader, optimizer, gather_metric_period) + super().__init__( + model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward + ) if grad_scaler is None: from torch.cuda.amp import GradScaler @@ -437,6 +452,8 @@ def run_step(self): data = next(self._data_loader_iter) data_time = time.perf_counter() - start + if self.zero_grad_before_forward: + self.optimizer.zero_grad() with autocast(dtype=self.precision): loss_dict = self.model(data) if isinstance(loss_dict, torch.Tensor): @@ -445,7 +462,9 @@ def run_step(self): else: losses = sum(loss_dict.values()) - self.optimizer.zero_grad() + if not self.zero_grad_before_forward: + self.optimizer.zero_grad() + self.grad_scaler.scale(losses).backward() if self.log_grad_scaler: