Skip to content

Commit

Permalink
Add an option to zero out the gradient before the forward
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4905

Currently the optimizer zeros the gradients after the forward and before the backward. In a recent PyTorch change, it set all gradients to None by default. This has a benefit of reducing the memory consumption (since all gradients are None).

However, doing this after the forward does not provide any memory saving, since the the memory consumption is maximum at the end of forward.

It doesn't matter whether the gradient is set to None before the forward or after the forward. So we should set it before the forward the enjoy the memory saving.

We add a flag to enable it instead of doing it by default for now. Since people can override the zero_grad function (as the comment indicates), we do not know exactly what is done inside the function. This is to be on the safe side so the current flows may not be broken.

Once we have gone through the existing flows, we should make the flag enabled by default.

Reviewed By: tglik

Differential Revision: D44264848

fbshipit-source-id: a68c7cbd36439faf65801f0f771ae8bc9c130699
  • Loading branch information
sf-wind authored and facebook-github-bot committed Apr 11, 2023
1 parent 1bc3a33 commit 88217ca
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions detectron2/engine/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ class SimpleTrainer(TrainerBase):
or write your own training loop.
"""

def __init__(self, model, data_loader, optimizer, gather_metric_period=1):
def __init__(
self, model, data_loader, optimizer, gather_metric_period=1, zero_grad_before_forward=False
):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
Expand All @@ -251,6 +253,7 @@ def __init__(self, model, data_loader, optimizer, gather_metric_period=1):
optimizer: a torch optimizer.
gather_metric_period: an int. Every gather_metric_period iterations
the metrics are gathered from all the ranks to rank 0 and logged.
zero_grad_before_forward: whether to zero the gradients before the forward.
"""
super().__init__()

Expand All @@ -268,6 +271,7 @@ def __init__(self, model, data_loader, optimizer, gather_metric_period=1):
self._data_loader_iter_obj = None
self.optimizer = optimizer
self.gather_metric_period = gather_metric_period
self.zero_grad_before_forward = zero_grad_before_forward

def run_step(self):
"""
Expand All @@ -281,6 +285,13 @@ def run_step(self):
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start

if self.zero_grad_before_forward:
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()

"""
If you want to do something with the losses, you can wrap the model.
"""
Expand All @@ -290,12 +301,12 @@ def run_step(self):
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())

"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
if not self.zero_grad_before_forward:
"""
If you need to accumulate gradients or do something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
losses.backward()

self.after_backward()
Expand Down Expand Up @@ -400,13 +411,15 @@ def __init__(
data_loader,
optimizer,
gather_metric_period=1,
zero_grad_before_forward=False,
grad_scaler=None,
precision: torch.dtype = torch.float16,
log_grad_scaler: bool = False,
):
"""
Args:
model, data_loader, optimizer, gather_metric_period: same as in :class:`SimpleTrainer`.
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward:
same as in :class:`SimpleTrainer`.
grad_scaler: torch GradScaler to automatically scale gradients.
precision: torch.dtype as the target precision to cast to in computations
"""
Expand All @@ -415,7 +428,9 @@ def __init__(
assert not (model.device_ids and len(model.device_ids) > 1), unsupported
assert not isinstance(model, DataParallel), unsupported

super().__init__(model, data_loader, optimizer, gather_metric_period)
super().__init__(
model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward
)

if grad_scaler is None:
from torch.cuda.amp import GradScaler
Expand All @@ -437,6 +452,8 @@ def run_step(self):
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start

if self.zero_grad_before_forward:
self.optimizer.zero_grad()
with autocast(dtype=self.precision):
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
Expand All @@ -445,7 +462,9 @@ def run_step(self):
else:
losses = sum(loss_dict.values())

self.optimizer.zero_grad()
if not self.zero_grad_before_forward:
self.optimizer.zero_grad()

self.grad_scaler.scale(losses).backward()

if self.log_grad_scaler:
Expand Down

0 comments on commit 88217ca

Please sign in to comment.