From 091c5894c32c5f87013f50d9f11c5c90adc970e2 Mon Sep 17 00:00:00 2001 From: 2510 Date: Sat, 30 Dec 2023 13:57:52 +0900 Subject: [PATCH 1/3] Fix gradient-checkpointing option is ignored in SDXL+LoRA training. (#6388) --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 3dab86a48492..bc9faf774308 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -706,6 +706,9 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: From 07a52194218e8c22660d10fe2a4e5dfe6f3e2697 Mon Sep 17 00:00:00 2001 From: 2510 Date: Sat, 30 Dec 2023 21:33:19 +0900 Subject: [PATCH 2/3] Fix gradient-checkpointing option is ignored in SD+LoRA training. --- examples/text_to_image/train_text_to_image_lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 6d285362434d..f6575f7494ca 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -486,6 +486,9 @@ def main(): lora_layers = filter(lambda p: p.requires_grad, unet.parameters()) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: From 8d7feb69cacfccaaa465ca70874921602f7b62b3 Mon Sep 17 00:00:00 2001 From: 2510 Date: Sun, 31 Dec 2023 15:53:52 +0900 Subject: [PATCH 3/3] Fix gradient checkpoint is not applied to text encoders. (SDXL+LoRA) --- examples/text_to_image/train_text_to_image_lora_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index bc9faf774308..81dbd0f1f612 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -708,6 +708,9 @@ def load_model_hook(models, input_dir): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices