Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for FP16 in SageMaker model parallelism #11407

Merged
merged 6 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,26 @@ def __init__(
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
)
self.use_apex = True

# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Use "
"pass along 'max_grad_norm': 0 in your hyperparameters."
sgugger marked this conversation as resolved.
Show resolved Hide resolved
)

# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
Expand Down Expand Up @@ -1607,7 +1619,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
inputs = self._prepare_inputs(inputs)

if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
scaler = self.scaler if self.use_amp else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)

if self.use_amp:
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,10 +974,15 @@ def get_parameter_names(model, forbidden_layer_types):
import smdistributed.modelparallel.torch as smp

@smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
with torch.cuda.amp.autocast(enabled=(scaler is not None)):
outputs = model(**inputs)

loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
if scaler is not None:
loss = scaler.scale(loss).squeeze()

model.backward(loss)
return loss

Expand Down