Skip to content

LayerDrop broken in various Flax models (Whisper/BART/more...) #35468

@sssshhhhhh

Description

@sssshhhhhh

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.6.56+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.7
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (False)
  • Tensorflow version (GPU?): 2.17.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.8.4 (cpu)
  • Jax version: 0.4.26
  • JaxLib version: 0.4.26
  • Using distributed or parallel set-up in script?: N/A

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Train a FlaxWhisperForConditionalGeneration model with encoder/decoder layerdrop activated.

from transformers import FlaxWhisperForConditionalGeneration
import numpy as np
 
model = FlaxWhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
model.config.encoder_layerdrop = 1.0
model.encode(np.random.rand(1, 80, 3000), train=True)
 
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-14-4ffcfa8247ef> in <cell line: 6>()
      4 model = FlaxWhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')
      5 model.config.encoder_layerdrop = 1.0
----> 6 model.encode(np.random.rand(1, 80, 3000), train=True)
 
/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_flax_whisper.py in encode(self, input_features, attention_mask, output_attentions, output_hidden_states, return_dict, train, params, dropout_rng, **kwargs)
   1006             return encode_module(input_features, **kwargs)
   1007 
-> 1008         return self.module.apply(
   1009             {"params": params or self.params},
   1010             input_features=jnp.array(input_features, dtype="f4"),
 
    [... skipping hidden 4 frame]
 
/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_flax_whisper.py in _encoder_forward(module, input_features, **kwargs)
   1004         def _encoder_forward(module, input_features, **kwargs):
   1005             encode_module = module._get_encoder_module()
-> 1006             return encode_module(input_features, **kwargs)
   1007 
   1008         return self.module.apply(
 
    [... skipping hidden 2 frame]
 
/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_flax_whisper.py in __call__(self, input_features, output_attentions, output_hidden_states, return_dict, deterministic)
    709         )
    710 
--> 711         last_hidden_states = outputs[0]
    712         last_hidden_states = self.layer_norm(last_hidden_states)
    713 
 
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py in __getitem__(self, k)
    431             return inner_dict[k]
    432         else:
--> 433             return self.to_tuple()[k]
    434 
    435     def __setattr__(self, name, value):
 
IndexError: tuple index out of range

Expected behavior

I'm using FlaxWhisperForConditionalGeneration but I see the same code is in a bunch of models.

Here hidden_states is set to None if the layer is dropped causing the error.

# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]

Fixing that I also noticed dropout_probability = random.uniform(0, 1) is only run during tracing so looping a compiled training step will always drop the same layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions