diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 5d2107f024d1..efcfb39ab4c4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import gc import hashlib import itertools import logging @@ -30,7 +31,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, model_info, upload_folder from packaging import version from PIL import Image from torch.utils.data import Dataset @@ -93,31 +94,61 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode f.write(yaml + model_card) -def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): +def log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds +): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if vae is not None: + pipeline_args["vae"] = vae + # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, unet=accelerator.unwrap_model(unet), - vae=vae, revision=args.revision, torch_dtype=weight_dtype, + **pipeline_args, ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): with torch.autocast("cuda"): - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] images.append(image) for tracker in accelerator.trackers: @@ -155,6 +186,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -459,6 +494,27 @@ def parse_args(input_args=None): " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." ), ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + ) if input_args is not None: args = parser.parse_args(input_args) @@ -481,6 +537,9 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + return args @@ -500,10 +559,16 @@ def __init__( class_num=None, size=512, center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -545,40 +610,52 @@ def __getitem__(self, index): if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -588,6 +665,10 @@ def collate_fn(examples, with_prior_preservation=False): "input_ids": input_ids, "pixel_values": pixel_values, } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + return batch @@ -608,6 +689,50 @@ def __getitem__(self, index): return example +def model_has_vae(args): + config_file_name = os.path.join("vae", AutoencoderKL.config_name) + if os.path.isdir(args.pretrained_model_name_or_path): + config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + return os.path.isfile(config_file_name) + else: + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -727,7 +852,14 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + else: + vae = None + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) @@ -761,7 +893,9 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - vae.requires_grad_(False) + if vae is not None: + vae.requires_grad_(False) + if not args.train_text_encoder: text_encoder.requires_grad_(False) @@ -835,6 +969,44 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -845,6 +1017,9 @@ def load_model_hook(models, input_dir): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( @@ -890,8 +1065,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder and text_encoder is not None: text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -961,37 +1138,55 @@ def load_model_hook(models, input_dir): continue with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - # Sample noise that we'll add to the latents + if vae is not None: + # Convert images to latent space + model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values + + # Sample noise that we'll add to the model input if args.offset_noise: - noise = torch.randn_like(latents) + 0.1 * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=latents.device + noise = torch.randn_like(model_input) + 0.1 * torch.randn( + model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device ) else: - noise = torch.randn_like(latents) - bsz = latents.shape[0] + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample + + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -1037,7 +1232,16 @@ def load_model_hook(models, input_dir): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, + validation_prompt_encoder_hidden_states, + validation_prompt_negative_prompt_embeds, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -1050,12 +1254,34 @@ def load_model_hook(models, input_dir): # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if args.skip_save_text_encoder: + pipeline_args["text_encoder"] = None + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, + **pipeline_args, ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/test_examples.py b/examples/test_examples.py index d9e7de717f47..59c96f44fe93 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -147,6 +147,32 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_if(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_checkpointing(self): instance_prompt = "photo" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 7b76dd7e37bd..75d9eb3e03df 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1507,16 +1507,33 @@ def forward( cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) output_states = output_states + (hidden_states,) @@ -2593,15 +2610,33 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: