Describe the bug
i'm using the train_dreambooth_flux.py to finetune flux. i get oom on 4x A100 80gb with deepspeed stage 2, gradient checkpoint, bf16 mixed precision, 1024px *1024px input, adafactor optimizer,batchsize 1. it can only run with deepspeed stage3, but that is too slow about 16sec/it.
Reproduction
just use train_dreambooth_flux.py in repo
Logs
No response
System Info
- 🤗 Diffusers version: 0.31.0.dev0
- Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.0
- PyTorch version (GPU?): 2.3.0+cu118 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.23.4
- Transformers version: 4.44.2
- Accelerate version: 0.33.0
- PEFT version: 0.10.0
- Bitsandbytes version: 0.44.0.dev
- Safetensors version: 0.4.2
- xFormers version: 0.0.26.post1+cu118
- Accelerator: NVIDIA A800 80GB PCIe, 81920 MiB
NVIDIA A800 80GB PCIe, 81920 MiB
NVIDIA A800 80GB PCIe, 81920 MiB
NVIDIA A800 80GB PCIe, 81920 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
@linoytsaban
Describe the bug
i'm using the train_dreambooth_flux.py to finetune flux. i get oom on 4x A100 80gb with deepspeed stage 2, gradient checkpoint, bf16 mixed precision, 1024px *1024px input, adafactor optimizer,batchsize 1. it can only run with deepspeed stage3, but that is too slow about 16sec/it.
Reproduction
just use train_dreambooth_flux.py in repo
Logs
No response
System Info
NVIDIA A800 80GB PCIe, 81920 MiB
NVIDIA A800 80GB PCIe, 81920 MiB
NVIDIA A800 80GB PCIe, 81920 MiB
Who can help?
@linoytsaban