diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index f32f81a..df7d994 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -6,6 +6,7 @@ import itertools import math import os +import inspect from pathlib import Path from typing import Optional @@ -828,13 +829,14 @@ def collate_fn(examples): ) progress_bar.set_description("Steps") global_step = 0 + last_save = 0 for epoch in range(args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() - for step, batch in enumerate(train_dataloader): + for step, batch in enumerate(train_dataloader): # Convert images to latent space latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype) @@ -908,21 +910,32 @@ def collate_fn(examples): progress_bar.update(1) optimizer.zero_grad() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - global_step += 1 - if global_step % args.save_steps == 0: + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.save_steps and global_step - last_save >= args.save_steps: if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, ) - save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") + filename = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt" + print(f"save weights {filename}") + save_lora_weight(pipeline.unet, filename) + + last_save = global_step logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs)