Skip to content

Commit 0be78d1

Browse files
authored
native amp (Lightning-AI#2373)
* native amp * typo * imports * apex
1 parent f1c9693 commit 0be78d1

File tree

14 files changed

+70
-63
lines changed

14 files changed

+70
-63
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ references:
7777
name: Testing Documentation
7878
command: |
7979
# Second run examples in docs
80+
bash tests/install_AMP.sh
8081
sudo apt-get update && sudo apt-get install -y cmake
8182
sudo pip install -r requirements/docs.txt
8283
cd docs; make doctest; make coverage

pytorch_lightning/core/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
Thus, if we wanted to add a validation loop you would add this to your
150150
:class:`~LightningModule`:
151151
152+
>>> import pytorch_lightning as pl
152153
>>> class LitModel(pl.LightningModule):
153154
... def validation_step(self, batch, batch_idx):
154155
... x, y = batch
@@ -166,6 +167,7 @@
166167
Add test loop
167168
^^^^^^^^^^^^^
168169
170+
>>> import pytorch_lightning as pl
169171
>>> class LitModel(pl.LightningModule):
170172
... def test_step(self, batch, batch_idx):
171173
... x, y = batch
@@ -264,6 +266,7 @@ def training_step(self, batch, batch_idx):
264266
:class:`~LightningModule.prepare_data` method to
265267
allow for this:
266268
269+
>>> import pytorch_lightning as pl
267270
>>> class LitModel(pl.LightningModule):
268271
... def prepare_data(self):
269272
... # download

pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55
from torch.nn import Module
66
from torch.optim.optimizer import Optimizer
7-
from pytorch_lightning.utilities import move_data_to_device
7+
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
88

99

1010
try:
@@ -189,7 +189,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
189189
loss.backward()
190190

191191
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
192-
if self.trainer.use_native_amp:
192+
if NATIVE_AMP_AVALAIBLE:
193193
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
194194

195195
else:

pytorch_lightning/core/lightning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from pytorch_lightning.core.saving import ModelIO, PRIMITIVE_TYPES, ALLOWED_CONFIG_TYPES
2222
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
2323
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
24-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2524
from pytorch_lightning.utilities import rank_zero_warn
2625
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
2726

pytorch_lightning/core/memory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch.nn as nn
1010
from torch.utils.hooks import RemovableHandle
1111

12-
import pytorch_lightning as pl
12+
13+
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
1314
from pytorch_lightning.utilities.apply_func import apply_to_collection
1415

1516
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
@@ -126,6 +127,7 @@ class ModelSummary(object):
126127
127128
Example::
128129
130+
>>> import pytorch_lightning as pl
129131
>>> class LitModel(pl.LightningModule):
130132
...
131133
... def __init__(self):
@@ -154,7 +156,7 @@ class ModelSummary(object):
154156
MODE_DEFAULT = MODE_TOP
155157
MODES = [MODE_FULL, MODE_TOP]
156158

157-
def __init__(self, model: "pl.LightningModule", mode: str = MODE_DEFAULT):
159+
def __init__(self, model, mode: str = MODE_DEFAULT):
158160
self._model = model
159161
self._mode = mode
160162
self._layer_summary = self.summarize()
@@ -209,7 +211,7 @@ def _forward_example_input(self) -> None:
209211
input_ = apply_to_collection(input_, torch.Tensor, lambda x: x.type(model.dtype))
210212

211213
if trainer is not None and trainer.use_amp:
212-
if model.use_native_amp:
214+
if NATIVE_AMP_AVALAIBLE:
213215
model.forward = torch.cuda.amp.autocast()(model.forward)
214216

215217
mode = model.training

pytorch_lightning/trainer/auto_mix_precision.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,36 @@
11
from abc import ABC
2-
import torch
32

43
from pytorch_lightning import _logger as log
5-
from pytorch_lightning.utilities import rank_zero_warn
6-
7-
try:
8-
from apex import amp
9-
except ImportError:
10-
APEX_AVAILABLE = False
11-
else:
12-
APEX_AVAILABLE = True
4+
from pytorch_lightning.utilities import rank_zero_warn, APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
5+
from pytorch_lightning.utilities.distributed import rank_zero_debug
136

147

158
class TrainerAMPMixin(ABC):
169

1710
# this is just a summary on variables used in this abstract class,
1811
# the proper values/initialisation should be done in child class
1912
precision: int
20-
use_native_amp: bool
21-
22-
def init_amp(self, use_amp):
23-
if self.use_native_amp:
24-
rank_zero_warn("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)"
25-
" and this argument will be removed in v0.9.0", DeprecationWarning)
2613

27-
# Backward compatibility, TODO: remove in v0.9.0
28-
if use_amp is not None:
29-
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
30-
" and this argument will be removed in v0.9.0", DeprecationWarning)
31-
self.precision = 16 if use_amp else 32
14+
def init_amp(self):
15+
if NATIVE_AMP_AVALAIBLE:
16+
log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)")
3217

3318
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
3419

35-
if use_amp and self.use_native_amp:
36-
log.info('Using 16bit precision.')
20+
if self.use_amp and NATIVE_AMP_AVALAIBLE:
21+
log.info('Using native 16bit precision.')
3722
return
3823

39-
# TODO: remove all below for v0.9.0
40-
if use_amp and not APEX_AVAILABLE: # pragma: no-cover
41-
raise ModuleNotFoundError("""
42-
You set `use_amp=True` but do not have apex installed.
43-
Install apex first using this guide and rerun with use_amp=True:
44-
https://github.com/NVIDIA/apex#linux
45-
this run will NOT use 16 bit precision
46-
""")
24+
# TODO: replace `use_amp` by `precision` all below for v0.9.0
25+
if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover
26+
raise ModuleNotFoundError(
27+
"You set `use_amp=True` but do not have apex installed."
28+
"Install apex first using this guide and rerun with use_amp=True:"
29+
"https://github.com/NVIDIA/apex#linux his run will NOT use 16 bit precision"
30+
)
4731

4832
if self.use_amp:
49-
log.info('Using 16bit precision.')
33+
log.info('Using APEX 16bit precision.')
5034

5135
@property
5236
def use_amp(self) -> bool:

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def train_fx(trial_hparams, cluster_manager, _):
127127
from pytorch_lightning import _logger as log
128128
from pytorch_lightning.callbacks import ModelCheckpoint
129129
from pytorch_lightning.loggers import LightningLoggerBase
130+
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
130131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
131132
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
132133

@@ -177,7 +178,6 @@ class TrainerDDPMixin(ABC):
177178
amp_level: str
178179
use_tpu: bool
179180
default_root_dir: str
180-
use_native_amp: bool
181181
progress_bar_callback: ...
182182
num_processes: int
183183
num_nodes: int
@@ -519,7 +519,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
519519
# AMP
520520
# run through amp wrapper before going to distributed DP
521521
# TODO: remove with dropping NVIDIA AMP support
522-
if self.use_amp and not self.use_native_amp:
522+
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
523523
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
524524
self.optimizers = optimizers
525525
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
LightningDistributedDataParallel,
1919
LightningDataParallel,
2020
)
21-
from pytorch_lightning.utilities import move_data_to_device
21+
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323
from pytorch_lightning.utilities.distributed import rank_zero_only
2424

@@ -61,7 +61,6 @@ class TrainerDPMixin(ABC):
6161
tpu_local_core_rank: int
6262
tpu_global_core_rank: int
6363
use_tpu: bool
64-
use_native_amp: bool
6564
data_parallel_device_ids: ...
6665
progress_bar_callback: ...
6766
tpu_id: Optional[int]
@@ -175,7 +174,7 @@ def single_gpu_train(self, model):
175174
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
176175

177176
# TODO: remove with dropping NVIDIA AMP support
178-
if self.use_amp and not self.use_native_amp:
177+
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
179178
# An example
180179
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
181180
self.optimizers = optimizers
@@ -236,14 +235,14 @@ def dp_train(self, model):
236235

237236
# hack forward to do autocast for the user
238237
model_autocast_original_forward = model.forward
239-
if self.use_amp and self.use_native_amp:
238+
if self.use_amp and NATIVE_AMP_AVALAIBLE:
240239
# wrap the user's forward in autocast and give it back at the end
241240
model.forward = torch.cuda.amp.autocast()(model.forward)
242241

243242
# TODO: remove with dropping NVIDIA AMP support
244243
# check for this bug (amp + dp + !01 doesn't work)
245244
# https://github.com/NVIDIA/apex/issues/227
246-
if self.use_dp and self.use_amp and not self.use_native_amp:
245+
if self.use_dp and self.use_amp and not NATIVE_AMP_AVALAIBLE:
247246
if self.amp_level == 'O2':
248247
raise MisconfigurationException(
249248
f'Amp level {self.amp_level} with DataParallel is not supported.'

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,14 @@
124124

125125
from abc import ABC, abstractmethod
126126
from pprint import pprint
127-
from typing import Callable, Optional, List, Union
127+
from typing import Callable, List, Union
128128

129129
import torch
130130
from torch.utils.data import DataLoader
131131

132132
from pytorch_lightning.core.lightning import LightningModule
133133
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
134-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
135-
from pytorch_lightning.utilities import rank_zero_warn
134+
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
136135

137136
try:
138137
import torch_xla.distributed.parallel_loader as xla_pl
@@ -285,7 +284,7 @@ def _evaluate(
285284
# -----------------
286285
# RUN EVALUATION STEP
287286
# -----------------
288-
if self.use_amp and self.use_native_amp:
287+
if self.use_amp and NATIVE_AMP_AVALAIBLE:
289288
with torch.cuda.amp.autocast():
290289
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
291290
else:

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytorch_lightning.core.memory import ModelSummary
1414
from pytorch_lightning.loggers import LightningLoggerBase
1515
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
16-
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
16+
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin, NATIVE_AMP_AVALAIBLE
1717
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
1818
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
1919
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
@@ -532,12 +532,17 @@ def __init__(
532532
# These are the only lines needed after v0.8.0
533533
# we wrap the user's forward with autocast and give it back at the end of fit
534534
self.autocast_original_forward = None
535-
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
536535
self.precision = precision
537536
self.scaler = None
538537

538+
# Backward compatibility, TODO: remove in v0.9.0
539+
if use_amp is not None:
540+
rank_zero_warn("Argument `use_amp` is now set by `precision` since v0.7.0"
541+
" and this method will be removed in v0.9.0", DeprecationWarning)
542+
self.precision = 16 if use_amp else 32
543+
539544
self.amp_level = amp_level
540-
self.init_amp(use_amp)
545+
self.init_amp()
541546

542547
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
543548

@@ -1002,7 +1007,7 @@ def run_pretrain_routine(self, model: LightningModule):
10021007
self.copy_trainer_model_properties(ref_model)
10031008

10041009
# init amp. Must be done here instead of __init__ to allow ddp to work
1005-
if self.use_native_amp and self.precision == 16:
1010+
if NATIVE_AMP_AVALAIBLE and self.precision == 16:
10061011
self.scaler = torch.cuda.amp.GradScaler()
10071012

10081013
# log hyper-parameters

0 commit comments

Comments
 (0)