|
1 | 1 | from abc import ABC |
2 | | -import torch |
3 | 2 |
|
4 | 3 | from pytorch_lightning import _logger as log |
5 | | -from pytorch_lightning.utilities import rank_zero_warn |
6 | | - |
7 | | -try: |
8 | | - from apex import amp |
9 | | -except ImportError: |
10 | | - APEX_AVAILABLE = False |
11 | | -else: |
12 | | - APEX_AVAILABLE = True |
| 4 | +from pytorch_lightning.utilities import rank_zero_warn, APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE |
| 5 | +from pytorch_lightning.utilities.distributed import rank_zero_debug |
13 | 6 |
|
14 | 7 |
|
15 | 8 | class TrainerAMPMixin(ABC): |
16 | 9 |
|
17 | 10 | # this is just a summary on variables used in this abstract class, |
18 | 11 | # the proper values/initialisation should be done in child class |
19 | 12 | precision: int |
20 | | - use_native_amp: bool |
21 | | - |
22 | | - def init_amp(self, use_amp): |
23 | | - if self.use_native_amp: |
24 | | - rank_zero_warn("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)" |
25 | | - " and this argument will be removed in v0.9.0", DeprecationWarning) |
26 | 13 |
|
27 | | - # Backward compatibility, TODO: remove in v0.9.0 |
28 | | - if use_amp is not None: |
29 | | - rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0" |
30 | | - " and this argument will be removed in v0.9.0", DeprecationWarning) |
31 | | - self.precision = 16 if use_amp else 32 |
| 14 | + def init_amp(self): |
| 15 | + if NATIVE_AMP_AVALAIBLE: |
| 16 | + log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)") |
32 | 17 |
|
33 | 18 | assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' |
34 | 19 |
|
35 | | - if use_amp and self.use_native_amp: |
36 | | - log.info('Using 16bit precision.') |
| 20 | + if self.use_amp and NATIVE_AMP_AVALAIBLE: |
| 21 | + log.info('Using native 16bit precision.') |
37 | 22 | return |
38 | 23 |
|
39 | | - # TODO: remove all below for v0.9.0 |
40 | | - if use_amp and not APEX_AVAILABLE: # pragma: no-cover |
41 | | - raise ModuleNotFoundError(""" |
42 | | - You set `use_amp=True` but do not have apex installed. |
43 | | - Install apex first using this guide and rerun with use_amp=True: |
44 | | - https://github.com/NVIDIA/apex#linux |
45 | | - this run will NOT use 16 bit precision |
46 | | - """) |
| 24 | + # TODO: replace `use_amp` by `precision` all below for v0.9.0 |
| 25 | + if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover |
| 26 | + raise ModuleNotFoundError( |
| 27 | + "You set `use_amp=True` but do not have apex installed." |
| 28 | + "Install apex first using this guide and rerun with use_amp=True:" |
| 29 | + "https://github.com/NVIDIA/apex#linux his run will NOT use 16 bit precision" |
| 30 | + ) |
47 | 31 |
|
48 | 32 | if self.use_amp: |
49 | | - log.info('Using 16bit precision.') |
| 33 | + log.info('Using APEX 16bit precision.') |
50 | 34 |
|
51 | 35 | @property |
52 | 36 | def use_amp(self) -> bool: |
|
0 commit comments