diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 2c70e64..32193e3 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -1,12 +1,12 @@ import math +from types import UnionType from typing import Callable, Dict, List, Optional, Tuple import numpy as np import PIL import torch -import torch.nn.functional as F - import torch.nn as nn +import torch.nn.functional as F class LoraInjectedLinear(nn.Module): @@ -30,9 +30,69 @@ def forward(self, input): return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE + + +def _find_children( + model, + search_class: type[nn.Module] | UnionType = nn.Linear, +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if isinstance(module, search_class): + yield parent, name, module + + +def _find_modules( + model, + ancestor_class: set[str] = DEFAULT_TARGET_REPLACE, + search_class: type[nn.Module] | UnionType = nn.Linear, + exclude_children_of: type[nn.Module] | UnionType = LoraInjectedLinear, +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + ancestors = ( + module + for module in model.modules() + if module.__class__.__name__ in ancestor_class + ) + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if isinstance(module, search_class): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and isinstance(parent, exclude_children_of): + continue + # Otherwise, yield it + yield parent, name, module + + def inject_trainable_lora( model: nn.Module, - target_replace_module: List[str] = ["CrossAttention", "Attention", "GEGLU"], + target_replace_module: set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, loras=None, # path to lora .pt ): @@ -46,64 +106,57 @@ def inject_trainable_lora( if loras != None: loras = torch.load(loras) - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=nn.Linear + ): + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) - for name, _child_module in _module.named_modules(): - if _child_module.__class__.__name__ == "Linear": - - weight = _child_module.weight - bias = _child_module.bias - _tmp = LoraInjectedLinear( - _child_module.in_features, - _child_module.out_features, - _child_module.bias is not None, - r, - ) - _tmp.linear.weight = weight - if bias is not None: - _tmp.linear.bias = bias - - # switch the module - _module._modules[name] = _tmp - - require_grad_params.append( - _module._modules[name].lora_up.parameters() - ) - require_grad_params.append( - _module._modules[name].lora_down.parameters() - ) - - if loras != None: - _module._modules[name].lora_up.weight = loras.pop(0) - _module._modules[name].lora_down.weight = loras.pop(0) - - _module._modules[name].lora_up.weight.requires_grad = True - _module._modules[name].lora_down.weight.requires_grad = True - names.append(name) return require_grad_params, names -def extract_lora_ups_down( - model, target_replace_module=["CrossAttention", "Attention", "GEGLU"] -): +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): loras = [] - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: - for _child_module in _module.modules(): - if _child_module.__class__.__name__ == "LoraInjectedLinear": - loras.append((_child_module.lora_up, _child_module.lora_down)) + for _m, _n, _child_module in _find_modules( + model, target_replace_module, search_class=LoraInjectedLinear + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + if len(loras) == 0: raise ValueError("No lora injected.") + return loras def save_lora_weight( model, path="./lora.pt", - target_replace_module=["CrossAttention", "Attention", "GEGLU"], + target_replace_module=DEFAULT_TARGET_REPLACE, ): weights = [] for _up, _down in extract_lora_ups_down( @@ -128,137 +181,174 @@ def save_lora_as_json(model, path="./lora.json"): def weight_apply_lora( - model, - loras, - target_replace_module=["CrossAttention", "Attention", "GEGLU"], - alpha=1.0, + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0 ): - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: - for _child_module in _module.modules(): - if _child_module.__class__.__name__ == "Linear": - - weight = _child_module.weight + for _m, _n, _child_module in _find_modules( + model, target_replace_module, search_class=nn.Linear + ): + weight = _child_module.weight - up_weight = loras.pop(0).detach().to(weight.device) - down_weight = loras.pop(0).detach().to(weight.device) + up_weight = loras.pop(0).detach().to(weight.device) + down_weight = loras.pop(0).detach().to(weight.device) - # W <- W + U * D - weight = weight + alpha * (up_weight @ down_weight).type( - weight.dtype - ) - _child_module.weight = nn.Parameter(weight) + # W <- W + U * D + weight = weight + alpha * (up_weight @ down_weight).type(weight.dtype) + _child_module.weight = nn.Parameter(weight) def monkeypatch_lora( - model, - loras, - target_replace_module=["CrossAttention", "Attention", "GEGLU"], - r: int = 4, + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 ): - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: - for name, _child_module in _module.named_modules(): - if _child_module.__class__.__name__ == "Linear": - - weight = _child_module.weight - bias = _child_module.bias - _tmp = LoraInjectedLinear( - _child_module.in_features, - _child_module.out_features, - _child_module.bias is not None, - r=r, - ) - _tmp.linear.weight = weight + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=nn.Linear + ): + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + ) + _tmp.linear.weight = weight - if bias is not None: - _tmp.linear.bias = bias + if bias is not None: + _tmp.linear.bias = bias - # switch the module - _module._modules[name] = _tmp + # switch the module + _module._modules[name] = _tmp - up_weight = loras.pop(0) - down_weight = loras.pop(0) + up_weight = loras.pop(0) + down_weight = loras.pop(0) - _module._modules[name].lora_up.weight = nn.Parameter( - up_weight.type(weight.dtype) - ) - _module._modules[name].lora_down.weight = nn.Parameter( - down_weight.type(weight.dtype) - ) + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) - _module._modules[name].to(weight.device) + _module._modules[name].to(weight.device) def monkeypatch_replace_lora( - model, - loras, - target_replace_module=["CrossAttention", "Attention", "GEGLU"], - r: int = 4, + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 ): - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: - for name, _child_module in _module.named_modules(): - if _child_module.__class__.__name__ == "LoraInjectedLinear": + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=LoraInjectedLinear + ): + weight = _child_module.linear.weight + bias = _child_module.linear.bias + _tmp = LoraInjectedLinear( + _child_module.linear.in_features, + _child_module.linear.out_features, + _child_module.linear.bias is not None, + r=r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias - weight = _child_module.linear.weight - bias = _child_module.linear.bias - _tmp = LoraInjectedLinear( - _child_module.linear.in_features, - _child_module.linear.out_features, - _child_module.linear.bias is not None, - r=r, - ) - _tmp.linear.weight = weight + # switch the module + _module._modules[name] = _tmp - if bias is not None: - _tmp.linear.bias = bias + up_weight = loras.pop(0) + down_weight = loras.pop(0) - # switch the module - _module._modules[name] = _tmp + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) - up_weight = loras.pop(0) - down_weight = loras.pop(0) + _module._modules[name].to(weight.device) - _module._modules[name].lora_up.weight = nn.Parameter( - up_weight.type(weight.dtype) - ) - _module._modules[name].lora_down.weight = nn.Parameter( - down_weight.type(weight.dtype) - ) - _module._modules[name].to(weight.device) +def monkeypatch_or_replace_lora( + model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=nn.Linear | LoraInjectedLinear + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_children( + model, search_class=LoraInjectedLinear + ): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear(_source.in_features, _source.out_features, bias is not None) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp def monkeypatch_add_lora( model, loras, - target_replace_module=["CrossAttention", "Attention", "GEGLU"], + target_replace_module=DEFAULT_TARGET_REPLACE, alpha: float = 1.0, beta: float = 1.0, ): - for _module in model.modules(): - if _module.__class__.__name__ in target_replace_module: - for name, _child_module in _module.named_modules(): - if _child_module.__class__.__name__ == "LoraInjectedLinear": - - weight = _child_module.linear.weight + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=LoraInjectedLinear + ): + weight = _child_module.linear.weight - up_weight = loras.pop(0) - down_weight = loras.pop(0) + up_weight = loras.pop(0) + down_weight = loras.pop(0) - _module._modules[name].lora_up.weight = nn.Parameter( - up_weight.type(weight.dtype).to(weight.device) * alpha - + _module._modules[name].lora_up.weight.to(weight.device) * beta - ) - _module._modules[name].lora_down.weight = nn.Parameter( - down_weight.type(weight.dtype).to(weight.device) * alpha - + _module._modules[name].lora_down.weight.to(weight.device) - * beta - ) + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) - _module._modules[name].to(weight.device) + _module._modules[name].to(weight.device) def tune_lora_scale(model, alpha: float = 1.0): @@ -330,41 +420,18 @@ def patch_pipe( ti_path = _ti_lora_path(unet_path) text_path = _text_lora_path(unet_path) - unet_has_lora = False - text_encoder_has_lora = False - - for _module in pipe.unet.modules(): - if _module.__class__.__name__ == "LoraInjectedLinear": - unet_has_lora = True - - for _module in pipe.text_encoder.modules(): - if _module.__class__.__name__ == "LoraInjectedLinear": - text_encoder_has_lora = True if patch_unet: print("LoRA : Patching Unet") - - if not unet_has_lora: - monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r) - else: - monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r) + monkeypatch_or_replace_lora(pipe.unet, torch.load(unet_path), r=r) if patch_text: print("LoRA : Patching text encoder") - if not text_encoder_has_lora: - monkeypatch_lora( - pipe.text_encoder, - torch.load(text_path), - target_replace_module=["CLIPAttention"], - r=r, - ) - else: - - monkeypatch_replace_lora( - pipe.text_encoder, - torch.load(text_path), - target_replace_module=["CLIPAttention"], - r=r, - ) + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + r=r, + ) if patch_ti: print("LoRA : Patching token input") token = load_learned_embed_in_clip( @@ -377,7 +444,7 @@ def patch_pipe( @torch.no_grad() -def inspect_lora(model, target_replace_module=["CrossAttention", "Attention", "GEGLU"]): +def inspect_lora(model, target_replace_module=DEFAULT_TARGET_REPLACE): fnorm = {k: [] for k in target_replace_module}