From 211605a12c92c0c681dde5e4245f9f21bc86fc14 Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Fri, 9 Dec 2022 14:42:20 -0800 Subject: [PATCH 1/2] fix global steps tracking & --save_steps intermittent saves (#6, #8) --- train_lora_dreambooth.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index f32f81a..d57ca71 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -828,13 +828,14 @@ def collate_fn(examples): ) progress_bar.set_description("Steps") global_step = 0 + last_save = 0 for epoch in range(args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() - for step, batch in enumerate(train_dataloader): + for step, batch in enumerate(train_dataloader): # Convert images to latent space latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype) @@ -908,21 +909,24 @@ def collate_fn(examples): progress_bar.update(1) optimizer.zero_grad() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - global_step += 1 - if global_step % args.save_steps == 0: + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, True), + text_encoder=accelerator.unwrap_model(text_encoder, True), revision=args.revision, ) - save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") + filename = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" + print(f"save weights {filename}") + save_lora_weight(pipeline.unet, filename) + + last_save = global_step logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 259795d7f421ab3edf20843cb45c095994c52ed0 Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Fri, 9 Dec 2022 15:01:30 -0800 Subject: [PATCH 2/2] guard against calling unwrap_model with keep_fp32_wrapper arg on older versions of accelerate (which don't support it) --- train_lora_dreambooth.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index d57ca71..df7d994 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -6,6 +6,7 @@ import itertools import math import os +import inspect from pathlib import Path from typing import Optional @@ -915,10 +916,18 @@ def collate_fn(examples): if accelerator.sync_gradients: if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet, True), - text_encoder=accelerator.unwrap_model(text_encoder, True), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, )