Skip to content

XLA FSDP V2 + TPU + T5 Family Models doesn't work #35142

@agemagician

Description

@agemagician

System Info

  • transformers version: 4.48.0.dev0
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.26.5
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (False)
  • Tensorflow version (GPU?): 2.15.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (tpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Using distributed or parallel set-up in script?:

Who can help?

@ArthurZucker @muellerzr @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello,

I am trying to run any official PyTorch text2text scripts to utilize the following:

  1. TPU
  2. T5 models family
  3. xla_fsdp_v2

For example, if I run the translation official code as follows:

!export PJRT_DEVICE=TPU;export XLA_USE_SPMD=1; PJRT_DEVICE=TPU XLA_USE_SPMD=1 python transformers/examples/pytorch/translation/run_translation.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --source_lang en \
    --target_lang ro \
    --source_prefix "translate English to Romanian: " \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --output_dir /tmp/tst-translation \
    --per_device_train_batch_size=8 \
    --per_device_eval_batch_size=8 \
    --overwrite_output_dir \
    --predict_with_generate \
    --pad_to_max_length True \
    --preprocessing_num_workers 32 \
    --fsdp "full_shard" \
    --fsdp_config '{"fsdp_transformer_layer_cls_to_wrap": ["T5Block"], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": false}'

I get the following error:

[WARNING|logging.py:328] 2024-12-07 18:51:58,712 >> Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.Seq2SeqLMOutput'>
  warnings.warn("For backward hooks to be called,"
Traceback (most recent call last):
  File "/content/transformers/examples/pytorch/translation/run_translation.py", line 697, in <module>
    main()
  File "/content/transformers/examples/pytorch/translation/run_translation.py", line 612, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2169, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2527, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3660, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3715, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/experimental/spmd_fully_sharded_data_parallel.py", line 168, in forward
    self._shard_output(output, self._mesh)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2015, in shard_output
    raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
ValueError: Something went wrong, the output of the model shouldn't be `None`
  0% 0/457740 [00:01<?, ?it/s]

However, turning the "xla_fsdp_v2" off or choosing another model family works fine.

So, the issue is in combining T5 + xla_fsdp_v2.

Any idea how we could fix it ?

Expected behavior

The training should start, and we should be able to use xla_fsdp_v2 with the T5 models family.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions