Skip to content

Commit 2bfa55f

Browse files
younesbelkadapatrickvonplatenpacman100BenjaminBossansayakpaul
authored
[core / PEFT / LoRA] Integrate PEFT into Unet (huggingface#5151)
* v1 * add tests and fix previous failing tests * fix CI * add tests + v1 `PeftLayerScaler` * style * add scale retrieving mechanism system * fix CI * up * up * simple approach --> not same results for some reason * fix issues * fix copies * remove unneeded method * active adapters! * fix merge conflicts * up * up * kohya - test-1 * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix scale * fix copies * add comment * multi adapters * fix tests * oops * v1 faster loading - in progress * Revert "v1 faster loading - in progress" This reverts commit ac925f8. * kohya same generation * fix some slow tests * peft integration features for unet lora 1. Support for Multiple ranks/alphas 2. Support for Multiple active adapters 3. Support for enabling/disabling LoRAs * fix `get_peft_kwargs` * Update loaders.py * add some tests * add unfuse tests * fix tests * up * add set adapter from sourab and tests * fix multi adapter tests * style & quality * style * remove comment * fix `adapter_name` issues * fix unet adapter name for sdxl * fix enabling/disabling adapters * fix fuse / unfuse unet * nit * fix * up * fix cpu offloading * fix another slow test * fix another offload test * add more tests * all slow tests pass * style * fix alpha pattern for unet and text encoder * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update src/diffusers/models/attention.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * up * up * clarify comment * comments * change comment order * change comment order * stylr & quality * Update tests/lora/test_lora_layers_peft.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix bugs and add tests * Update src/diffusers/models/modeling_utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update src/diffusers/models/modeling_utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * refactor * suggestion * add break statemebt * add compile tests * move slow tests to peft tests as I modified them * quality * refactor a bit * style * change import * style * fix CI * refactor slow tests one last time * style * oops * oops * oops * final tweak tests * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/loaders.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * comments * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * remove comments * more comments * try * revert * add `safe_merge` tests * add comment * style, comments and run tests in fp16 * add warnings * fix doc test * replace with `adapter_weights` * add `get_active_adapters()` * expose `get_list_adapters` method * better error message * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style * trigger slow lora tests * fix tests * maybe fix last test * revert * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * move `MIN_PEFT_VERSION` * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * let's not use class variable * fix few nits * change a bit offloading logic * check earlier * rm unneeded block * break long line * return empty list * change logic a bit and address comments * add typehint * remove parenthesis * fix * revert to fp16 in tests * add to gpu * revert to old test * style * Update src/diffusers/loaders.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * change indent * Apply suggestions from code review * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent 9bc55e8 commit 2bfa55f

File tree

51 files changed

+2271
-322
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2271
-322
lines changed

src/diffusers/loaders.py

Lines changed: 377 additions & 66 deletions
Large diffs are not rendered by default.

src/diffusers/models/attention.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20+
from ..utils import USE_PEFT_BACKEND
2021
from ..utils.torch_utils import maybe_allow_in_graph
2122
from .activations import get_activation
2223
from .attention_processor import Attention
@@ -300,6 +301,7 @@ def __init__(
300301
super().__init__()
301302
inner_dim = int(dim * mult)
302303
dim_out = dim_out if dim_out is not None else dim
304+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
303305

304306
if activation_fn == "gelu":
305307
act_fn = GELU(dim, inner_dim)
@@ -316,14 +318,15 @@ def __init__(
316318
# project dropout
317319
self.net.append(nn.Dropout(dropout))
318320
# project out
319-
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
321+
self.net.append(linear_cls(inner_dim, dim_out))
320322
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
321323
if final_dropout:
322324
self.net.append(nn.Dropout(dropout))
323325

324326
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
327+
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
325328
for module in self.net:
326-
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
329+
if isinstance(module, compatible_cls):
327330
hidden_states = module(hidden_states, scale)
328331
else:
329332
hidden_states = module(hidden_states)
@@ -368,7 +371,9 @@ class GEGLU(nn.Module):
368371

369372
def __init__(self, dim_in: int, dim_out: int):
370373
super().__init__()
371-
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
374+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
375+
376+
self.proj = linear_cls(dim_in, dim_out * 2)
372377

373378
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
374379
if gate.device.type != "mps":
@@ -377,7 +382,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
377382
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
378383

379384
def forward(self, hidden_states, scale: float = 1.0):
380-
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
385+
args = () if USE_PEFT_BACKEND else (scale,)
386+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
381387
return hidden_states * self.gelu(gate)
382388

383389

src/diffusers/models/attention_processor.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21-
from ..utils import deprecate, logging
21+
from ..utils import USE_PEFT_BACKEND, deprecate, logging
2222
from ..utils.import_utils import is_xformers_available
2323
from ..utils.torch_utils import maybe_allow_in_graph
2424
from .lora import LoRACompatibleLinear, LoRALinearLayer
@@ -137,22 +137,27 @@ def __init__(
137137
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
138138
)
139139

140-
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
140+
if USE_PEFT_BACKEND:
141+
linear_cls = nn.Linear
142+
else:
143+
linear_cls = LoRACompatibleLinear
144+
145+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
141146

142147
if not self.only_cross_attention:
143148
# only relevant for the `AddedKVProcessor` classes
144-
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
145-
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
149+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
150+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
146151
else:
147152
self.to_k = None
148153
self.to_v = None
149154

150155
if self.added_kv_proj_dim is not None:
151-
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
152-
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
156+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
157+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
153158

154159
self.to_out = nn.ModuleList([])
155-
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
160+
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
156161
self.to_out.append(nn.Dropout(dropout))
157162

158163
# set attention processor
@@ -545,6 +550,8 @@ def __call__(
545550
):
546551
residual = hidden_states
547552

553+
args = () if USE_PEFT_BACKEND else (scale,)
554+
548555
if attn.spatial_norm is not None:
549556
hidden_states = attn.spatial_norm(hidden_states, temb)
550557

@@ -562,15 +569,15 @@ def __call__(
562569
if attn.group_norm is not None:
563570
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
564571

565-
query = attn.to_q(hidden_states, scale=scale)
572+
query = attn.to_q(hidden_states, *args)
566573

567574
if encoder_hidden_states is None:
568575
encoder_hidden_states = hidden_states
569576
elif attn.norm_cross:
570577
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
571578

572-
key = attn.to_k(encoder_hidden_states, scale=scale)
573-
value = attn.to_v(encoder_hidden_states, scale=scale)
579+
key = attn.to_k(encoder_hidden_states, *args)
580+
value = attn.to_v(encoder_hidden_states, *args)
574581

575582
query = attn.head_to_batch_dim(query)
576583
key = attn.head_to_batch_dim(key)
@@ -581,7 +588,7 @@ def __call__(
581588
hidden_states = attn.batch_to_head_dim(hidden_states)
582589

583590
# linear proj
584-
hidden_states = attn.to_out[0](hidden_states, scale=scale)
591+
hidden_states = attn.to_out[0](hidden_states, *args)
585592
# dropout
586593
hidden_states = attn.to_out[1](hidden_states)
587594

@@ -1007,15 +1014,20 @@ def __call__(
10071014
if attn.group_norm is not None:
10081015
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
10091016

1010-
query = attn.to_q(hidden_states, scale=scale)
1017+
args = () if USE_PEFT_BACKEND else (scale,)
1018+
query = attn.to_q(hidden_states, *args)
10111019

10121020
if encoder_hidden_states is None:
10131021
encoder_hidden_states = hidden_states
10141022
elif attn.norm_cross:
10151023
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
10161024

1017-
key = attn.to_k(encoder_hidden_states, scale=scale)
1018-
value = attn.to_v(encoder_hidden_states, scale=scale)
1025+
key = (
1026+
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
1027+
)
1028+
value = (
1029+
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
1030+
)
10191031

10201032
inner_dim = key.shape[-1]
10211033
head_dim = inner_dim // attn.heads
@@ -1035,7 +1047,9 @@ def __call__(
10351047
hidden_states = hidden_states.to(query.dtype)
10361048

10371049
# linear proj
1038-
hidden_states = attn.to_out[0](hidden_states, scale=scale)
1050+
hidden_states = (
1051+
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
1052+
)
10391053
# dropout
10401054
hidden_states = attn.to_out[1](hidden_states)
10411055

src/diffusers/models/embeddings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch import nn
2020

21+
from ..utils import USE_PEFT_BACKEND
2122
from .activations import get_activation
2223
from .lora import LoRACompatibleLinear
2324

@@ -166,8 +167,9 @@ def __init__(
166167
cond_proj_dim=None,
167168
):
168169
super().__init__()
170+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
169171

170-
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
172+
self.linear_1 = linear_cls(in_channels, time_embed_dim)
171173

172174
if cond_proj_dim is not None:
173175
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -180,7 +182,7 @@ def __init__(
180182
time_embed_dim_out = out_dim
181183
else:
182184
time_embed_dim_out = time_embed_dim
183-
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
185+
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
184186

185187
if post_act_fn is None:
186188
self.post_act = None

src/diffusers/models/modeling_utils.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
DIFFUSERS_CACHE,
3333
FLAX_WEIGHTS_NAME,
3434
HF_HUB_OFFLINE,
35+
MIN_PEFT_VERSION,
3536
SAFETENSORS_WEIGHTS_NAME,
3637
WEIGHTS_NAME,
3738
_add_variant,
3839
_get_model_file,
40+
check_peft_version,
3941
deprecate,
4042
is_accelerate_available,
4143
is_torch_version,
@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
187189
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
188190
_supports_gradient_checkpointing = False
189191
_keys_to_ignore_on_load_unexpected = None
192+
_hf_peft_config_loaded = False
190193

191194
def __init__(self):
192195
super().__init__()
@@ -292,6 +295,153 @@ def disable_xformers_memory_efficient_attention(self):
292295
"""
293296
self.set_use_memory_efficient_attention_xformers(False)
294297

298+
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
299+
r"""
300+
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
301+
to the adapter to follow the convention of the PEFT library.
302+
303+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
304+
[documentation](https://huggingface.co/docs/peft).
305+
306+
Args:
307+
adapter_config (`[~peft.PeftConfig]`):
308+
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
309+
methods.
310+
adapter_name (`str`, *optional*, defaults to `"default"`):
311+
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
312+
"""
313+
check_peft_version(min_version=MIN_PEFT_VERSION)
314+
315+
from peft import PeftConfig, inject_adapter_in_model
316+
317+
if not self._hf_peft_config_loaded:
318+
self._hf_peft_config_loaded = True
319+
elif adapter_name in self.peft_config:
320+
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
321+
322+
if not isinstance(adapter_config, PeftConfig):
323+
raise ValueError(
324+
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
325+
)
326+
327+
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
328+
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
329+
adapter_config.base_model_name_or_path = None
330+
inject_adapter_in_model(adapter_config, self, adapter_name)
331+
self.set_adapter(adapter_name)
332+
333+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
334+
"""
335+
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
336+
337+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
338+
official documentation: https://huggingface.co/docs/peft
339+
340+
Args:
341+
adapter_name (Union[str, List[str]])):
342+
The list of adapters to set or the adapter name in case of single adapter.
343+
"""
344+
check_peft_version(min_version=MIN_PEFT_VERSION)
345+
346+
if not self._hf_peft_config_loaded:
347+
raise ValueError("No adapter loaded. Please load an adapter first.")
348+
349+
if isinstance(adapter_name, str):
350+
adapter_name = [adapter_name]
351+
352+
missing = set(adapter_name) - set(self.peft_config)
353+
if len(missing) > 0:
354+
raise ValueError(
355+
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
356+
f" current loaded adapters are: {list(self.peft_config.keys())}"
357+
)
358+
359+
from peft.tuners.tuners_utils import BaseTunerLayer
360+
361+
_adapters_has_been_set = False
362+
363+
for _, module in self.named_modules():
364+
if isinstance(module, BaseTunerLayer):
365+
if hasattr(module, "set_adapter"):
366+
module.set_adapter(adapter_name)
367+
# Previous versions of PEFT does not support multi-adapter inference
368+
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
369+
raise ValueError(
370+
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
371+
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
372+
)
373+
else:
374+
module.active_adapter = adapter_name
375+
_adapters_has_been_set = True
376+
377+
if not _adapters_has_been_set:
378+
raise ValueError(
379+
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
380+
)
381+
382+
def disable_adapters(self) -> None:
383+
r"""
384+
Disable all adapters attached to the model and fallback to inference with the base model only.
385+
386+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
387+
official documentation: https://huggingface.co/docs/peft
388+
"""
389+
check_peft_version(min_version=MIN_PEFT_VERSION)
390+
391+
if not self._hf_peft_config_loaded:
392+
raise ValueError("No adapter loaded. Please load an adapter first.")
393+
394+
from peft.tuners.tuners_utils import BaseTunerLayer
395+
396+
for _, module in self.named_modules():
397+
if isinstance(module, BaseTunerLayer):
398+
if hasattr(module, "enable_adapters"):
399+
module.enable_adapters(enabled=False)
400+
else:
401+
# support for older PEFT versions
402+
module.disable_adapters = True
403+
404+
def enable_adapters(self) -> None:
405+
"""
406+
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
407+
list of adapters to enable.
408+
409+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
410+
official documentation: https://huggingface.co/docs/peft
411+
"""
412+
check_peft_version(min_version=MIN_PEFT_VERSION)
413+
414+
if not self._hf_peft_config_loaded:
415+
raise ValueError("No adapter loaded. Please load an adapter first.")
416+
417+
from peft.tuners.tuners_utils import BaseTunerLayer
418+
419+
for _, module in self.named_modules():
420+
if isinstance(module, BaseTunerLayer):
421+
if hasattr(module, "enable_adapters"):
422+
module.enable_adapters(enabled=True)
423+
else:
424+
# support for older PEFT versions
425+
module.disable_adapters = False
426+
427+
def active_adapters(self) -> List[str]:
428+
"""
429+
Gets the current list of active adapters of the model.
430+
431+
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
432+
official documentation: https://huggingface.co/docs/peft
433+
"""
434+
check_peft_version(min_version=MIN_PEFT_VERSION)
435+
436+
if not self._hf_peft_config_loaded:
437+
raise ValueError("No adapter loaded. Please load an adapter first.")
438+
439+
from peft.tuners.tuners_utils import BaseTunerLayer
440+
441+
for _, module in self.named_modules():
442+
if isinstance(module, BaseTunerLayer):
443+
return module.active_adapter
444+
295445
def save_pretrained(
296446
self,
297447
save_directory: Union[str, os.PathLike],

0 commit comments

Comments
 (0)