Skip to content

Commit a6e7aa7

Browse files
Bordaawaelchli
andauthored
allow using apex with any PT version (Lightning-AI#2865)
* wip * setup * type * name * wip * docs * imports * fix if * fix if * use_amp * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * fix tests * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * fix tests * todos Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent fed0ac8 commit a6e7aa7

20 files changed

+140
-139
lines changed

dockers/cuda-extras/Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ RUN apt-get update && \
3939
&& \
4040

4141
# Install AMP
42-
# TODO: skip this instrall for PT >= 1.6
4342
bash install_AMP.sh && \
4443
# Install all requirements
4544
pip install -r requirements.txt && \

pytorch_lightning/accelerators/cpu_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, trainer):
2222

2323
def setup(self, model):
2424
# run through amp wrapper
25-
if self.trainer.use_amp:
25+
if self.trainer.amp_type:
2626
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
2727

2828
# call setup after the ddp process has connected

pytorch_lightning/accelerators/ddp2_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from pytorch_lightning import _logger as log
20-
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
20+
from pytorch_lightning.utilities import AMPType
2121
from pytorch_lightning.utilities.distributed import rank_zero_only
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323

@@ -32,9 +32,7 @@
3232
try:
3333
from apex import amp
3434
except ImportError:
35-
APEX_AVAILABLE = False
36-
else:
37-
APEX_AVAILABLE = True
35+
amp = None
3836

3937

4038
class DDP2Backend(object):
@@ -135,10 +133,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
135133
# set model properties before going into wrapper
136134
self.trainer.copy_trainer_model_properties(model)
137135

138-
# AMP
139-
# run through amp wrapper before going to distributed DP
140-
# TODO: remove with dropping NVIDIA AMP support
141-
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
136+
# AMP - run through amp wrapper before going to distributed DP
137+
if self.trainer.amp_type == AMPType.APEX:
142138
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
143139
self.trainer.optimizers = optimizers
144140
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

pytorch_lightning/accelerators/ddp_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
from pytorch_lightning import _logger as log
26-
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
26+
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only
2828

2929
try:
@@ -37,9 +37,7 @@
3737
try:
3838
from apex import amp
3939
except ImportError:
40-
APEX_AVAILABLE = False
41-
else:
42-
APEX_AVAILABLE = True
40+
amp = None
4341

4442

4543
class DDPBackend(object):
@@ -202,10 +200,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
202200
# set model properties before going into wrapper
203201
self.trainer.copy_trainer_model_properties(model)
204202

205-
# AMP
206-
# run through amp wrapper before going to distributed DP
207-
# TODO: remove with dropping NVIDIA AMP support
208-
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
203+
# AMP - run through amp wrapper before going to distributed DP
204+
if self.trainer.amp_type == AMPType.APEX:
209205
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
210206
self.trainer.optimizers = optimizers
211207
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

pytorch_lightning/accelerators/ddp_spawn_backend.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
import torch.multiprocessing as mp
1717

1818
from pytorch_lightning import _logger as log
19+
from pytorch_lightning.utilities import AMPType
1920
from pytorch_lightning.utilities.distributed import rank_zero_only
2021

2122
try:
2223
from apex import amp
2324
except ImportError:
24-
APEX_AVAILABLE = False
25-
else:
26-
APEX_AVAILABLE = True
25+
amp = None
2726

2827

2928
class DDPSpawnBackend(object):
@@ -133,11 +132,9 @@ def ddp_train(self, process_idx, mp_queue, model):
133132
# set model properties before going into wrapper
134133
self.trainer.copy_trainer_model_properties(model)
135134

136-
# AMP
135+
# AMP -
137136
# run through amp wrapper before going to distributed DP
138-
# TODO: remove with dropping NVIDIA AMP support
139-
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
140-
if self.trainer.use_amp and not native_amp_available:
137+
if self.trainer.amp_type == AMPType.APEX:
141138
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
142139
self.trainer.optimizers = optimizers
143140
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

pytorch_lightning/accelerators/dp_backend.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
from torch import optim
1717

1818
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
19+
from pytorch_lightning.utilities import AMPType
1920
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2021

2122
try:
2223
from apex import amp
2324
except ImportError:
24-
APEX_AVAILABLE = False
25-
else:
26-
APEX_AVAILABLE = True
25+
amp = None
2726

2827

2928
class DataParallelBackend(object):
@@ -50,7 +49,7 @@ def setup(self, model):
5049
self.model_autocast_original_forward = model.forward
5150

5251
# init half precision
53-
if self.trainer.use_amp:
52+
if self.trainer.amp_type:
5453
model = self.__init_half_precision(model)
5554

5655
# init torch data parallel
@@ -70,9 +69,7 @@ def __init_torch_data_parallel(self, model):
7069
return model
7170

7271
def __init_half_precision(self, model):
73-
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
74-
75-
if native_amp_available:
72+
if self.trainer.amp_type == AMPType.NATIVE:
7673
self.__init_native_amp(model)
7774
else:
7875
model = self.__init_nvidia_apex(model)

pytorch_lightning/accelerators/gpu_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import torch
16-
1715
from pytorch_lightning.core import LightningModule
16+
from pytorch_lightning.utilities import AMPType
1817

1918
try:
2019
from apex import amp
2120
except ImportError:
22-
APEX_AVAILABLE = False
23-
else:
24-
APEX_AVAILABLE = True
21+
amp = None
2522

2623

2724
class GPUBackend(object):
25+
amp_type: AMPType
2826

2927
def __init__(self, trainer):
3028
self.trainer = trainer
@@ -43,9 +41,7 @@ def setup(self, model):
4341
self.trainer.lr_schedulers = lr_schedulers
4442
self.trainer.optimizer_frequencies = optimizer_frequencies
4543

46-
# TODO: remove with dropping NVIDIA AMP support
47-
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
48-
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
44+
if self.trainer.amp_type == AMPType.APEX:
4945
model = self._setup_nvidia_apex(model)
5046
return model
5147

pytorch_lightning/core/hooks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
from torch.nn import Module
66
from torch.optim.optimizer import Optimizer
77

8-
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
8+
from pytorch_lightning.utilities import move_data_to_device, AMPType
99

1010
try:
1111
from apex import amp
1212
except ImportError:
13-
APEX_AVAILABLE = False
14-
else:
15-
APEX_AVAILABLE = True
13+
amp = None
1614

1715

1816
class ModelHooks(Module):
@@ -267,8 +265,8 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
267265
"""
268266
loss.backward()
269267

270-
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
271-
if NATIVE_AMP_AVALAIBLE:
268+
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType):
269+
if amp_type == AMPType.NATIVE:
272270
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
273271
else:
274272
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)

pytorch_lightning/core/memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010
from torch.utils.hooks import RemovableHandle
1111

12-
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
12+
from pytorch_lightning.utilities import AMPType
1313

1414
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
1515
UNKNOWN_SIZE = "?"
@@ -207,8 +207,7 @@ def _forward_example_input(self) -> None:
207207
input_ = model.example_input_array
208208
input_ = model.transfer_batch_to_device(input_, model.device)
209209

210-
if trainer is not None and trainer.use_amp and not trainer.use_tpu:
211-
if NATIVE_AMP_AVALAIBLE:
210+
if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu:
212211
model.forward = torch.cuda.amp.autocast()(model.forward)
213212

214213
mode = model.training

pytorch_lightning/trainer/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,19 @@ def on_train_end(self, trainer, pl_module):
864864
865865
trainer = Trainer(sync_batchnorm=True)
866866
867+
amp_type
868+
^^^^^^^^
869+
870+
Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6.
871+
872+
.. testcode::
873+
874+
# using NVIDIA Apex
875+
trainer = Trainer(amp_type='apex')
876+
877+
# using PyTorch built-in AMP
878+
trainer = Trainer(amp_type='native')
879+
867880
val_percent_check
868881
^^^^^^^^^^^^^^^^^
869882

0 commit comments

Comments
 (0)