Skip to content

FLUX dreambooth train on multigpu with deepspeed #9484

@zhangvia

Description

@zhangvia

Describe the bug

i'm using the train_dreambooth_flux.py to finetune flux. i get oom on 4x A100 80gb with deepspeed stage 2, gradient checkpoint, bf16 mixed precision, 1024px *1024px input, adafactor optimizer,batchsize 1. it can only run with deepspeed stage3, but that is too slow about 16sec/it.

Reproduction

just use train_dreambooth_flux.py in repo

Logs

No response

System Info

  • 🤗 Diffusers version: 0.31.0.dev0
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.10.0
  • PyTorch version (GPU?): 2.3.0+cu118 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.23.4
  • Transformers version: 4.44.2
  • Accelerate version: 0.33.0
  • PEFT version: 0.10.0
  • Bitsandbytes version: 0.44.0.dev
  • Safetensors version: 0.4.2
  • xFormers version: 0.0.26.post1+cu118
  • Accelerator: NVIDIA A800 80GB PCIe, 81920 MiB
    NVIDIA A800 80GB PCIe, 81920 MiB
    NVIDIA A800 80GB PCIe, 81920 MiB
    NVIDIA A800 80GB PCIe, 81920 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@linoytsaban

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions