Skip to content

Commit

Permalink
Revert "Enable resume training unet/text encoder (#48)" (#50)
Browse files Browse the repository at this point in the history
This reverts commit 6f499df.
  • Loading branch information
cloneofsimo authored Dec 16, 2022
1 parent 6f499df commit d42a2d6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 87 deletions.
11 changes: 2 additions & 9 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras = None # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -43,9 +42,6 @@ def inject_trainable_lora(
require_grad_params = []
names = []

if loras != None:
loras = torch.load(loras)

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:

Expand All @@ -66,21 +62,18 @@ def inject_trainable_lora(

# switch the module
_module._modules[name] = _tmp

require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)

if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)

_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

return require_grad_params, names


Expand Down
139 changes: 61 additions & 78 deletions train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,22 +413,6 @@ def parse_args(input_args=None):
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=(
"File path for unet lora to resume training."
)
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=(
"File path for text encoder lora to resume training."
)
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -592,7 +576,7 @@ def main(args):
revision=args.revision,
)
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank)

for _up, _down in extract_lora_ups_down(unet):
print("Before training: Unet First Layer lora up", _up.weight.data)
Expand All @@ -606,7 +590,6 @@ def main(args):
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder, target_replace_module=["CLIPAttention"],
r=args.lora_rank,
loras=args.resume_text_encoder,
)
for _up, _down in extract_lora_ups_down(
text_encoder, target_replace_module=["CLIPAttention"]
Expand Down Expand Up @@ -881,74 +864,74 @@ def collate_fn(examples):

global_step += 1

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.save_steps and global_step - last_save >= args.save_steps:
if accelerator.is_main_process:
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
# it, the models will be unwrapped, and when they are then used for further training,
# we will crash. pass this, but only to newer versions of accelerate. fixes
# https://github.com/huggingface/diffusers/issues/1566
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
inspect.signature(accelerator.unwrap_model).parameters.keys()
)
extra_args = (
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet, **extra_args),
text_encoder=accelerator.unwrap_model(
text_encoder, **extra_args
),
revision=args.revision,
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
)

filename_unet = (
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
_down.weight.data,
)
filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
print(f"save weights {filename_unet}, {filename_text_encoder}")
save_lora_weight(pipeline.unet, filename_unet)
if args.train_text_encoder:
save_lora_weight(
pipeline.text_encoder,
filename_text_encoder,
target_replace_module=["CLIPAttention"],
break
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)

for _up, _down in extract_lora_ups_down(pipeline.unet):
print("First Unet Layer's Up Weight is now : ", _up.weight.data)
print(
"First Unet Layer's Down Weight is now : ",
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break
if args.train_text_encoder:
for _up, _down in extract_lora_ups_down(
pipeline.text_encoder,
target_replace_module=["CLIPAttention"],
):
print(
"First Text Encoder Layer's Up Weight is now : ",
_up.weight.data,
)
print(
"First Text Encoder Layer's Down Weight is now : ",
_down.weight.data,
)
break

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

last_save = global_step

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
Expand Down

0 comments on commit d42a2d6

Please sign in to comment.