Skip to content

Commit 0a092f6

Browse files
making optimization steps for hooks (Lightning-AI#2363)
*simplified optimizer step and zero grad overriding
1 parent d221817 commit 0a092f6

File tree

4 files changed

+68
-35
lines changed

4 files changed

+68
-35
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,15 @@ As you see, you're just organizing your PyTorch code - there's no abstraction.
143143

144144
And for the stuff that the Trainer abstracts out, you can [override any part](https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#extensibility) you want to do things like implement your own distributed training, 16-bit precision, or even a custom backward pass.
145145

146-
For example, here you could do your own backward pass
146+
For example, here you could do your own backward pass without worrying about GPUs, TPUs or 16-bit since we already handle it.
147147

148148
```python
149149
class LitModel(LightningModule):
150-
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
151-
second_order_closure=None):
150+
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
151+
second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
152152
optimizer.step()
153+
154+
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
153155
optimizer.zero_grad()
154156
```
155157

docs/source/optimizers.rst

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ For example, here step optimizer A every 2 batches and optimizer B every 4 batch
8383

8484
.. testcode::
8585

86-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
86+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
8787
optimizer.step()
88-
optimizer.zero_grad()
88+
89+
def optimizer_zero_grad(self, current_epoch, batch_idx, optimizer, opt_idx):
90+
optimizer.zero_grad()
8991

9092
# Alternating schedule for optimizer steps (ie: GANs)
91-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
93+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
9294
# update generator opt every 2 steps
9395
if optimizer_i == 0:
9496
if batch_nb % 2 == 0 :
@@ -109,7 +111,7 @@ Here we add a learning-rate warm up
109111
.. testcode::
110112

111113
# learning rate warm-up
112-
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
114+
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
113115
# warm up lr
114116
if self.trainer.global_step < 500:
115117
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)

pytorch_lightning/core/lightning.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ def optimizer_step(
11331133
optimizer: Optimizer,
11341134
optimizer_idx: int,
11351135
second_order_closure: Optional[Callable] = None,
1136+
on_tpu: bool = False,
1137+
using_native_amp: bool = False,
1138+
using_lbfgs: bool = False,
11361139
) -> None:
11371140
r"""
11381141
Override this method to adjust the default way the
@@ -1146,19 +1149,21 @@ def optimizer_step(
11461149
optimizer: A PyTorch optimizer
11471150
optimizer_idx: If you used multiple optimizers this indexes into that list.
11481151
second_order_closure: closure for second order methods
1152+
on_tpu: true if TPU backward is required
1153+
using_native_amp: True if using native amp
1154+
using_lbfgs: True if the matching optimizer is lbfgs
11491155
11501156
Examples:
11511157
.. code-block:: python
11521158
11531159
# DEFAULT
11541160
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
1155-
second_order_closure=None):
1161+
second_order_closure, on_tpu, using_native_amp, using_lbfgs):
11561162
optimizer.step()
1157-
optimizer.zero_grad()
11581163
11591164
# Alternating schedule for optimizer steps (i.e.: GANs)
11601165
def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
1161-
second_order_closure=None):
1166+
second_order_closure, on_tpu, using_native_amp, using_lbfgs):
11621167
# update generator opt every 2 steps
11631168
if optimizer_idx == 0:
11641169
if batch_idx % 2 == 0 :
@@ -1182,7 +1187,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
11821187
11831188
# learning rate warm-up
11841189
def optimizer_step(self, current_epoch, batch_idx, optimizer,
1185-
optimizer_idx, second_order_closure=None):
1190+
optimizer_idx, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
11861191
# warm up lr
11871192
if self.trainer.global_step < 500:
11881193
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
@@ -1198,30 +1203,20 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
11981203
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.
11991204
12001205
"""
1201-
if self.trainer.use_tpu and XLA_AVAILABLE:
1206+
if on_tpu:
12021207
xm.optimizer_step(optimizer)
1203-
elif isinstance(optimizer, torch.optim.LBFGS):
1204-
1205-
# native amp + lbfgs is a no go right now
1206-
if self.trainer.use_amp and self.trainer.use_native_amp:
1207-
raise MisconfigurationException(
1208-
'native PyTorch amp and lbfgs are not compatible.'
1209-
' To request, please file a Github issue in PyTorch and tag @mcarilli')
1208+
elif using_native_amp:
1209+
self.trainer.scaler.step(optimizer)
1210+
elif using_lbfgs:
12101211
optimizer.step(second_order_closure)
12111212
else:
1212-
if self.trainer.use_amp and self.trainer.use_native_amp:
1213-
self.trainer.scaler.step(optimizer)
1214-
else:
1215-
optimizer.step()
1216-
1217-
# in native 16-bit we need to update scaler after optimizer step
1218-
if self.trainer.use_amp and self.trainer.use_native_amp:
1219-
self.trainer.scaler.update()
1220-
1221-
# model hook
1222-
self.on_before_zero_grad(optimizer)
1213+
optimizer.step()
12231214

1224-
# clear gradients
1215+
def optimizer_zero_grad(self,
1216+
epoch: int,
1217+
batch_idx: int,
1218+
optimizer: Optimizer,
1219+
optimizer_idx: int):
12251220
optimizer.zero_grad()
12261221

12271222
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:

pytorch_lightning/trainer/training_loop.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,15 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
716716
# ------------------
717717
# .STEP + ZERO_GRAD
718718
# ------------------
719+
self.call_optimizer_step(optimizer, opt_idx, batch_idx, split_batch)
720+
721+
return grad_norm_dic
722+
723+
def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
724+
# calls .step(), .zero_grad()
725+
# override function to modify this behavior
719726
model = self.get_model()
727+
720728
with self.profiler.profile('optimizer_step'):
721729
lambda_closure = lambda: self.optimizer_closure(
722730
split_batch,
@@ -725,11 +733,37 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
725733
optimizer,
726734
self.hiddens
727735
).loss
728-
model.optimizer_step(self.current_epoch, batch_idx,
729-
optimizer, opt_idx,
730-
lambda_closure)
731736

732-
return grad_norm_dic
737+
# apply TPU optimizer
738+
if self.use_tpu and XLA_AVAILABLE:
739+
model.optimizer_step(self.current_epoch, batch_idx,
740+
optimizer, opt_idx, lambda_closure, on_tpu=True)
741+
742+
# for LBFGS do something a bit different
743+
elif isinstance(optimizer, torch.optim.LBFGS):
744+
745+
# native amp + lbfgs is a no go right now
746+
if self.use_amp and self.use_native_amp:
747+
raise MisconfigurationException(
748+
'native PyTorch amp and lbfgs are not compatible.'
749+
' To request, please file a Github issue in PyTorch and tag @mcarilli')
750+
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
751+
using_lbfgs=True)
752+
753+
# when using 16-bit
754+
else:
755+
native_amp = self.use_amp and self.use_native_amp
756+
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, native_amp)
757+
758+
# in native 16-bit we need to update scaler after optimizer step
759+
if self.use_amp and self.use_native_amp:
760+
self.scaler.update()
761+
762+
# model hook
763+
model.on_before_zero_grad(optimizer)
764+
765+
# clear gradients
766+
model.optimizer_zero_grad(self.current_epoch, batch_idx, optimizer, opt_idx)
733767

734768
def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens):
735769
"""

0 commit comments

Comments
 (0)