Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if dreambooth lora #3360

Merged
merged 15 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 246 additions & 39 deletions examples/dreambooth/train_dreambooth_lora.py

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,41 @@ def test_dreambooth_lora_with_text_encoder(self):
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)

def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)

def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from huggingface_hub import hf_hub_download

from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
SlicedAttnAddedKVProcessor,
)
from .utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -250,10 +254,22 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRAAttnProcessor(
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)

if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor

attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
Expand Down
68 changes: 68 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,73 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states


class LoRAAttnAddedKVProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
encoder_hidden_states
)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
encoder_hidden_states
)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual

return hidden_states


class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
Expand Down Expand Up @@ -1022,6 +1089,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
AttnAddedKVProcessor2_0,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
williamberman marked this conversation as resolved.
Show resolved Hide resolved
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
]
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer

from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
Expand Down Expand Up @@ -85,7 +86,7 @@
"""


class IFPipeline(DiffusionPipeline):
class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel

Expand Down Expand Up @@ -804,6 +805,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
williamberman marked this conversation as resolved.
Show resolved Hide resolved

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer

from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
Expand Down Expand Up @@ -109,7 +110,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
"""


class IFImg2ImgPipeline(DiffusionPipeline):
class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel

Expand Down Expand Up @@ -929,6 +930,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)

# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer

from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
"""


class IFInpaintingPipeline(DiffusionPipeline):
class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel

Expand Down Expand Up @@ -1044,6 +1045,9 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)

if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)

# compute the previous noisy sample x_t -> x_t-1
prev_intermediate_images = intermediate_images

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def _encode_prompt(
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
Expand Down
Loading