Skip to content

Commit 29ebe92

Browse files
williamFalconBorda
andauthored
support for native amp (Lightning-AI#1561)
* adding native amp suppport * adding native amp suppport * adding native amp suppport * adding native amp suppport * autocast * autocast * autocast * autocast * autocast * autocast * removed comments * removed comments * added state saving * added state saving * try install amp again * added state saving * drop Apex reinstall Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 41b6cbb commit 29ebe92

File tree

11 files changed

+100
-25
lines changed

11 files changed

+100
-25
lines changed

.drone.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ steps:
3131
- pip install pip -U
3232
- pip --version
3333
- nvidia-smi
34-
# - bash ./tests/install_AMP.sh
34+
#- bash ./tests/install_AMP.sh
3535
- apt-get update && apt-get install -y cmake
3636
- pip install -r requirements.txt --user -q
3737
- pip install -r ./tests/requirements-devel.txt --user -q

pytorch_lightning/core/hooks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,15 @@ def backward(self, use_amp, loss, optimizer):
140140
141141
"""
142142
if trainer.precision == 16:
143-
144143
# .backward is not special on 16-bit with TPUs
145-
if not trainer.on_tpu:
144+
if trainer.on_tpu:
145+
return
146+
147+
if self.trainer.use_native_amp:
148+
self.trainer.scaler.scale(loss).backward()
149+
150+
# TODO: remove in v0.8.0
151+
else:
146152
with amp.scale_loss(loss, optimizer) as scaled_loss:
147153
scaled_loss.backward()
148154
else:

pytorch_lightning/core/lightning.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,9 +1157,22 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
11571157
if self.trainer.use_tpu and XLA_AVAILABLE:
11581158
xm.optimizer_step(optimizer)
11591159
elif isinstance(optimizer, torch.optim.LBFGS):
1160+
1161+
# native amp + lbfgs is a no go right now
1162+
if self.use_amp and self.use_native_amp:
1163+
m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
1164+
'a Github issue in PyTorch and tag @mcarilli'
1165+
raise MisconfigurationException(m)
11601166
optimizer.step(second_order_closure)
11611167
else:
1162-
optimizer.step()
1168+
if self.use_amp and self.use_native_amp:
1169+
self.trainer.scaler.step(optimizer)
1170+
else:
1171+
optimizer.step()
1172+
1173+
# in native 16-bit we need to update scaler after optimizer step
1174+
if self.use_amp and self.use_native_amp:
1175+
self.trainer.scaler.update()
11631176

11641177
# model hook
11651178
self.on_before_zero_grad(optimizer)

pytorch_lightning/trainer/auto_mix_precision.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from abc import ABC
2+
import torch
23

34
from pytorch_lightning import _logger as log
5+
from pytorch_lightning.utilities import rank_zero_warn
46

57
try:
68
from apex import amp
@@ -15,8 +17,28 @@ class TrainerAMPMixin(ABC):
1517
# this is just a summary on variables used in this abstract class,
1618
# the proper values/initialisation should be done in child class
1719
precision: int
20+
use_native_amp: bool
1821

1922
def init_amp(self, use_amp):
23+
# TODO: remove in v 0.8.0
24+
if self.use_native_amp:
25+
rank_zero_warn("`amp_level` has been deprecated since v0.7.4 "
26+
"(native amp does not require it)"
27+
" and this argument will be removed in v0.8.0", DeprecationWarning)
28+
29+
# Backward compatibility, TODO: remove in v0.9.0
30+
if use_amp is not None:
31+
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
32+
" and this argument will be removed in v0.9.0", DeprecationWarning)
33+
self.precision = 16 if use_amp else 32
34+
35+
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
36+
37+
if use_amp and self.use_native_amp:
38+
log.info('Using 16bit precision.')
39+
return
40+
41+
# TODO: remove all below for v0.8.0
2042
if use_amp and not APEX_AVAILABLE: # pragma: no-cover
2143
raise ModuleNotFoundError("""
2244
You set `use_amp=True` but do not have apex installed.
@@ -31,4 +53,4 @@ def init_amp(self, use_amp):
3153

3254
@property
3355
def use_amp(self) -> bool:
34-
return self.precision == 16 and APEX_AVAILABLE
56+
return self.precision == 16

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class TrainerDDPMixin(ABC):
151151
amp_level: str
152152
use_tpu: bool
153153
default_root_dir: str
154+
use_native_amp: bool
154155

155156
@property
156157
@abstractmethod
@@ -350,8 +351,8 @@ def ddp_train(self, process_idx, model):
350351

351352
# AMP
352353
# run through amp wrapper before going to distributed DP
353-
if self.use_amp:
354-
# An example
354+
# TODO: remove in v0.8.0
355+
if self.use_amp and not self.use_native_amp:
355356
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
356357
self.optimizers = optimizers
357358

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ class TrainerDPMixin(ABC):
394394
tpu_local_core_rank: int
395395
tpu_global_core_rank: int
396396
use_tpu: bool
397+
use_native_amp: bool
397398
data_parallel_device_ids: ...
398399
logger: Union[LightningLoggerBase, bool]
399400

@@ -481,7 +482,8 @@ def single_gpu_train(self, model):
481482
# allow for lr schedulers as well
482483
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
483484

484-
if self.use_amp:
485+
# TODO: update for 0.8.0
486+
if self.use_amp and not self.use_native_amp:
485487
# An example
486488
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
487489
self.optimizers = optimizers
@@ -528,9 +530,16 @@ def dp_train(self, model):
528530

529531
model.cuda(self.root_gpu)
530532

533+
# hack forward to do autocast for the user
534+
model_autocast_original_forward = model.forward
535+
if self.use_amp and self.use_native_amp:
536+
# wrap the user's forward in autocast and give it back at the end
537+
model.forward = torch.cuda.amp.autocast()(model.forward)
538+
539+
# TODO: remove in v0.8.0
531540
# check for this bug (amp + dp + !01 doesn't work)
532541
# https://github.com/NVIDIA/apex/issues/227
533-
if self.use_dp and self.use_amp:
542+
if self.use_dp and self.use_amp and not self.use_native_amp:
534543
if self.amp_level == 'O2':
535544
raise MisconfigurationException(
536545
f'Amp level {self.amp_level} with DataParallel is not supported.'
@@ -551,6 +560,8 @@ def dp_train(self, model):
551560

552561
self.run_pretrain_routine(model)
553562

563+
model.forward = model_autocast_original_forward
564+
554565
def horovod_train(self, model):
555566
# Horovod: initialize library
556567
hvd.init()

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
268268
# -----------------
269269
# RUN EVALUATION STEP
270270
# -----------------
271-
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
271+
if self.use_amp and self.use_native_amp:
272+
with torch.cuda.amp.autocast():
273+
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
274+
else:
275+
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
272276

273277
# on dp / ddp2 might still want to do something with the batch parts
274278
if test_mode:

pytorch_lightning/trainer/trainer.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def __init__(
115115
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
116116
weights_summary: Optional[str] = 'full',
117117
weights_save_path: Optional[str] = None,
118-
amp_level: str = 'O1',
119118
num_sanity_val_steps: int = 5,
120119
truncated_bptt_steps: Optional[int] = None,
121120
resume_from_checkpoint: Optional[str] = None,
@@ -124,6 +123,7 @@ def __init__(
124123
reload_dataloaders_every_epoch: bool = False,
125124
auto_lr_find: Union[bool, str] = False,
126125
replace_sampler_ddp: bool = True,
126+
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
127127
default_save_path=None, # backward compatible, todo: remove in v0.8.0
128128
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
129129
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
@@ -487,20 +487,18 @@ def __init__(
487487
self.determine_data_use_amount(train_percent_check, val_percent_check,
488488
test_percent_check, overfit_pct)
489489

490-
# 16 bit mixed precision training using apex
490+
# AMP init
491+
# These are the only lines needed after v0.8.0
492+
# we wrap the user's forward with autocast and give it back at the end of fit
493+
self.autocast_original_forward = None
494+
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
495+
if self.use_native_amp and self.precision == 16:
496+
self.scaler = torch.cuda.amp.GradScaler()
497+
self.precision = precision
498+
499+
# TODO: remove for v0.8.0
491500
self.amp_level = amp_level
492501
self.precision = precision
493-
494-
# Backward compatibility, TODO: remove in v0.9.0
495-
if use_amp is not None:
496-
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
497-
" and this argument will be removed in v0.9.0", DeprecationWarning)
498-
self.precision = 16 if use_amp else 32
499-
500-
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
501-
502-
if self.precision == 16 and self.num_tpu_cores is None:
503-
use_amp = True
504502
self.init_amp(use_amp)
505503

506504
# Callback system

pytorch_lightning/trainer/training_io.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
281281
if on_gpu:
282282
model.cuda(self.root_gpu)
283283

284+
# restore amp scaling
285+
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
286+
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
287+
284288
# load training state (affects trainer only)
285289
self.restore_training_state(checkpoint)
286290

@@ -316,6 +320,10 @@ def dump_checkpoint(self):
316320

317321
checkpoint['state_dict'] = model.state_dict()
318322

323+
# restore native amp scaling
324+
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
325+
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
326+
319327
if hasattr(model, "hparams"):
320328
is_namespace = isinstance(model.hparams, Namespace)
321329
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
@@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu):
441449
# load the state_dict on the model automatically
442450
model.load_state_dict(checkpoint['state_dict'])
443451

452+
# restore amp scaling
453+
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
454+
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
455+
444456
if self.root_gpu is not None:
445457
model.cuda(self.root_gpu)
446458

pytorch_lightning/trainer/training_loop.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def training_step(self, batch, batch_idx):
148148

149149
import numpy as np
150150
from torch.utils.data import DataLoader
151+
import torch
151152

152153
from pytorch_lightning import _logger as log
153154
from pytorch_lightning.callbacks.base import Callback
@@ -588,8 +589,12 @@ def run_training_batch(self, batch, batch_idx):
588589
def optimizer_closure():
589590
# forward pass
590591
with self.profiler.profile('model_forward'):
591-
output_dict = self.training_forward(
592-
split_batch, batch_idx, opt_idx, self.hiddens)
592+
if self.use_amp and self.use_native_amp:
593+
with torch.cuda.amp.autocast():
594+
output_dict = self.training_forward(split_batch, batch_idx,
595+
opt_idx, self.hiddens)
596+
else:
597+
output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)
593598

594599
# format and reduce outputs accordingly
595600
processed_output = self.process_output(output_dict, train=True)
@@ -645,6 +650,8 @@ def optimizer_closure():
645650
self.track_grad_norm)
646651

647652
# clip gradients
653+
if self.use_amp and self.use_native_amp:
654+
self.scaler.unscale_(optimizer)
648655
self.clip_gradients()
649656

650657
# calls .step(), .zero_grad()

0 commit comments

Comments
 (0)