Skip to content

Training "train_dreambooth_lora_sdxl.py" on multi-gpu using accelerate #6146

@ghost

Description

Describe the bug

HI, I'm trying to train the train_dreambooth_lora_sdxl.py on multi-gpu. This is my accelerate config, and the error that I got is- 'DistributedDataParallel' object has no attribute 'text_model'. I found out that using multi-gpu create DistributedDataParallel object instade of CLIPTextModel but I was expecting the accelerate to deal with that problem. Do you have any ideas for me? mabey something at the accelerate config that need to be changed? Or the code need to be changed for supporting DistributedDataParallel?

Reproduction

The code based on "train_dreambooth_lora_sdxl.py" and the row that creates the error -
text_encoder_one.text_model.embeddings.requires_grad_(True)

Logs

/usr/local/lib/python3.8/dist-packages/diffusers/loaders/lora.py:711: FutureWarning: `_modify_text_encoder` is deprecated and will be removed in version 0.27. You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future.
  deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
/usr/local/lib/python3.8/dist-packages/diffusers/loaders/lora.py:683: FutureWarning: `_remove_text_encoder_monkey_patch_classmethod` is deprecated and will be removed in version 0.27. You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future.
  deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
/usr/local/lib/python3.8/dist-packages/diffusers/loaders/lora.py:711: FutureWarning: `_modify_text_encoder` is deprecated and will be removed in version 0.27. You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future.
  deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
/usr/local/lib/python3.8/dist-packages/diffusers/loaders/lora.py:683: FutureWarning: `_remove_text_encoder_monkey_patch_classmethod` is deprecated and will be removed in version 0.27. You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future.
  deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
Traceback (most recent call last):
  File "/tmp/development/creative_training/training/train.py", line 1754, in <module>
    main(args)
  File "/tmp/development/creative_training/training/train.py", line 1438, in main
    text_encoder_one.text_model.embeddings.requires_grad_(True)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DistributedDataParallel' object has no attribute 'text_model'
Steps:   0%|                                             | 0/10 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/tmp/development/creative_training/training/train.py", line 1754, in <module>
    main(args)
  File "/tmp/development/creative_training/training/train.py", line 1438, in main
    text_encoder_one.text_model.embeddings.requires_grad_(True)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DistributedDataParallel' object has no attribute 'text_model'
Steps:   0%|                                             | 0/10 [00:00<?, ?it/s]
[2023-12-12 10:06:01,848] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 7289) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 1027, in <module>
    main()
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 1023, in main
    launch_command(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 1008, in launch_command
    multi_gpu_launcher(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 666, in multi_gpu_launcher
    distrib_run.run(args)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/tmp/development/creative_training/training/train.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-12-12_10:06:01
  host      : liorra-creative-accelerate-2gpu-a100-0-0
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 7290)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-12-12_10:06:01
  host      : liorra-creative-accelerate-2gpu-a100-0-0
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 7289)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

System Info

  • diffusers version: 0.25.0.dev0
  • Platform: Linux-5.4.0-167-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Huggingface_hub version: 0.19.4
  • Transformers version: 4.36.0.dev0
  • Accelerate version: 0.25.0
  • xFormers version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@sayakpaul @pat

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions