-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Open
Description
It seems that mt5/t5 models that have been pre-trained under bfloat16 don't quite work under fp16 mixed precision and require special handling.
I found a workaround by disabling autocast (pytorch native amp) for one layer that caused all the problems:
huggingface/transformers#10956
The gist of the change is:
def _forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
def forward(self, hidden_states):
# many t5/mt5 models are trained in bfloat16 and don't do well under mixed precision (fp16).
# It appears that it's enough to disable autocast for this FF layer to avoid inf/nan
# problems for the whole model
if torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(hidden_states)
else:
return self._forward(hidden_states)
Is there a way to do the same for DeepSpeed? i.e. continue using fp16 mixed precision for everything but a specific context?
Actually further testing shows that for a simple case, this is enough:
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
with torch.cuda.amp.autocast(enabled=False):
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
But I haven't done full testing to know if it covers all bases.
Thank you!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels