Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DreamBooth DeepSpeed support for under 8 GB VRAM training #735

Merged
merged 5 commits into from Oct 10, 2022

Conversation

Ttl
Copy link
Contributor

@Ttl Ttl commented Oct 5, 2022

Add instructions on how to enable DeepSpeed in DreamBooth example to allow training on under 8 GB VRAM. I was able to train a working network using this on 8 GB VRAM GPU.

It did not work out of the box with fp16 mixed precision and I had to add some explicit casts to make it run. Without them it raises Exception: RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 5, 2022

The documentation is not available anymore as the PR was closed or merged.

@pink-red
Copy link
Contributor

pink-red commented Oct 5, 2022

Just tried this version, it works pretty well! DreamBooth finally fits onto a 3060, although it used 10743 MiB for me. Does DeepSpeed adjust the training to fit into VRAM or was it something else?

Here's time comparison for different cases (800 steps in each case, not including time to generate class images):

Device (script version) Time
GPU, T4 (ShivamShrirao's Colab) ~17 min*
GPU, 3060 + CPU, 5950x** (DeepSpeed) ~58 min
CPU, 5950x ~9 hours

*ShivamShrirao's fork caches the latents, which speeds up the training, but takes some time itself. For me, it was 02:34 for latent caching (for 103 instance images) and 14:36 for the training itself, resulting in 17:10.

**I've noticed that with DeepSpeed enabled my CPU was under pretty heavy load, so including it into table too.

By the way, if you don't want to run accelerate config, which will change the default settings, try these arguments:

accelerate launch --use_deepspeed --zero_stage=2 --gradient_accumulation_steps=1 --offload_param_device=cpu --offload_optimizer_device=cpu train_dreambooth.py *training_arguments*

Full command I've used for training:

accelerate launch --use_deepspeed --zero_stage=2 --gradient_accumulation_steps=1 --offload_param_device=cpu --offload_optimizer_device=cpu train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="a photo of sks dog" \
  --class_prompt="a photo of dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --num_class_images=200 \
  --max_train_steps=800 \
  --sample_batch_size=2 \
  --mixed_precision=fp16

@patil-suraj patil-suraj self-assigned this Oct 7, 2022
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean to me!

@devilismyfriend
Copy link

devilismyfriend commented Oct 8, 2022

Unfortunately, this did not work for me, tested on a 3080 10GB and 64GB of RAM.

The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_cpu_threads_per_process` was set to `8` to improve out-of-box performance
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
[2022-10-06 20:56:28,679] [WARNING] [runner.py:178:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2022-10-06 20:56:29,043] [INFO] [runner.py:504:main] cmd = /root/anaconda3/envs/diffusers/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --no_local_rank train_dreambooth.py --pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4 --instance_data_dir=training --output_dir=model_output --class_data_dir=animeGrl --with_prior_preservation --prior_loss_weight=1.0 --instance_prompt=a photo of cplucy animeGrl --class_prompt=a photo of anime girl --resolution=512 --train_batch_size=1 --gradient_accumulation_steps=1 --gradient_checkpointing --learning_rate=5e-6 --lr_scheduler=constant --lr_warmup_steps=0 --num_class_images=200 --max_train_steps=3000 --mixed_precision=fp16
[2022-10-06 20:56:29,998] [INFO] [launch.py:136:main] WORLD INFO DICT: {'localhost': [0]}
[2022-10-06 20:56:29,998] [INFO] [launch.py:142:main] nnodes=1, num_local_procs=1, node_rank=0
[2022-10-06 20:56:29,998] [INFO] [launch.py:155:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2022-10-06 20:56:29,998] [INFO] [launch.py:156:main] dist_world_size=1
[2022-10-06 20:56:29,998] [INFO] [launch.py:158:main] Setting CUDA_VISIBLE_DEVICES=0
[2022-10-06 20:56:32,105] [INFO] [comm.py:633:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2022-10-06 20:56:45,211] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.7.3, git-hash=unknown, git-branch=unknown
[2022-10-06 20:56:48,066] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2022-10-06 20:56:48,066] [INFO] [logging.py:68:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer
[2022-10-06 20:56:48,066] [INFO] [logging.py:68:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2022-10-06 20:56:48,115] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}
[2022-10-06 20:56:48,115] [INFO] [utils.py:52:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=<class 'torch.optim.adamw.AdamW'>
[2022-10-06 20:56:48,115] [INFO] [logging.py:68:log_dist] [Rank 0] Creating fp16 ZeRO stage 2 optimizer
[2022-10-06 20:56:48,115] [INFO] [stage_1_and_2.py:134:__init__] Reduce bucket size 500000000
[2022-10-06 20:56:48,115] [INFO] [stage_1_and_2.py:135:__init__] Allgather bucket size 500000000
[2022-10-06 20:56:48,115] [INFO] [stage_1_and_2.py:136:__init__] CPU Offload: True
[2022-10-06 20:56:48,115] [INFO] [stage_1_and_2.py:137:__init__] Round robin gradient partitioning: False
Using /root/.cache/torch_extensions/py39_cu116 as PyTorch extensions root...
Emitting ninja build file /root/.cache/torch_extensions/py39_cu116/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
Time to load utils op: 0.1581101417541504 seconds
Rank: 0 partition count [1] and sizes[(859520964, False)]
[2022-10-06 20:56:50,230] [INFO] [utils.py:827:see_memory_usage] Before initializing optimizer states
[2022-10-06 20:56:50,231] [INFO] [utils.py:828:see_memory_usage] MA 1.66 GB         Max_MA 1.66 GB         CA 3.27 GB         Max_CA 3 GB
[2022-10-06 20:56:50,231] [INFO] [utils.py:836:see_memory_usage] CPU Virtual Memory:  used = 8.45 GB, percent = 27.0%
Traceback (most recent call last):
  File "/root/github/diffusers-dreambooth_deepspeed/examples/dreambooth/train_dreambooth.py", line 613, in <module>
    main()
  File "/root/github/diffusers-dreambooth_deepspeed/examples/dreambooth/train_dreambooth.py", line 489, in main
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/accelerate/accelerator.py", line 619, in prepare
    result = self._prepare_deepspeed(*args)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/accelerate/accelerator.py", line 805, in _prepare_deepspeed
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/__init__.py", line 124, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 320, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1144, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1395, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer(
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 512, in __init__
    self.initialize_optimizer_states()
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 599, in initialize_optimizer_states
    i].grad = single_grad_partition.pin_memory(
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[2022-10-06 20:56:52,025] [INFO] [launch.py:286:sigkill_handler] Killing subprocess 76
[2022-10-06 20:56:52,025] [ERROR] [launch.py:292:sigkill_handler] ['/root/anaconda3/envs/diffusers/bin/python', '-u', 'train_dreambooth.py', '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4', '--instance_data_dir=training', '--output_dir=model_output', '--class_data_dir=animeGrl', '--with_prior_preservation', '--prior_loss_weight=1.0', '--instance_prompt=a photo of cplucy animeGrl', '--class_prompt=a photo of anime girl', '--resolution=512', '--train_batch_size=1', '--gradient_accumulation_steps=1', '--gradient_checkpointing', '--learning_rate=5e-6', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--num_class_images=200', '--max_train_steps=3000', '--mixed_precision=fp16'] exits with return code = 1
Traceback (most recent call last):
  File "/root/anaconda3/envs/diffusers/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/accelerate/commands/accelerate_cli.py", line 43, in main
    args.func(args)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/accelerate/commands/launch.py", line 827, in launch_command
    deepspeed_launcher(args)
  File "/root/anaconda3/envs/diffusers/lib/python3.9/site-packages/accelerate/commands/launch.py", line 540, in deepspeed_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['deepspeed', '--no_local_rank', '--num_gpus', '1', 'train_dreambooth.py', '--pretrained_model_name_or_path=CompVis/stable-diffusion-v1-4', '--instance_data_dir=training', '--output_dir=model_output', '--class_data_dir=animeGrl', '--with_prior_preservation', '--prior_loss_weight=1.0', '--instance_prompt=a photo of cplucy animeGrl', '--class_prompt=a photo of anime girl', '--resolution=512', '--train_batch_size=1', '--gradient_accumulation_steps=1', '--gradient_checkpointing', '--learning_rate=5e-6', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--num_class_images=200', '--max_train_steps=3000', '--mixed_precision=fp16']' returned non-zero exit status 1.

Due to recent commits some casts to half precision are not necessary
anymore.

Mention that DeepSpeed's version of Adam is about 2x faster.
@Ttl
Copy link
Contributor Author

Ttl commented Oct 8, 2022

Enabling DeepSpeedCPUAdam optimizer gives about 2x speedup. I didn't make code changes for it but added a mention in the README.

@Thomas-MMJ
Copy link

Thomas-MMJ commented Oct 9, 2022

8-bit optimizer does not seem to be compatible with DeepSpeed at the moment.

There is a patch in the deepspeed repo to allow 8bit_adam to work with deepspeed, I haven't tested it so not sure if it works since there hasn't been any progress/comments on it since December.

microsoft/DeepSpeed#1582

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very cool! Thanks a lot for working on this. I just left a couple of nits and a questions about loss computation.

examples/dreambooth/train_dreambooth.py Show resolved Hide resolved
examples/dreambooth/train_dreambooth.py Outdated Show resolved Hide resolved

# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cast to float32 here ? Do we always want to compute loss in full precision ?
cc @patrickvonplaten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mixed precision best practice to calculate large reduction in higher precision. This calculates mean of batch_size * 4 * 64 * 64 halfs and mse_loss is one of the operations that would be automatically casted to fp32 in fp16 with autocast. I'm not sure it it's necessary at low batch size as it does seem to work without it, but it doesn't really affect memory consumption since it's only one operation and should give some safety.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense!

examples/dreambooth/README.md Outdated Show resolved Hide resolved
examples/dreambooth/README.md Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

Good to merge for me if @patil-suraj is happy with it!

@patil-suraj
Copy link
Contributor

Thanks a lot for the amazing contribution !

@patil-suraj patil-suraj merged commit 81bdbb5 into huggingface:main Oct 10, 2022
@tcapelle
Copy link

It appears that mixed precision is broken now

prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
…e#735)

* Support deepspeed

* Dreambooth DeepSpeed documentation

* Remove unnecessary casts, documentation

Due to recent commits some casts to half precision are not necessary
anymore.

Mention that DeepSpeed's version of Adam is about 2x faster.

* Review comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants