Skip to content

Commit

Permalink
Enable resuming (#52)
Browse files Browse the repository at this point in the history
* v 0.0.5 (#42)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

* Fix lora inject, added weight self apply lora (#39)

* Develop (#34)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

Co-authored-by: brian6091 <brian6091@gmail.com>

* release : version 0.0.4, now able to tune rank, now add loras dynamically

* readme : add brain6091's discussions

* fix:inject lora in to_out module list

* feat: added weight self apply lora

* chore: add import copy

* fix: readded r

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr>

* Revert "Fix lora inject, added weight self apply lora (#39)" (#40)

This reverts commit fececf3.

* fix : rank bug in monkeypatch

* fix cli fix

* visualizatio on effect of LR

Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>

* Enable resume training unet/text encoder (#48)

* Enable resume training unet/text encoder

New flags --resume_text_encoder --resume_unet accept the paths to .pt files to resume.
Make sure to change the output directory from the previous training session, or else .pt files will be overwritten since training does not resume from previous global step.

* Load weights from .pt with inject_trainable_lora

Adds new loras argument to inject_trainable_lora function which accepts path to a .pt file containing previously trained weights.

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>
  • Loading branch information
4 people authored Dec 16, 2022
1 parent a386525 commit 6767142
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
11 changes: 9 additions & 2 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras = None # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -42,6 +43,9 @@ def inject_trainable_lora(
require_grad_params = []
names = []

if loras != None:
loras = torch.load(loras)

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:

Expand All @@ -62,18 +66,21 @@ def inject_trainable_lora(

# switch the module
_module._modules[name] = _tmp

require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)

if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)

_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

return require_grad_params, names


Expand Down
19 changes: 18 additions & 1 deletion train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,22 @@ def parse_args(input_args=None):
default=-1,
help="For distributed training: local_rank",
)
parser.add_argument(
"--resume_unet",
type=str,
default=None,
help=(
"File path for unet lora to resume training."
)
)
parser.add_argument(
"--resume_text_encoder",
type=str,
default=None,
help=(
"File path for text encoder lora to resume training."
)
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -576,7 +592,7 @@ def main(args):
revision=args.revision,
)
unet.requires_grad_(False)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank)
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)

for _up, _down in extract_lora_ups_down(unet):
print("Before training: Unet First Layer lora up", _up.weight.data)
Expand All @@ -590,6 +606,7 @@ def main(args):
text_encoder_lora_params, _ = inject_trainable_lora(
text_encoder, target_replace_module=["CLIPAttention"],
r=args.lora_rank,
loras=args.resume_text_encoder,
)
for _up, _down in extract_lora_ups_down(
text_encoder, target_replace_module=["CLIPAttention"]
Expand Down

0 comments on commit 6767142

Please sign in to comment.