@@ -115,7 +115,6 @@ def __init__(
115115 print_nan_grads : bool = False , # backward compatible, todo: remove in v0.9.0
116116 weights_summary : Optional [str ] = 'full' ,
117117 weights_save_path : Optional [str ] = None ,
118- amp_level : str = 'O1' ,
119118 num_sanity_val_steps : int = 5 ,
120119 truncated_bptt_steps : Optional [int ] = None ,
121120 resume_from_checkpoint : Optional [str ] = None ,
@@ -124,6 +123,7 @@ def __init__(
124123 reload_dataloaders_every_epoch : bool = False ,
125124 auto_lr_find : Union [bool , str ] = False ,
126125 replace_sampler_ddp : bool = True ,
126+ amp_level : str = 'O1' , # backward compatible, todo: remove in v0.8.0
127127 default_save_path = None , # backward compatible, todo: remove in v0.8.0
128128 gradient_clip = None , # backward compatible, todo: remove in v0.8.0
129129 nb_gpu_nodes = None , # backward compatible, todo: remove in v0.8.0
@@ -487,20 +487,18 @@ def __init__(
487487 self .determine_data_use_amount (train_percent_check , val_percent_check ,
488488 test_percent_check , overfit_pct )
489489
490- # 16 bit mixed precision training using apex
490+ # AMP init
491+ # These are the only lines needed after v0.8.0
492+ # we wrap the user's forward with autocast and give it back at the end of fit
493+ self .autocast_original_forward = None
494+ self .use_native_amp = hasattr (torch .cuda , "amp" ) and hasattr (torch .cuda .amp , "autocast" )
495+ if self .use_native_amp and self .precision == 16 :
496+ self .scaler = torch .cuda .amp .GradScaler ()
497+ self .precision = precision
498+
499+ # TODO: remove for v0.8.0
491500 self .amp_level = amp_level
492501 self .precision = precision
493-
494- # Backward compatibility, TODO: remove in v0.9.0
495- if use_amp is not None :
496- rank_zero_warn ("`use_amp` has been replaced by `precision` since v0.7.0"
497- " and this argument will be removed in v0.9.0" , DeprecationWarning )
498- self .precision = 16 if use_amp else 32
499-
500- assert self .precision in (16 , 32 ), 'only 32 or 16 bit precision supported'
501-
502- if self .precision == 16 and self .num_tpu_cores is None :
503- use_amp = True
504502 self .init_amp (use_amp )
505503
506504 # Callback system
0 commit comments