Skip to content

Commit

Permalink
Enable resume training unet/text encoder (#48)
Browse files Browse the repository at this point in the history
* Enable resume training unet/text encoder

New flags --resume_text_encoder --resume_unet accept the paths to .pt files to resume.
Make sure to change the output directory from the previous training session, or else .pt files will be overwritten since training does not resume from previous global step.

* Load weights from .pt with inject_trainable_lora

Adds new loras argument to inject_trainable_lora function which accepts path to a .pt file containing previously trained weights.
  • Loading branch information
hdon96 committed Dec 16, 2022
1 parent b64b1d4 commit 6f499df
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 63 deletions.
11 changes: 9 additions & 2 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras = None # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -42,6 +43,9 @@ def inject_trainable_lora(
require_grad_params = []
names = []

if loras != None:
loras = torch.load(loras)

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:

Expand All @@ -62,18 +66,21 @@ def inject_trainable_lora(

# switch the module
_module._modules[name] = _tmp

require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)

if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)

_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

return require_grad_params, names


Expand Down
139 changes: 78 additions & 61 deletions train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,22 @@ def parse_args(input_args=None):
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=(
"File path for unet lora to resume training."
)
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=(
"File path for text encoder lora to resume training."
)
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -576,7 +592,7 @@ def main(args):
revision=args.revision,
)
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)

for _up, _down in extract_lora_ups_down(unet):
print("Before training: Unet First Layer lora up", _up.weight.data)
Expand All @@ -590,6 +606,7 @@ def main(args):
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder, target_replace_module=["CLIPAttention"],
r=args.lora_rank,
loras=args.resume_text_encoder,
)
for _up, _down in extract_lora_ups_down(
text_encoder, target_replace_module=["CLIPAttention"]
Expand Down Expand Up @@ -864,74 +881,74 @@ def collate_fn(examples):

global_step += 1

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
break
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Text Encoder Layer's Down Weight is now : ",
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
break

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down

0 comments on commit 6f499df

Please sign in to comment.