Skip to content

Is it possible to disable mixed precision on a layer specific layer? #908

@stas00

Description

@stas00

It seems that mt5/t5 models that have been pre-trained under bfloat16 don't quite work under fp16 mixed precision and require special handling.

I found a workaround by disabling autocast (pytorch native amp) for one layer that caused all the problems:
huggingface/transformers#10956
The gist of the change is:

    def _forward(self, hidden_states):
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states

    def forward(self, hidden_states):
        # many t5/mt5 models are trained in bfloat16 and don't do well under mixed precision (fp16).
        # It appears that it's enough to disable autocast for this FF layer to avoid inf/nan
        # problems for the whole model
        if torch.is_autocast_enabled():
            with torch.cuda.amp.autocast(enabled=False):
                return self._forward(hidden_states)
        else:
            return self._forward(hidden_states)

Is there a way to do the same for DeepSpeed? i.e. continue using fp16 mixed precision for everything but a specific context?

Actually further testing shows that for a simple case, this is enough:

    def forward(self, hidden_states):
        forwarded_states = self.layer_norm(hidden_states)
        with torch.cuda.amp.autocast(enabled=False):
            forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states

But I haven't done full testing to know if it covers all bases.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions