Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support deepspeed #1101

Merged
merged 12 commits into from
Feb 27, 2024
Merged

support deepspeed #1101

merged 12 commits into from
Feb 27, 2024

Conversation

BootsofLagrangian
Copy link
Contributor

@BootsofLagrangian BootsofLagrangian commented Feb 3, 2024

Test Done!!

Introduction

This PR adds DeepSpeed support via Accelerate to sd-scripts, aiming to improve multi-GPU training with ZeRO-Stage. I've made these changes in my fork under the branch-deepspeed and I'm open to any feedback!

0. Environment

  • Linux 3.10.0-1160.15.2.el7.x86_64
  • Under anaconda environment
  • Windows support is partially supported with DeepSpeed (Microsoft said). so, NOT TESTED!

1. Install DeepSpeed

First, activate your virtual environment and install DeepSpeed with the following command:

DS_BUILD_OPS=0 pip install deepspeed

2. Configure Accelerate

You can easily set up your environment for DeepSpeed with accelerate config. It allows you to control basic DeepSpeed environment variables. You can also use command-line arguments for configuration. Here's how you can set up for ZeRO-2 stage using Accelerate:

(deepspeed) accelerate config
In which compute environment are you running? **This machine**
Which type of machine are you using? **multi-GPU**
How many different machines will you use (use more than 1 for multi-node training)? [1]: **1**
Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: **NO**
Do you wish to optimize your script with torch dynamo?[yes/NO]: **NO**
Do you want to use DeepSpeed? [yes/NO]: **yes**
Do you want to specify a json file to a DeepSpeed config? [yes/NO]: **NO**
What should be your DeepSpeed's ZeRO optimization stage? **2**
How many gradient accumulation steps you're passing in your script? [1]: **1**
Do you want to use gradient clipping? [yes/NO]: **NO**
How many GPU(s) should be used for distributed training? [1]: **4**
Do you wish to use FP16 or BF16 (mixed precision)? **bf16**
accelerate configuration saved at ~/.cache/huggingface/accelerate/default_config.yaml

Follow the prompts to select your environment settings, including using multi-GPU, enabling DeepSpeed, and setting ZeRO optimization stage to 2.

Your configuration will be saved in a YAML file, similar to the following example (path and values may vary):

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  zero_stage: 2
...

3. Use in Your Scripts

toml Configuration File

  • Add deepspeed=true and zero_stage=[zero_stage] to your toml config file. Refer to the ZeRO-stage and Accelerate DeepSpeed documentation for more details.

Bash Argument

  • Simply add --deepspeed --zero_stage=[zero_stage] to your script's command line arguments.
    1, 2, and 3 can be [zero_stage]

CPU/NVMe offloading

  • offload_optimizer_device = "cpu|nvme"
  • offload_param_device = "cpu|nvme"
  • offload_optimizer_nvme_path = "/path/to/offloading"
  • offload_param_nvme_path= "/path/to/offloading"

Add this argument in your toml or bash/batch scripts arguments.

full_fp16 training

  • DeepSpeed supports fp16_master_weights_and_gradients during training. But I think that it is not recommended and can run under restricted configuration. Only activated when optimizer is CPUAdam and ZeRO-2 stage.

Note

This PR aims to improve training efficiency in multi-GPU setups. It's been tested only in Linux environments and specifically for multi-GPU configurations. The DeepSpeed supports in Accelerate is still experimental, so please keep this in mind and feel free to provide feedback or comments on this PR.

Test Done!!

@FurkanGozukara
Copy link

hello.

so we add --deepspeed --zero_stage=[your_stage]

but what is your_stage here? thank you

what is zero_stage?

@BootsofLagrangian
Copy link
Contributor Author

BootsofLagrangian commented Feb 5, 2024

hello.

so we add --deepspeed --zero_stage=[your_stage]

but what is your_stage here? thank you

what is zero_stage?

your stage means one of ZeRO stage, [0, 1, 2, 3]. Details are in ZeRO documents.
Shortly

  1. ZeRO stage 1 : The optimizer states
  2. ZeRO stage 2 : The optimizer states + The gradient states
  3. ZeRO stage 3 : The optimizer states + The gradient states + The parameters

to shard them into multi-gpus.

@FurkanGozukara
Copy link

hello.
so we add --deepspeed --zero_stage=[your_stage]
but what is your_stage here? thank you
what is zero_stage?

your stage means one of ZeRO stage, [0, 1, 2, 3]. Details are in ZeRO documents. Shortly

  1. ZeRO stage 1 : The optimizer states
  2. ZeRO stage 2 : The optimizer states + The gradient states
  3. ZeRO stage 3 : The optimizer states + The gradient states + The parameters

to shard them into multi-gpus.

for 2 gpus which one you suggest? like --deepspeed --zero_stage=1?

what about 3 gpus?

currently it clone entire training on each gpus as far as i know

so can you tell some suggested guidelines?

@BootsofLagrangian
Copy link
Contributor Author

BootsofLagrangian commented Feb 5, 2024

hello.
so we add --deepspeed --zero_stage=[your_stage]
but what is your_stage here? thank you
what is zero_stage?

your stage means one of ZeRO stage, [0, 1, 2, 3]. Details are in ZeRO documents. Shortly

  1. ZeRO stage 1 : The optimizer states
  2. ZeRO stage 2 : The optimizer states + The gradient states
  3. ZeRO stage 3 : The optimizer states + The gradient states + The parameters

to shard them into multi-gpus.

for 2 gpus which one you suggest? like --deepspeed --zero_stage=1?

what about 3 gpus?

currently it clone entire training on each gpus as far as i know

so can you tell some suggested guidelines?

I think that ZeRO-2 stage is optimal stage if you have enough amount of VRAM. i.e. --deepspeed --zero_stage=2 only.

Without offloading, total amount of VRAM is still a major factor to decide training parameters.

You have to choose between training speed and saving VRAM, which will require a trade-off.

The number of GPUs is important, but deciding on the ZeRO stage is a strategy to choose the method you like better.

And if you want to use cpu/nvme offload, you might meet some kind of error in this commit. I will fix it soon.

But you can use ZeRO-2 stage without offloading in this commit.

@FurkanGozukara
Copy link

when i tested fp16 and bf16 on SD 1.5 I had horrible results. are you able to get any decent results?

@mchdks
Copy link

mchdks commented Feb 10, 2024

any update here? @kohya-ss

@BootsofLagrangian
Copy link
Contributor Author

BootsofLagrangian commented Feb 19, 2024

Here is a report of deepspeed in sd-scripts

Environment

Model

  • Stable Diffusion XL 1.0

GPUs

  • 24GB VRAM. No NVLink
    1. 2 x RTX 3090, 4 x RTX 3090, 6 x RTX 3090, and 8 x RTX 3090
    2. 2 x RTX 4090, 4 x RTX 4090, and 6 x RTX 4090
  • 48GB VRAM. No NVLink
    1. 2 x A6000, 4 x A6000, 6 x A6000, and 8 x A6000

No NVLink make bottleneck among GPU communication.

Requirements

  1. torch 2.2.0 cu121
  2. bitsandbytes 0.0.42
  • AdamW8bit
  1. accelerate 0.25.0
  2. deepspeed 0.13.1
  3. lycoris 2.0.2
  • LoCon

Training Settings

  • Methodology
    1. Full Finetuning(FT). sdxl_train.py
    2. LoCon with rank=16(PEFT). sdxl_train_network.py
  • Training Precision
    1. full_fp16, only active for DDP
    2. full_bf16
    3. bf16
  • Distributed Method
    1. Distributed Data Parallel(DDP). sd-scripts and accelerate default
    2. ZeRO. stage 1, 2, and 3
    • In both full_bf16 and bf16 training, I also tested ZeRO stage 2 with optimizer cpu offloading.
    • If deepspeed engine face OOM, they try to move variables from VRAM to RAM. i.e. if no OOM, it same as ZeRO-2.
  • Resolution
    1. 1024x1024 with aspect ratio bucketing(ARB)
  • Batch and Optimizer
    1. batch size = 4
    2. gradient accumulation steps = 16
    3. AdamW8bit

With 24GB VRAM, sd-scripts can run barely PEFT, and can not run FT on DDP.

Experiment

Fisrt, I'm sorry for some missing element of table. My budget is limited.

Lower is better.

full_fp16, FT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3
RTX 3090 2 OOM OOM OOM OOM
4 OOM OOM OOM 21539
6 OOM OOM OOM 21569
8 OOM OOM 23400 21013
RTX 4090 2 OOM OOM OOM OOM
4 OOM OOM OOM 21154
6 OOM OOM OOM 21571
A6000 2 42672 46391 45084 45277
4 43007 40060 36713 40912
6 41655 39553 36725 41537
8 42609 33300 29475 30474

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3
RTX 3090 2 - - - -
4 - - - 191.73
6 - - - 477.72
8 - - 143.58 190.09
RTX 4090 2 - - - -
4 - - - 50.02
6 - - - 104.65
A6000 2 42.03 37.35 43.61 45.80
4 46.18 37.90 45.64 49.48
6 47.98 38.42 48.32 104.51
8 40.97 35.51 45.59 46.63

full_fp16, PEFT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3
RTX 3090 2 23282 22097 22026 21045
4 23113 23206 23154 20968
6 22764 23280 23034 21949
8 23109 22582 SIGABRT* 21955
RTX 4090 2 22450 22163 22175 20791
4 22720 23059 23528 22209
6 21767 OOM* 23167 22142
A6000 2 37697 44323 43847 26219
4 35612 41004 38782 31543
6 28099 34084 34044 31085
8 30579 31521 31513 30701

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3
RTX 3090 2 58.64 51.70 56.25 114.71
4 59.27 55.92 58.37 167.77
6 58.32 52.61 57.07 375.46
8 53.08 48.88 SIGABRT* 165.85
RTX 4090 2 30.33 29.23 29.76 66.43
4 32.36 29.42 31.17 75.34
6 32.91 - 31.91 112.83
A6000 2 45.89 41.96 42.80 70.17
4 48.72 42.47 43.39 79.54
6 46.86 41.24 44.93 117.68
8 45.11 39.91 39.58 84.05

*SIGABRT occurred. IDK why.
*Ideally this OOM should not to be happened.

full_bf16, FT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 OOM OOM OOM OOM 23553
4 OOM OOM OOM 22206 23534
6 OOM OOM OOM 22059 23435
8 OOM OOM OOM* 21185 23388
RTX 4090 2 OOM OOM OOM OOM 23709
4 OOM OOM OOM 21646 23689
6 OOM OOM OOM 20867 23605
A6000 2 42673 46368 45117 42914 42320
4 42995 39963 36662 39542 33359
6 41653 39357 36639 40631 31007
8 41953 33268 29594 30988 32420

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 - - - - 112.18
4 - - - 113.79 123.54
6 - - - 481.29 183.95
8 - - OOM* 192.69 167.70
RTX 4090 2 - - - - 45.24
4 - - - 50.01 48.54
6 - - - 105.13 48.30
A6000 2 42.79 35.99 42.76 45.39 57.25
4 42.61 38.42 47.54 46.59 60.92
6 43.09 40.05 49.84 103.77 60.39
8 47.67 36.09 48.59 50.04 60.16

*Ideally this OOM should not to be happened.

full_bf16, PEFT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 22413 22099 22016 20786 21830
4 23118 23204 23139 20668 23010
6 22736 23377 22971 21960 23056
8 23192 22576 22575 21767 23056
RTX 4090 2 22432 22146 22175 20896 22001
4 22715 23077 23247 21329 23150
6 21570 OOM* 23170 22618 23175
A6000 2 37674 43901 43813 26004 43606
4 35439 40972 38718 32233 38611
6 28098 34106 34011 31077 33971
8 30590 31429 32644 30472 33516

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 57.72 52.67 54.85 108.90 55.82
4 57.01 52.24 55.58 123.89 56.76
6 57.33 52.95 57.02 409.32 57.65
8 53.62 49.53 52.23 167.19 52.57
RTX 4090 2 34.38 28.74 29.97 69.83 34.20
4 31.66 30.12 30.46 75.44 30.93
6 32.48 - 32.68 110.62 34.28
A6000 2 46.19 40.35 41.49 67.89 43.10
4 45.88 41.03 42.82 76.99 43.41
6 46.30 42.03 43.75 117.78 43.91
8 45.93 37.86 54.25 87.34 42.61

*Ideally this OOM should not to be happened.

bf16, FT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 - - - - 23549
4 - - - 22041 23509
6 - - - 21477 23454
8 - - - 20672 23398
RTX 4090 2 - - - - 23712
4 - - - 21620 23627
6 - - - 21234 23660
A6000 2 - 46397 45082 42974 42320
4 - 40096 36711 39840 33405
6 - 39357 36680 40632 31034
8 - 33289 29604 30599 32457

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 - - - - 172.08
4 - - - 216.22 195.97
6 - - - 506.07 189.02
8 - - - 189.56 165.13
RTX 4090 2 - - - - 44.25
4 - - - 53.28 52.91
6 - - - 110.34 56.24
A6000 2 - 36.68 46.37 45.86 61.53
4 - 41.75 46.25 48.04 61.95
6 - 41.84 49.43 104.32 65.88
8 - 35.50 45.23 47.68 56.39

bf16, PEFT

Average VRAM usage(MB)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 22684 22090 21817 21773 21816
4 22086 23207 23064 21082 22988
6 21531 23396 23130 22124 23131
8 22141 22568 22743 22038 22763
RTX 4090 2 21948 22156 21993 21705 21987
4 22224 23055 23199 20992 23181
6 21944 MISSING* 23695 22003 23138
A6000 2 38117 43926 43641 26064 43641
4 36059 40986 38665 32496 38642
6 28546 34141 33961 31079 33981
8 31020 31440 33448 31245 33499

Training Speed(s/it)

GPU Name # of GPUs DDP ZeRO-1 ZeRO-2 ZeRO-3 ZeRO-2 Optimizer cpu offloading
RTX 3090 2 57.31 52.40 56.29 156.13 56.48
4 58.78 54.07 57.65 179.51 56.78
6 59.57 55.31 60.12 426.70 61.33
8 53.92 49.99 55.16 165.99 54.26
RTX 4090 2 30.16 27.90 29.36 66.05 29.41
4 31.36 29.43 31.11 77.10 32.79
6 31.79 MISSING* 41.31 117.37 31.95
A6000 2 48.37 42.36 44.73 71.43 43.77
4 54.13 44.26 44.34 78.75 45.91
6 49.05 43.56 44.79 121.62 45.53
8 46.27 37.56 42.29 85.80 44.37

*I lost this element.

Results

  1. Wrapping models(U-Net, TEs, and network) just works.
  2. ZeRO-1 stage is most capable strategy, IMO.
  3. If you want to use multiple of odd number gpus in ZeRO-3 stage, you will meet very slow script.
  4. 2 x 24GB VRAM gpus can run FT on ZeRO-2 stage with CPU offloading.
  5. Ada Lovelace is super fast.

@FurkanGozukara
Copy link

the only way to utilize multiple consumer GPU is i think cloning the training if you don't have pro GPUs.

@tinbtb
Copy link

tinbtb commented Feb 19, 2024

full_fp16, only active for DDP

Why not use bf16 for these cards?

Training Speed(s/it). Lower is Better.
29.23

Could you please compare it with training on a single card? As far as I remember I get roughly the same speeds with just one card.

@BootsofLagrangian
Copy link
Contributor Author

full_fp16, only active for DDP

Why not use bf16 for these cards?

Training Speed(s/it). Lower is Better.
29.23

Could you please compare it with training on a single card? As far as I remember I get roughly the same speeds with just one card.

full_bf16 and bf16 is on running.

For a single card, training speed is almost same as DDP, DDP is slight slower.

@mchdks
Copy link

mchdks commented Feb 19, 2024

full_fp16, only active for DDP

Why not use bf16 for these cards?

Training Speed(s/it). Lower is Better.
29.23

Could you please compare it with training on a single card? As far as I remember I get roughly the same speeds with just one card.

full_bf16 and bf16 is on running.

For a single card, training speed is almost same as DDP, DDP is slight slower.

The results look very promising, I can help you by providing you with different multi-gpu machines if it will help with your tests.

@FurkanGozukara
Copy link

what is your effective batch size?

this is cloned on each GPU?

like 2 gpu means 4 * 16 * 2 ?

@BootsofLagrangian
Copy link
Contributor Author

full_fp16, only active for DDP

Why not use bf16 for these cards?

Training Speed(s/it). Lower is Better.
29.23

Could you please compare it with training on a single card? As far as I remember I get roughly the same speeds with just one card.

full_bf16 and bf16 is on running.
For a single card, training speed is almost same as DDP, DDP is slight slower.

The results look very promising, I can help you by providing you with different multi-gpu machines if it will help with your tests.

Thank your suggestion. But I think now-days Diffusion Models are not too big and necessary to run with multi-machines or multi-nodes.

@BootsofLagrangian
Copy link
Contributor Author

what is your effective batch size?

this is cloned on each GPU?

like 2 gpu means 4 * 16 * 2 ?

Effective Batch calculation in sd-scripts is on below

[effective batch] = [number of machine] x [number of gpus] x [train_batch_size] x [gradient_accumulation_steps]

For example, in my 2 gpu settings.

effective batch = 1 machine x 2 gpus x 4 train batch size x 16 gradient accumulation steps
= 1 x 2 x 4 x 16 = 128

@mchdks
Copy link

mchdks commented Feb 22, 2024

full_fp16, only active for DDP

Why not use bf16 for these cards?

Training Speed(s/it). Lower is Better.
29.23

Could you please compare it with training on a single card? As far as I remember I get roughly the same speeds with just one card.

full_bf16 and bf16 is on running.
For a single card, training speed is almost same as DDP, DDP is slight slower.

The results look very promising, I can help you by providing you with different multi-gpu machines if it will help with your tests.

Thank your suggestion. But I think now-days Diffusion Models are not too big and necessary to run with multi-machines or multi-nodes.

Actually, I was talking about multi-GPUs, not multi-machines. If you want to test on 8xa100 and a10g, please send a message.

@mchdks
Copy link

mchdks commented Feb 23, 2024

Here is a report of deepspeed in sd-scripts

Environment

Model

  • Stable Diffusion XL 1.0

GPUs

  • 24GB VRAM. No NVLink

    1. 2 x RTX 3090, 4 x RTX 3090, 6 x RTX 3090, and 8 x RTX 3090

    2. 2 x RTX 4090, 4 x RTX 4090, and 6 x RTX 4090

  • 48GB VRAM. No NVLink

    1. 2 x A6000, 4 x A6000, 6 x A6000, and 8 x A6000

No NVLink make bottleneck among GPU communication.

Requirements

  1. torch 2.2.0 cu121

  2. bitsandbytes 0.0.42

  • AdamW8bit
  1. accelerate 0.25.0

  2. deepspeed 0.13.1

  3. lycoris 2.0.2

  • LoCon

Training Settings

  • Methodology

    1. Full Finetuning(FT). sdxl_train.py

    2. LoCon with rank=16(PEFT). sdxl_train_network.py

  • Training Precision

    1. full_fp16, only active for DDP

    2. full_bf16

    3. bf16

  • Distributed Method

    1. Distributed Data Parallel(DDP). sd-scripts and accelerate default

    2. ZeRO. stage 1, 2, and 3

    • In both full_bf16 and bf16 training, I also tested ZeRO stage 2 with optimizer cpu offloading.

    • If deepspeed engine face OOM, they try to move variables from VRAM to RAM. i.e. if no OOM, it same as ZeRO-2.

  • Resolution

    1. 1024x1024 with aspect ratio bucketing(ARB)
  • Batch and Optimizer

    1. batch size = 4

    2. gradient accumulation steps = 16

    3. AdamW8bit

With 24GB VRAM, sd-scripts can run barely PEFT, and can not run FT on DDP.

Experiment

Fisrt, I'm sorry for some missing element of table. My budget is limited.

Lower is better.

full_fp16, FT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |

|:---------:|:---------:|:------:|:------:|:------:|:------:|

| RTX 3090 | 2 | OOM | OOM | OOM | OOM |

| | 4 | OOM | OOM | OOM | 21539 |

| | 6 | OOM | OOM | OOM | 21569 |

| | 8 | OOM | OOM | 23400 | 21013 |

| RTX 4090 | 2 | OOM | OOM | OOM | OOM |

| | 4 | OOM | OOM | OOM | 21154 |

| | 6 | OOM | OOM | OOM | 21571 |

| A6000 | 2 | 42672 | 46391 | 45084 | 45277 |

| | 4 | 43007 | 40060 | 36713 | 40912 |

| | 6 | 41655 | 39553 | 36725 | 41537 |

| | 8 | 42609 | 33300 | 29475 | 30474 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |

|:---------:|:---------:|:------:|:------:|:------:|:------:|

| RTX 3090 | 2 | - | - | - | - |

| | 4 | - | - | - | 191.73 |

| | 6 | - | - | - | 477.72 |

| | 8 | - | - | 143.58 | 190.09 |

| RTX 4090 | 2 | - | - | - | - |

| | 4 | - | - | - | 50.02 |

| | 6 | - | - | - | 104.65 |

| A6000 | 2 | 42.03 | 37.35 | 43.61 | 45.80 |

| | 4 | 46.18 | 37.90 | 45.64 | 49.48 |

| | 6 | 47.98 | 38.42 | 48.32 | 104.51 |

| | 8 | 40.97 | 35.51 | 45.59 | 46.63 |

full_fp16, PEFT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |

|:---------:|:---------:|:------:|:------:|:------:|:------:|

| RTX 3090 | 2 | 23282 | 22097 | 22026 | 21045 |

| | 4 | 23113 | 23206 | 23154 | 20968 |

| | 6 | 22764 | 23280 | 23034 | 21949 |

| | 8 | 23109 | 22582 |SIGABRT*| 21955 |

| RTX 4090 | 2 | 22450 | 22163 | 22175 | 20791 |

| | 4 | 22720 | 23059 | 23528 | 22209 |

| | 6 | 21767 | OOM* | 23167 | 22142 |

| A6000 | 2 | 37697 | 44323 | 43847 | 26219 |

| | 4 | 35612 | 41004 | 38782 | 31543 |

| | 6 | 28099 | 34084 | 34044 | 31085 |

| | 8 | 30579 | 31521 | 31513 | 30701 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 |

|:---------:|:---------:|:------:|:------:|:------:|:------:|

| RTX 3090 | 2 | 58.64 | 51.70 | 56.25 | 114.71 |

| | 4 | 59.27 | 55.92 | 58.37 | 167.77 |

| | 6 | 58.32 | 52.61 | 57.07 | 375.46 |

| | 8 | 53.08 | 48.88 |SIGABRT*| 165.85 |

| RTX 4090 | 2 | 30.33 | 29.23 | 29.76 | 66.43 |

| | 4 | 32.36 | 29.42 | 31.17 | 75.34 |

| | 6 | 32.91 | - | 31.91 | 112.83 |

| A6000 | 2 | 45.89 | 41.96 | 42.80 | 70.17 |

| | 4 | 48.72 | 42.47 | 43.39 | 79.54 |

| | 6 | 46.86 | 41.24 | 44.93 | 117.68 |

| | 8 | 45.11 | 39.91 | 39.58 | 84.05 |

*SIGABRT occurred. IDK why.

*Ideally this OOM should not to be happened.

full_bf16, FT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | OOM | OOM | OOM | OOM | 23553 |

| | 4 | OOM | OOM | OOM | 22206 | 23534 |

| | 6 | OOM | OOM | OOM | 22059 | 23435 |

| | 8 | OOM | OOM | OOM* | 21185 | 23388 |

| RTX 4090 | 2 | OOM | OOM | OOM | OOM | 23709 |

| | 4 | OOM | OOM | OOM | 21646 | 23689 |

| | 6 | OOM | OOM | OOM | 20867 | 23605 |

| A6000 | 2 | 42673 | 46368 | 45117 | 42914 | 42320 |

| | 4 | 42995 | 39963 | 36662 | 39542 | 33359 |

| | 6 | 41653 | 39357 | 36639 | 40631 | 31007 |

| | 8 | 41953 | 33268 | 29594 | 30988 | 32420 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | - | - | - | - | 112.18 |

| | 4 | - | - | - | 113.79 | 123.54 |

| | 6 | - | - | - | 481.29 | 183.95 |

| | 8 | - | - | OOM* | 192.69 | 167.70 |

| RTX 4090 | 2 | - | - | - | - | 45.24 |

| | 4 | - | - | - | 50.01 | 48.54 |

| | 6 | - | - | - | 105.13 | 48.30 |

| A6000 | 2 | 42.79 | 35.99 | 42.76 | 45.39 | 57.25 |

| | 4 | 42.61 | 38.42 | 47.54 | 46.59 | 60.92 |

| | 6 | 43.09 | 40.05 | 49.84 | 103.77 | 60.39 |

| | 8 | 47.67 | 36.09 | 48.59 | 50.04 | 60.16 |

*Ideally this OOM should not to be happened.

full_bf16, PEFT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | 22413 | 22099 | 22016 | 20786 | 21830 |

| | 4 | 23118 | 23204 | 23139 | 20668 | 23010 |

| | 6 | 22736 | 23377 | 22971 | 21960 | 23056 |

| | 8 | 23192 | 22576 | 22575 | 21767 | 23056 |

| RTX 4090 | 2 | 22432 | 22146 | 22175 | 20896 | 22001 |

| | 4 | 22715 | 23077 | 23247 | 21329 | 23150 |

| | 6 | 21570 | OOM* | 23170 | 22618 | 23175 |

| A6000 | 2 | 37674 | 43901 | 43813 | 26004 | 43606 |

| | 4 | 35439 | 40972 | 38718 | 32233 | 38611 |

| | 6 | 28098 | 34106 | 34011 | 31077 | 33971 |

| | 8 | 30590 | 31429 | 32644 | 30472 | 33516 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | 57.72 | 52.67 | 54.85 | 108.90 | 55.82 |

| | 4 | 57.01 | 52.24 | 55.58 | 123.89 | 56.76 |

| | 6 | 57.33 | 52.95 | 57.02 | 409.32 | 57.65 |

| | 8 | 53.62 | 49.53 | 52.23 | 167.19 | 52.57 |

| RTX 4090 | 2 | 34.38 | 28.74 | 29.97 | 69.83 | 34.20 |

| | 4 | 31.66 | 30.12 | 30.46 | 75.44 | 30.93 |

| | 6 | 32.48 | - | 32.68 | 110.62 | 34.28 |

| A6000 | 2 | 46.19 | 40.35 | 41.49 | 67.89 | 43.10 |

| | 4 | 45.88 | 41.03 | 42.82 | 76.99 | 43.41 |

| | 6 | 46.30 | 42.03 | 43.75 | 117.78 | 43.91 |

| | 8 | 45.93 | 37.86 | 54.25 | 87.34 | 42.61 |

*Ideally this OOM should not to be happened.

bf16, FT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | - | - | - | - | 23549 |

| | 4 | - | - | - | 22041 | 23509 |

| | 6 | - | - | - | 21477 | 23454 |

| | 8 | - | - | - | 20672 | 23398 |

| RTX 4090 | 2 | - | - | - | - | 23712 |

| | 4 | - | - | - | 21620 | 23627 |

| | 6 | - | - | - | 21234 | 23660 |

| A6000 | 2 | - | 46397 | 45082 | 42974 | 42320 |

| | 4 | - | 40096 | 36711 | 39840 | 33405 |

| | 6 | - | 39357 | 36680 | 40632 | 31034 |

| | 8 | - | 33289 | 29604 | 30599 | 32457 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | - | - | - | - | 172.08 |

| | 4 | - | - | - | 216.22 | 195.97 |

| | 6 | - | - | - | 506.07 | 189.02 |

| | 8 | - | - | - | 189.56 | 165.13 |

| RTX 4090 | 2 | - | - | - | - | 44.25 |

| | 4 | - | - | - | 53.28 | 52.91 |

| | 6 | - | - | - | 110.34 | 56.24 |

| A6000 | 2 | - | 36.68 | 46.37 | 45.86 | 61.53 |

| | 4 | - | 41.75 | 46.25 | 48.04 | 61.95 |

| | 6 | - | 41.84 | 49.43 | 104.32 | 65.88 |

| | 8 | - | 35.50 | 45.23 | 47.68 | 56.39 |

bf16, PEFT

Average VRAM usage(MB)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | 22684 | 22090 | 21817 | 21773 | 21816 |

| | 4 | 22086 | 23207 | 23064 | 21082 | 22988 |

| | 6 | 21531 | 23396 | 23130 | 22124 | 23131 |

| | 8 | 22141 | 22568 | 22743 | 22038 | 22763 |

| RTX 4090 | 2 | 21948 | 22156 | 21993 | 21705 | 21987 |

| | 4 | 22224 | 23055 | 23199 | 20992 | 23181 |

| | 6 | 21944 | MISSING*| 23695 | 22003 | 23138 |

| A6000 | 2 | 38117 | 43926 | 43641 | 26064 | 43641 |

| | 4 | 36059 | 40986 | 38665 | 32496 | 38642 |

| | 6 | 28546 | 34141 | 33961 | 31079 | 33981 |

| | 8 | 31020 | 31440 | 33448 | 31245 | 33499 |

Training Speed(s/it)

| GPU Name | # of GPUs | DDP | ZeRO-1 | ZeRO-2 | ZeRO-3 | ZeRO-2 Optimizer cpu offloading |

|:---------:|:---------:|:------:|:------:|:------:|:------:|:-------------------------------:|

| RTX 3090 | 2 | 57.31 | 52.40 | 56.29 | 156.13 | 56.48 |

| | 4 | 58.78 | 54.07 | 57.65 | 179.51 | 56.78 |

| | 6 | 59.57 | 55.31 | 60.12 | 426.70 | 61.33 |

| | 8 | 53.92 | 49.99 | 55.16 | 165.99 | 54.26 |

| RTX 4090 | 2 | 30.16 | 27.90 | 29.36 | 66.05 | 29.41 |

| | 4 | 31.36 | 29.43 | 31.11 | 77.10 | 32.79 |

| | 6 | 31.79 | MISSING*| 41.31 | 117.37 | 31.95 |

| A6000 | 2 | 48.37 | 42.36 | 44.73 | 71.43 | 43.77 |

| | 4 | 54.13 | 44.26 | 44.34 | 78.75 | 45.91 |

| | 6 | 49.05 | 43.56 | 44.79 | 121.62 | 45.53 |

| | 8 | 46.27 | 37.56 | 42.29 | 85.80 | 44.37 |

*I lost this element.

Results

  1. Wrapping models(U-Net, TEs, and network) just works.

  2. ZeRO-1 stage is most capable strategy, IMO.

  3. If you want to use multiple of odd number gpus in ZeRO-3 stage, you will meet very slow script.

  4. 2 x 24GB VRAM gpus can run FT on ZeRO-2 stage with CPU offloading.

  5. Ada Lovelace is super fast.

I tried it with a clean installation on a new machineto do a test, but I received warnings and errors that were too long to include here, unfortunately the result was unsuccessful. Could there be a requirement you missed?

@BootsofLagrangian
Copy link
Contributor Author

BootsofLagrangian commented Feb 23, 2024

@mchdks

Here is a simple yet all about installation.

installation

I recommend to use deepspeed anaconda envrionment. First, you need to clone my deepspeed branch.

  1. install anaconda
    • I recommend python=3.10
  2. install anaconda deepspeed env via yml file.
    • conda create -n deepspeed --file=/path/to/YAL_FILE.yml
  3. move dir to sd-scripts deepspeed branch
    • cd /path/to/deepspeed/branch
  4. activate anaconda environment
    • conda activate deepspeed
  5. install pytorch and sd-scripts requirements.txt
    • pip install torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu121
    • pip install -r requirements.txt
    • tested torch version is 2.2.0 but it can run on other torch version >= 2.0.1.
  6. install bitsandbytes, xformers, and lycoris
    • pip install bitsandbytes xformers lycoris_lora
  7. install deepspeed
    • DS_BUILD_OPS=0 pip install deepspeed==0.13.1
  8. accelerate config
  9. prepare test image set and script configuration file, like

CONFIG_FILE.toml
pretrained_model_name_or_path = "./training/base_model/sd_xl_base_1.0.safetensors"
xformers = true
deepspeed = true
zero_stage = 1
mixed_precision = "bf16"
save_precision = "bf16"
full_bf16 = true
output_name = "full_bf16_ff_zero_1"
output_dir = "./training/ds_test/model"
train_data_dir = "./training/test_img"
shuffle_caption = true
caption_extension = ".txt"
random_crop = true
resolution = "1024,1024"
enable_bucket = true
bucket_no_upscale = true
save_every_n_epochs = 1
train_batch_size = 4
max_token_length = 225
max_train_epochs = 1
max_data_loader_n_workers = 4
persistent_data_loader_workers = true
seed = 42
gradient_checkpointing = true
gradient_accumulation_steps = 16
logging_dir = "./training/ds_test/logs"
caption_separator = ". "
noise_offset = 0.0357
learning_rate = 1e-4
unet_lr = 1e-4
learning_rate_te1 = 5e-5
learning_rate_te2 = 5e-5
train_text_encoder = true
max_grad_norm = 1.0
optimizer_type = "AdamW8bit"
save_model_as = "safetensors"
optimizer_args = [ "weight_decay=1e-1", ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 340
no_half_vae = true

  1. run script like this.
accelerate launch --mixed_precision=bf16 \
        --num_processes=8 --num_machines=1 --multi_gpu \
        --main_process_ip=localhost --main_process_port=29555 \
        --num_cpu_threads_per_process=4 \
        ./sdxl_train.py --config_file=$CONFIG_FILE

CONFIG_FILE is a path to above toml file.

  1. you will meet very long warnings and logs. but just ignore them, it doesn't effect on training.

@kohya-ss
Copy link
Owner

Thank you for this great PR! It looks very nice.

However, I dont' have an environment to test DeepSpeed. I know that I can test it with cloud environments, but I prefer to develop other features than testing DeepSpeed.

In addition, the update to the scripts is not a little, so it will be a little hard to maintain.

Therefore, is it OK if I move the features to the single script which supports DeepSpeed as much as possible after merging? I will make a new branch for it, and I'd be happy if you test and review the branch.

I think that if the script works well, it will not be required for me to maintain the script in future, and someone would update the script if necessary.

@BootsofLagrangian
Copy link
Contributor Author

Thank you for this great PR! It looks very nice.

However, I dont' have an environment to test DeepSpeed. I know that I can test it with cloud environments, but I prefer to develop other features than testing DeepSpeed.

In addition, the update to the scripts is not a little, so it will be a little hard to maintain.

Therefore, is it OK if I move the features to the single script which supports DeepSpeed as much as possible after merging? I will make a new branch for it, and I'd be happy if you test and review the branch.

I think that if the script works well, it will not be required for me to maintain the script in future, and someone would update the script if necessary.

Sounds good. It is good to move DeepSpeed features into dev-branch and to postpone merging it.

@kohya-ss kohya-ss changed the base branch from main to deep-speed February 27, 2024 09:55
@kohya-ss kohya-ss merged commit 0e4a573 into kohya-ss:deep-speed Feb 27, 2024
1 check passed
@kohya-ss
Copy link
Owner

@BootsofLagrangian
Hi! I've merged the PR to the new branch, and I'm refactoring a bit the code.

I have a question for DeepSpeedWrapper. In my understanding, Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.

If this is correct, when we pass the wrapper to accelerator.accumulate, the argument for accumulate might be the wrapper, instead of the list of models. Because accumulate takes the prepared model.

Therefore, training_models = [ds_model] might be ok. Is this correct?

@kohya-ss kohya-ss mentioned this pull request Feb 27, 2024
@BootsofLagrangian
Copy link
Contributor Author

@BootsofLagrangian Hi! I've merged the PR to the new branch, and I'm refactoring a bit the code.

I have a question for DeepSpeedWrapper. In my understanding, Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.

If this is correct, when we pass the wrapper to accelerator.accumulate, the argument for accumulate might be the wrapper, instead of the list of models. Because accumulate takes the prepared model.

Therefore, training_models = [ds_model] might be ok. Is this correct?

Yep, that is correct. accelerate do something magical, accumulate method accepts accelerate-compatible Modules.

I tested training_models = [ds_model] and it works.

@kohya-ss
Copy link
Owner

Thank you for clarification! I opened a new PR #1139, I would appreciate your comments and suggestions.

@storuky
Copy link

storuky commented Mar 13, 2024

I got very optimistic results!
I have only 3x4090 in my PC and I was able to run AdamW (not Adam8bit) optimizer on BF16 (without full_bf16) with batch size 16 and 1st text encoder training.
I used zero stage 2 and cpu offloading and cached latents on disk.
Speed is incredible: 10s/it 😱
It took 123 GB of RAM.
Just to note: to run AdamW with batch size 16 you need ~ 75GB VRAM (A100 or H100) and it performs with 3.5s/it (H100) and 5s/it (A100) but effective batch is 16 while for 3x4090 it's 48!

My dataset: 353 images
Repeats: 40
Epochs: 5

Training time
3x4090: 4 hours
1xH100: 4.2 hours
1xA100: 6.1 hours

Unbelievable, but 3x4090 faster than 1 H100 Pcie.

@BootsofLagrangian you are a wizard!

@FurkanGozukara
Copy link

@storuky wow nice results

@BootsofLagrangian BootsofLagrangian deleted the deepspeed branch March 20, 2024 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants