diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9af81aa5a95d..0bf3333a6209 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import gc import hashlib import itertools import logging @@ -30,7 +31,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, model_info, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -48,7 +49,13 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + SlicedAttnAddedKVProcessor, +) from diffusers.optimization import get_scheduler from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -108,6 +115,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -387,6 +398,24 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -409,6 +438,9 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + return args @@ -428,10 +460,16 @@ def __init__( class_num=None, size=512, center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -473,39 +511,50 @@ def __getitem__(self, index): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -516,6 +565,10 @@ def collate_fn(examples, with_prior_preservation=False): "input_ids": input_ids, "pixel_values": pixel_values, } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + return batch @@ -536,6 +589,50 @@ def __getitem__(self, index): return example +def model_has_vae(args): + config_file_name = os.path.join("vae", AutoencoderKL.config_name) + if os.path.isdir(args.pretrained_model_name_or_path): + config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + return os.path.isfile(config_file_name) + else: + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -656,13 +753,20 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + else: + vae = None + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # We only train the additional adapter LoRA layers - vae.requires_grad_(False) + if vae is not None: + vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) @@ -676,7 +780,8 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: @@ -707,7 +812,7 @@ def main(args): # Set correct lora layers unet_lora_attn_procs = {} - for name in unet.attn_processors.keys(): + for name, attn_processor in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] @@ -718,7 +823,12 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - unet_lora_attn_procs[name] = LoRAAttnProcessor( + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = LoRAAttnProcessor + + unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) @@ -783,6 +893,44 @@ def main(args): eps=args.adam_epsilon, ) + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -793,6 +941,9 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( @@ -896,32 +1047,53 @@ def main(args): continue with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -988,19 +1160,40 @@ def main(args): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) + pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) ] for tracker in accelerator.trackers: @@ -1024,7 +1217,8 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - text_encoder = text_encoder.to(torch.float32) + if text_encoder is not None: + text_encoder = text_encoder.to(torch.float32) LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, @@ -1036,7 +1230,20 @@ def main(args): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) # load attention processors diff --git a/examples/test_examples.py b/examples/test_examples.py index 648c2cb8a1b7..d9e7de717f47 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -292,6 +292,41 @@ def test_dreambooth_lora_with_text_encoder(self): is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) self.assertTrue(is_correct_naming) + def test_dreambooth_lora_if_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b4b0f4bb3bd6..7620270b4dc2 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -21,9 +21,13 @@ from huggingface_hub import hf_hub_download from .models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + SlicedAttnAddedKVProcessor, ) from .utils import ( DIFFUSERS_CACHE, @@ -250,10 +254,22 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processors[key] = LoRAAttnProcessor( + attn_processor = self + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + if isinstance( + attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) + ): + cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnAddedKVProcessor + else: + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnProcessor + + attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank ) attn_processors[key].load_state_dict(value_dict) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7ac88b17999a..6701122fc13b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -669,6 +669,73 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class LoRAAttnAddedKVProcessor(nn.Module): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states) + value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + class XFormersAttnProcessor: def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op @@ -1022,6 +1089,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor2_0, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, ] diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py index 448389b9f1f6..cd1015dc03bb 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -7,6 +7,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -85,7 +86,7 @@ """ -class IFPipeline(DiffusionPipeline): +class IFPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -804,6 +805,9 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index 231ee02b1bb8..6bae2071173b 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -9,6 +9,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -109,7 +110,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFImg2ImgPipeline(DiffusionPipeline): +class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -929,6 +930,9 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index 6986387ca995..9c1f71126ac5 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -9,6 +9,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -112,7 +113,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFInpaintingPipeline(DiffusionPipeline): +class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -1044,6 +1045,9 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index adada63b83f7..e48d8a46423e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -560,7 +560,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}."