From c6c9f3ab211fbfb4c15dfa2f4fe5828d30fbcbcc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 16:05:35 +0530 Subject: [PATCH 01/47] add: controlnet sdxl. --- examples/controlnet/train_controlnet_sdxl.py | 1215 +++++++++++++++++ .../controlnet/pipeline_controlnet_sd_xl.py | 1163 ++++++++++++++++ 2 files changed, 2378 insertions(+) create mode 100644 examples/controlnet/train_controlnet_sdxl.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py new file mode 100644 index 000000000000..2bac54c02eb0 --- /dev/null +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -0,0 +1,1215 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(controlnet, args, accelerator, weight_dtype, step): + logger.info("Running validation... ") + + controlnet = accelerator.unwrap_model(controlnet) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + + images = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + return image_logs + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="sd_xl_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[args.image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) + add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + prompt_batch = batch[args.caption_column] + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + train_dataset = get_train_dataset(args, accelerator) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + ) + with accelerator.main_process_first(): + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * vae.config.scaling_factor + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation(controlnet, args, accelerator, weight_dtype, global_step) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py new file mode 100644 index 000000000000..46fce339766e --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -0,0 +1,1163 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO (sayakpaul): update this +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + TODO (sayakpaul): Update docstrings. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config( + requires_safety_checker=requires_safety_checker, force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(num_inference_steps): + keeps = [ + 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 7.2 Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ] + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 185b67b09a34dd326c0698eb08279710cd7565fd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 16:14:13 +0530 Subject: [PATCH 02/47] modifications to controlnet. --- src/diffusers/models/controlnet.py | 123 ++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index b0f566020079..87ea39cca4c6 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -131,12 +131,25 @@ class ControlNetModel(ModelMixin, ConfigMixin): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -177,10 +190,15 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", @@ -188,6 +206,7 @@ def __init__( controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, ): super().__init__() @@ -215,6 +234,9 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 @@ -224,16 +246,43 @@ def __init__( # time time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -257,6 +306,29 @@ def __init__( else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], @@ -291,6 +363,7 @@ def __init__( down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -327,6 +400,7 @@ def __init__( self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -356,7 +430,22 @@ def from_unet( The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, @@ -542,6 +631,7 @@ def forward( class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, @@ -564,6 +654,8 @@ def forward( Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): @@ -618,6 +710,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -629,6 +722,30 @@ def forward( class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + # 2. pre-process sample = self.conv_in(sample) From db78a4cb4e3f105cbc7534890f606e25e906e23a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 16:22:24 +0530 Subject: [PATCH 03/47] run styling. --- src/diffusers/models/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 87ea39cca4c6..875348f8b027 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -654,7 +654,7 @@ def forward( Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - added_cond_kwargs (`dict`): + added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. From af8273a0aa7abe29a82526732397fdbfda744326 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 20:15:34 +0530 Subject: [PATCH 04/47] add: __init__.pys --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/controlnet/__init__.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a7fc9a36f271..8e0c135b615e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -172,6 +172,7 @@ StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c3968406ed90..bcd338e5f642 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -50,6 +50,7 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index 76ab63bdb116..d5f7eb6b4fcc 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -16,6 +16,7 @@ from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline if is_transformers_available() and is_flax_available(): From c8b00dea5e24f5c9a3783a432700b72bc78fedbe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 20:19:07 +0530 Subject: [PATCH 05/47] incorporate https://github.com/huggingface/diffusers/pull/4019 changes. --- .../controlnet/pipeline_controlnet_sd_xl.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 46fce339766e..46ba98270d5c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1117,15 +1117,18 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ] + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory - if not use_torch_2_0_or_xformers: + if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) From 68f2c3894d3c317051747221f13dee24b0c1089a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 11 Jul 2023 20:19:55 +0530 Subject: [PATCH 06/47] run make fix-copies. --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 164206d776fa..adce4b42ef79 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -722,6 +722,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 4b86d63231ed8f92c27c33d75a9f7f5167f19653 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 07:39:25 +0530 Subject: [PATCH 07/47] resize the conditioning images. --- examples/controlnet/train_controlnet_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 2bac54c02eb0..79e4d40b0de5 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -116,6 +116,7 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) images = [] From dade6812dfdd5135ed363db989f0657f93be6a42 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 07:45:58 +0530 Subject: [PATCH 08/47] remove autocast. --- examples/controlnet/train_controlnet_sdxl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 79e4d40b0de5..f0951e802ebf 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -121,10 +121,10 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): images = [] for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline( - validation_prompt, validation_image, num_inference_steps=20, generator=generator - ).images[0] + # with torch.autocast("cuda"): + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] images.append(image) image_logs.append( @@ -166,6 +166,10 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): else: logger.warn(f"image logging not implemented for {tracker.name}") + del pipeline + gc.collect() + torch.cuda.empty_cache() + return image_logs From 15d4afd8372d034502c998f32215c7e2af6130f5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 07:52:10 +0530 Subject: [PATCH 09/47] run styling. --- examples/controlnet/train_controlnet_sdxl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index f0951e802ebf..8febd238552d 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -122,9 +122,9 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): for _ in range(args.num_validation_images): # with torch.autocast("cuda"): - image = pipeline( - validation_prompt, validation_image, num_inference_steps=20, generator=generator - ).images[0] + image = pipeline(validation_prompt, validation_image, num_inference_steps=20, generator=generator).images[ + 0 + ] images.append(image) image_logs.append( From 5ed7d3e31fa414e07abadfed900ccf15e8ea72dc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 08:09:16 +0530 Subject: [PATCH 10/47] disable autocast. --- examples/controlnet/train_controlnet_sdxl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 8febd238552d..52089a97c04d 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -74,17 +74,20 @@ def image_grid(imgs, rows, cols): return grid -def log_validation(controlnet, args, accelerator, weight_dtype, step): +def log_validation(controlnet, args, accelerator, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) + # We're loading the pipeline in full-precision to avoid numerical instabilities. + # Using autocast also doesn't help. More details: + # https://github.com/huggingface/diffusers/pull/4038#issuecomment-1631245691 pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, controlnet=controlnet, safety_checker=None, revision=args.revision, - torch_dtype=weight_dtype, + torch_dtype=torch.float32, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) @@ -121,7 +124,6 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): images = [] for _ in range(args.num_validation_images): - # with torch.autocast("cuda"): image = pipeline(validation_prompt, validation_image, num_inference_steps=20, generator=generator).images[ 0 ] @@ -1183,7 +1185,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - image_logs = log_validation(controlnet, args, accelerator, weight_dtype, global_step) + image_logs = log_validation(controlnet, args, accelerator, global_step) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 64b0e20655fec95892f7d5c58a33bdb61e51fd8c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 08:32:53 +0530 Subject: [PATCH 11/47] debugging --- examples/controlnet/train_controlnet_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 52089a97c04d..c296f9586aa3 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -84,10 +84,10 @@ def log_validation(controlnet, args, accelerator, step): # https://github.com/huggingface/diffusers/pull/4038#issuecomment-1631245691 pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, - controlnet=controlnet, + controlnet=controlnet.to(torch.float16), safety_checker=None, revision=args.revision, - torch_dtype=torch.float32, + torch_dtype=torch.float16, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) From f482f46e532d2b8d4657e449f0e94039359e47bb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 08:43:25 +0530 Subject: [PATCH 12/47] device placement. --- examples/controlnet/train_controlnet_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index c296f9586aa3..8b75cf2ae3b4 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -84,7 +84,7 @@ def log_validation(controlnet, args, accelerator, step): # https://github.com/huggingface/diffusers/pull/4038#issuecomment-1631245691 pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, - controlnet=controlnet.to(torch.float16), + controlnet=controlnet.to(accelerator.device, dtype=torch.float16), safety_checker=None, revision=args.revision, torch_dtype=torch.float16, From c2bbb2b7d4ccbf0c2c3759f0a48604f6d0dc6575 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 08:52:27 +0530 Subject: [PATCH 13/47] back to autocast. --- examples/controlnet/train_controlnet_sdxl.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 8b75cf2ae3b4..1d644086ab82 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -74,7 +74,7 @@ def image_grid(imgs, rows, cols): return grid -def log_validation(controlnet, args, accelerator, step): +def log_validation(controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) @@ -84,10 +84,10 @@ def log_validation(controlnet, args, accelerator, step): # https://github.com/huggingface/diffusers/pull/4038#issuecomment-1631245691 pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, - controlnet=controlnet.to(accelerator.device, dtype=torch.float16), + controlnet=controlnet, safety_checker=None, revision=args.revision, - torch_dtype=torch.float16, + torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) @@ -124,9 +124,10 @@ def log_validation(controlnet, args, accelerator, step): images = [] for _ in range(args.num_validation_images): - image = pipeline(validation_prompt, validation_image, num_inference_steps=20, generator=generator).images[ - 0 - ] + with torch.autocast("cuda"): + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] images.append(image) image_logs.append( @@ -1185,7 +1186,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - image_logs = log_validation(controlnet, args, accelerator, global_step) + image_logs = log_validation(controlnet, args, accelerator, weight_dtype, global_step) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From ccb62103e72c666596f8ca5b1f76bb4f3b41d5c5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 08:55:20 +0530 Subject: [PATCH 14/47] remove comment. --- examples/controlnet/train_controlnet_sdxl.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 1d644086ab82..f8e817e900a5 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -79,9 +79,6 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step): controlnet = accelerator.unwrap_model(controlnet) - # We're loading the pipeline in full-precision to avoid numerical instabilities. - # Using autocast also doesn't help. More details: - # https://github.com/huggingface/diffusers/pull/4038#issuecomment-1631245691 pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, controlnet=controlnet, From be13ef5ad0af1c8e0fd45c951471ccb0b0fdc3c1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 16:03:14 +0530 Subject: [PATCH 15/47] save some memory by reusing the vae and unet in the pipeline. --- examples/controlnet/train_controlnet_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index f8e817e900a5..862d9d76965d 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -74,13 +74,15 @@ def image_grid(imgs, rows, cols): return grid -def log_validation(controlnet, args, accelerator, weight_dtype, step): +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, + vae=vae, + unet=unet, controlnet=controlnet, safety_checker=None, revision=args.revision, @@ -1183,7 +1185,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - image_logs = log_validation(controlnet, args, accelerator, weight_dtype, global_step) + image_logs = log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, global_step) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From fbb086a5aac90a8e4f888fdbc67e2fe21639b6d3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Jul 2023 16:05:52 +0530 Subject: [PATCH 16/47] apply styling. --- examples/controlnet/train_controlnet_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 862d9d76965d..0ae93f8365d0 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1185,7 +1185,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - image_logs = log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, global_step) + image_logs = log_validation( + vae, unet, controlnet, args, accelerator, weight_dtype, global_step + ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) From 65a5c452fcce914703aa641060b9cc9c26dc3117 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 21:53:35 +0200 Subject: [PATCH 17/47] Allow low precision sd xl --- src/diffusers/models/autoencoder_kl.py | 1 + .../pipeline_stable_diffusion_xl.py | 42 ++++++++++-------- .../pipeline_stable_diffusion_xl_img2img.py | 43 +++++++++++-------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index ddb9bde0ee0a..56288bfc0ca9 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -82,6 +82,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + upcast_precision: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index b3dcf1b67cda..a7eee64724f9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -537,6 +537,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -799,25 +819,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + self.upcast_vae() + latents = latents.to(self.vae.post_quant_conv.dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 7b0cdfad8c0a..815cd5521b64 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -624,6 +624,27 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -932,25 +953,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + self.upcast_vae() + latents = latents.to(self.vae.post_quant_conv.dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From 43f842cb75b4ce134873522c3de5dc581564d2f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 20:30:17 +0000 Subject: [PATCH 18/47] finish --- src/diffusers/models/autoencoder_kl.py | 2 +- .../pipeline_stable_diffusion_xl.py | 5 ++--- .../pipeline_stable_diffusion_xl_img2img.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 56288bfc0ca9..33326dd80b1b 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -82,7 +82,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, - upcast_precision: float = True, + force_upcast: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a7eee64724f9..2c513a9cfb32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -556,7 +556,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -819,9 +818,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(self.vae.post_quant_conv.dtype) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 815cd5521b64..5c6cfdf18dd6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -542,8 +542,9 @@ def prepare_latents( else: # make sure the VAE is in float32 mode, as it overflows in float16 - image = image.float() - self.vae.to(dtype=torch.float32) + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -559,9 +560,10 @@ def prepare_latents( else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - self.vae.to(dtype) - init_latents = init_latents.to(dtype) + if self.vae.config.force_upcast: + self.vae.to(dtype) + init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: @@ -644,7 +646,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -953,9 +954,9 @@ def __call__( callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.dtype == torch.float16 and self.vae.config.upcast_precision: + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() - latents = latents.to(self.vae.post_quant_conv.dtype) + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] From 7171d423b191129be8da95014568fabd3694d54c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Jul 2023 21:24:54 +0000 Subject: [PATCH 19/47] finish --- .../pipeline_stable_diffusion_upscale.py | 42 ++++++++++--------- .../pipeline_stable_diffusion_xl.py | 1 + .../pipeline_stable_diffusion_xl_img2img.py | 2 +- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index a7255424fb46..59aa87f46c19 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -501,6 +501,25 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() def __call__( self, @@ -746,26 +765,9 @@ def __call__( # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # post-processing if not output_type == "latent": diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2c513a9cfb32..bcdf47b7c983 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -537,6 +537,7 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 5c6cfdf18dd6..f14e7fa457e3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -626,7 +626,7 @@ def _get_add_time_ids( return add_time_ids, add_neg_time_ids - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.upcast_vae + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) From fa22868cfa6f812ed2c03d0301ab4914c54562c3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 10:04:31 +0530 Subject: [PATCH 20/47] changes to accommodate the improved VAE. --- examples/controlnet/train_controlnet_sdxl.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 0ae93f8365d0..50da26d6d75e 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -241,6 +241,12 @@ def parse_args(input_args=None): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_vae_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) parser.add_argument( "--controlnet_model_name_or_path", type=str, @@ -823,9 +829,14 @@ def main(args): text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_name_or_path is None + else args.pretrained_vae_name_or_path + ) vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="vae", + vae_path, + subfolder="vae" if args.pretrained_vae_name_or_path is None else None, revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( From b1c83fdda8bf7c379e7abfb752199fd4249a2294 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 10:09:33 +0530 Subject: [PATCH 21/47] modifications to how we handle vae encoding in the training. --- examples/controlnet/train_controlnet_sdxl.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 50da26d6d75e..35eb6856b933 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -957,7 +957,10 @@ def load_model_hook(models, input_dir): # Move vae, unet and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. - vae.to(accelerator.device, dtype=torch.float32) + if args.pretrained_vae_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) unet.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) @@ -1109,9 +1112,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + if args.pretrained_vae_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - latents = latents.to(weight_dtype) + if args.pretrained_vae_name_or_path is None: + latents = latents.to(weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) From a4e10d89eae7d11bd3a5daad183830f9edf85c74 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 10:17:21 +0530 Subject: [PATCH 22/47] make style --- examples/controlnet/train_controlnet_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 35eb6856b933..f007018c3251 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -960,7 +960,7 @@ def load_model_hook(models, input_dir): if args.pretrained_vae_name_or_path is not None: vae.to(accelerator.device, dtype=weight_dtype) else: - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=torch.float32) unet.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From 2d7c45416fb6511e19606311da15479f8e708c84 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 11:13:39 +0530 Subject: [PATCH 23/47] make existing controlnet fast tests pass. --- src/diffusers/models/controlnet.py | 2 +- .../pipelines/controlnet/multicontrolnet.py | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 875348f8b027..acebee4b40a5 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -656,7 +656,7 @@ def forward( attention_mask (`torch.Tensor`, *optional*, defaults to `None`): added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. - cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 921895b8fd92..2df8c7edd7ae 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -45,17 +45,17 @@ def forward( ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states, - image, - scale, - class_labels, - timestep_cond, - attention_mask, - cross_attention_kwargs, - guess_mode, - return_dict, + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, ) # merge samples From 7296a7f156d471c90551ddb22e9d5fb7e27235b1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 16:41:22 +0530 Subject: [PATCH 24/47] change vae checkpoint cli arg. --- examples/controlnet/train_controlnet_sdxl.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index f007018c3251..9f4809cc4d9e 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -242,7 +242,7 @@ def parse_args(input_args=None): help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( - "--pretrained_vae_name_or_path", + "--pretrained_model_name_or_path", type=str, default=None, help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", @@ -831,12 +831,12 @@ def main(args): ) vae_path = ( args.pretrained_model_name_or_path - if args.pretrained_vae_name_or_path is None - else args.pretrained_vae_name_or_path + if args.pretrained_model_name_or_path is None + else args.pretrained_model_name_or_path ) vae = AutoencoderKL.from_pretrained( vae_path, - subfolder="vae" if args.pretrained_vae_name_or_path is None else None, + subfolder="vae" if args.pretrained_model_name_or_path is None else None, revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( @@ -957,7 +957,7 @@ def load_model_hook(models, input_dir): # Move vae, unet and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. - if args.pretrained_vae_name_or_path is not None: + if args.pretrained_model_name_or_path is not None: vae.to(accelerator.device, dtype=weight_dtype) else: vae.to(accelerator.device, dtype=torch.float32) @@ -1112,13 +1112,13 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): # Convert images to latent space - if args.pretrained_vae_name_or_path is not None: + if args.pretrained_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - if args.pretrained_vae_name_or_path is None: + if args.pretrained_model_name_or_path is None: latents = latents.to(weight_dtype) # Sample noise that we'll add to the latents From 74f74851196ed7da77472c436b992bd6786a63bf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 19:14:53 +0530 Subject: [PATCH 25/47] fix: vae pretrained paths. --- examples/controlnet/train_controlnet_sdxl.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 9f4809cc4d9e..d2bb2b5581cc 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -242,7 +242,7 @@ def parse_args(input_args=None): help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( - "--pretrained_model_name_or_path", + "--pretrained_vae_model_name_or_path", type=str, default=None, help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", @@ -831,12 +831,12 @@ def main(args): ) vae_path = ( args.pretrained_model_name_or_path - if args.pretrained_model_name_or_path is None - else args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( vae_path, - subfolder="vae" if args.pretrained_model_name_or_path is None else None, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, ) unet = UNet2DConditionModel.from_pretrained( @@ -957,7 +957,7 @@ def load_model_hook(models, input_dir): # Move vae, unet and text_encoder to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. - if args.pretrained_model_name_or_path is not None: + if args.pretrained_vae_model_name_or_path is not None: vae.to(accelerator.device, dtype=weight_dtype) else: vae.to(accelerator.device, dtype=torch.float32) @@ -1112,8 +1112,10 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): # Convert images to latent space - if args.pretrained_model_name_or_path is not None: + if args.pretrained_vae_model_name_or_path is not None: pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) else: pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample() From 7d27b1d10e1d83a950662f8aab0ab5ac3da4cd41 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 16 Jul 2023 18:13:04 +0530 Subject: [PATCH 26/47] fix: steps in get_scheduler(). --- examples/controlnet/train_controlnet_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index d2bb2b5581cc..f464af143e47 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1029,8 +1029,8 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) From 3fa4809c72eef9f3e0eac6383259d1cbf8b0659e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 07:08:03 +0530 Subject: [PATCH 27/47] debugging. --- examples/controlnet/train_controlnet_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index d2bb2b5581cc..0e3bb6d3fadf 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1146,6 +1146,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer ) # Predict the noise residual + print( + f"Unet: {unet.dtype}, noisy_latents: {noisy_latents.dtype}, batch['prompt_ids']: {batch['prompt_ids'].dtype}" + ) model_pred = unet( noisy_latents, timesteps, From 523c1eba28eb52116671937b3cf97d4fa37f2468 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 07:25:40 +0530 Subject: [PATCH 28/47] debugging./ --- examples/controlnet/train_controlnet_sdxl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b18a1acecb12..5eff75b90474 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1125,6 +1125,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + print(f"Noise: {noise.dtype}, latents: {latents.dtype}") bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -1133,6 +1134,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + print(f"noisy_latents: {noisy_latents.dtype}") # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) @@ -1146,9 +1148,7 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer ) # Predict the noise residual - print( - f"Unet: {unet.dtype}, noisy_latents: {noisy_latents.dtype}, batch['prompt_ids']: {batch['prompt_ids'].dtype}" - ) + print(f"noisy_latents: {noisy_latents.dtype}") model_pred = unet( noisy_latents, timesteps, From 98ed3c94f903bf9f5d666312abbce0b1bcebe903 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 07:29:24 +0530 Subject: [PATCH 29/47] fix: weight conversion. --- examples/controlnet/train_controlnet_sdxl.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 5eff75b90474..5f5a35a4e37b 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1120,13 +1120,13 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer pixel_values = batch["pixel_values"] latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor - if args.pretrained_model_name_or_path is None: + if args.pretrained_vae_model_name_or_path is None: latents = latents.to(weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - print(f"Noise: {noise.dtype}, latents: {latents.dtype}") bsz = latents.shape[0] + # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() @@ -1134,7 +1134,6 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - print(f"noisy_latents: {noisy_latents.dtype}") # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) @@ -1148,7 +1147,6 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer ) # Predict the noise residual - print(f"noisy_latents: {noisy_latents.dtype}") model_pred = unet( noisy_latents, timesteps, From a397ead27febee46c9b7da3899a36decec6b04c6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 07:49:54 +0530 Subject: [PATCH 30/47] add: docs. --- examples/controlnet/README.md | 11 +- examples/controlnet/README_sdxl.md | 131 ++++++++++++++++++++++ examples/controlnet/requirements_sdxl.txt | 7 ++ 3 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 examples/controlnet/README_sdxl.md create mode 100644 examples/controlnet/requirements_sdxl.txt diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 571e9e708cf2..15b0170d5120 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained( # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) -# remove following line if xformers is not installed +# remove following line if xformers is not installed or when using Torch 2.0. pipe.enable_xformers_memory_efficient_attention() - +# memory optimization. pipe.enable_model_cpu_offload() control_image = load_image("./conditioning_image_1.png") @@ -285,9 +285,8 @@ prompt = "pale golden rod circle with old lace background" # generate image generator = torch.manual_seed(0) image = pipe( - prompt, num_inference_steps=20, generator=generator, image=control_image + prompt, num_inference_steps=20, generator=generator, image=control_image ).images[0] - image.save("./output.png") ``` @@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`). Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident). + +## Support for Stable Diffusion XL + +We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details. diff --git a/examples/controlnet/README_sdxl.md b/examples/controlnet/README_sdxl.md new file mode 100644 index 000000000000..4e4035066f8c --- /dev/null +++ b/examples/controlnet/README_sdxl.md @@ -0,0 +1,131 @@ +# DreamBooth training example for Stable Diffusion XL (SDXL) + +The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/controlnet` folder and run +```bash +pip install -r requirements_sdxl.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + +## Circle filling dataset + +The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. + +## Training + +Our training examples use two test conditioning images. They can be downloaded by running + +```sh +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png + +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub. + +```bash +export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" +export OUTPUT_DIR="path to save model" + +accelerate launch train_controlnet_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --mixed_precision="fp16" \ + --resolution=1024 \ + --learning_rate=1e-5 \ + --max_train_steps=15000 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --validation_steps=100 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --report_to="wandb" \ + --seed=42 \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + +### Inference + +Once training is done, we can perform inference like so: + +```python +from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler +from diffusers.utils import load_image +import torch + +base_model_path = "stabilityai/stable-diffusion-xl-base-0.9" +controlnet_path = "path to controlnet" + +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) +pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + base_model_path, controlnet=controlnet, torch_dtype=torch.float16 +) + +# speed up diffusion process with faster scheduler and memory optimization +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) +# remove following line if xformers is not installed or when using Torch 2.0. +pipe.enable_xformers_memory_efficient_attention() +# memory optimization. +pipe.enable_model_cpu_offload() + +control_image = load_image("./conditioning_image_1.png") +prompt = "pale golden rod circle with old lace background" + +# generate image +generator = torch.manual_seed(0) +image = pipe( + prompt, num_inference_steps=20, generator=generator, image=control_image +).images[0] +image.save("./output.png") +``` + +## Notes + +### Specifying a better VAE + +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). \ No newline at end of file diff --git a/examples/controlnet/requirements_sdxl.txt b/examples/controlnet/requirements_sdxl.txt new file mode 100644 index 000000000000..7a9936ee4003 --- /dev/null +++ b/examples/controlnet/requirements_sdxl.txt @@ -0,0 +1,7 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +invisible-watermark>=0.2.0 \ No newline at end of file From 4c14e88f61704dbe9e99021d8ddc642cd1ef0dbd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 08:57:05 +0530 Subject: [PATCH 31/47] add: limited tests./ --- .../controlnet/pipeline_controlnet_sd_xl.py | 4 +- .../controlnet/test_controlnet_sdxl.py | 181 ++++++++++++++++++ 2 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 tests/pipelines/controlnet/test_controlnet_sdxl.py diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 46ba98270d5c..88a8d263610c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1024,9 +1024,9 @@ def __call__( # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] - for i in range(num_inference_steps): + for i in range(len(timesteps)): keeps = [ - 1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e) + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py new file mode 100644 index 000000000000..010bee004e0e --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + EulerDiscreteScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, +) +from diffusers.utils import randn_tensor, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class ControlNetPipelineSDXLFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionXLControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) From 156dd27e48e231a871ea864b439b49d19d03f6ab Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 17 Jul 2023 11:27:11 +0530 Subject: [PATCH 32/47] add: datasets to the requirements. --- examples/controlnet/requirements_sdxl.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/requirements_sdxl.txt b/examples/controlnet/requirements_sdxl.txt index 7a9936ee4003..68e00742489e 100644 --- a/examples/controlnet/requirements_sdxl.txt +++ b/examples/controlnet/requirements_sdxl.txt @@ -4,4 +4,5 @@ transformers>=4.25.1 ftfy tensorboard Jinja2 -invisible-watermark>=0.2.0 \ No newline at end of file +invisible-watermark>=0.2.0 +datasets \ No newline at end of file From 6611259f009d3635d06541f67669ee3b3b8afd85 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 12:42:09 +0530 Subject: [PATCH 33/47] update docstrings and incorporate the usage of watermarking. --- .../controlnet/pipeline_controlnet_sd_xl.py | 136 ++++-------------- 1 file changed, 25 insertions(+), 111 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 88a8d263610c..c52429e89c35 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -21,7 +21,7 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin @@ -42,82 +42,53 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput +from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# TODO (sayakpaul): update this EXAMPLE_DOC_STRING = """ Examples: ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ... ) - >>> image = np.array(image) - - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() - - >>> pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.manual_seed(0) - >>> image = pipe( - ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image - ... ).images[0] + >>> # To be updated when there's a useful ControlNet checkpoint + >>> # compatible with SDXL. ``` """ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] Args: - TODO (sayakpaul): Update docstrings. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets @@ -126,13 +97,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -144,29 +109,10 @@ def __init__( unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, force_zeros_for_empty_prompt: bool = True, ): super().__init__() - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) @@ -179,17 +125,14 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.register_to_config( - requires_safety_checker=requires_safety_checker, force_zeros_for_empty_prompt=force_zeros_for_empty_prompt - ) + self.watermark = StableDiffusionXLWatermarker() + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -245,9 +188,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0): for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -266,10 +206,6 @@ def enable_model_cpu_offload(self, gpu_id=0): for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - if self.safety_checker is not None: - # the safety checker can offload the vae again - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) @@ -482,21 +418,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( @@ -890,10 +811,8 @@ def __call__( Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet @@ -1144,23 +1063,18 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + return StableDiffusionXLPipelineOutput(images=image) - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (image, has_nsfw_concept) + return (image,) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionXLPipelineOutput(images=image) From 695c3ac1cac1267d21ae72231dc089afd2735366 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 12:43:53 +0530 Subject: [PATCH 34/47] incorporate fix from #4083 --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c52429e89c35..f864e0f8c6fa 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1061,6 +1061,11 @@ def __call__( self.controlnet.to("cpu") torch.cuda.empty_cache() + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: From 680e2418e798a84375b13503d4a857769aad0b70 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 13:21:49 +0530 Subject: [PATCH 35/47] fix watermarking dependency handling. --- src/diffusers/__init__.py | 8 +++++ .../pipelines/controlnet/__init__.py | 6 +++- ...formers_and_invisible_watermark_objects.py | 32 +------------------ .../controlnet/test_controlnet_sdxl.py | 2 -- 4 files changed, 14 insertions(+), 34 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1ee80b4d9c5b..8afeb5dc5130 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -205,6 +205,14 @@ StableDiffusionXLPipeline, ) +try: + if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 +else: + from .pipelines import StableDiffusionXLControlNetPipeline + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index d5f7eb6b4fcc..cd0b9a724e5e 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -1,11 +1,16 @@ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, + is_invisible_watermark_available, is_torch_available, is_transformers_available, ) +if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + + try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() @@ -16,7 +21,6 @@ from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline - from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline if is_transformers_available() and is_flax_available(): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py index 15ef1989a58b..497afba1b886 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -2,37 +2,7 @@ from ..utils import DummyObject, requires_backends -class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLPipeline(metaclass=DummyObject): +class StableDiffusionXLControlNetPipelin(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 010bee004e0e..196bdc303a49 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -138,8 +138,6 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - "safety_checker": None, - "feature_extractor": None, } return components From 842f5dd65258614f455653377bd79fa0756c1dbc Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 13:26:22 +0530 Subject: [PATCH 36/47] run make-fix-copies. --- ...my_torch_and_transformers_and_invisible_watermark_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py index 497afba1b886..4049e854b4cc 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -2,7 +2,7 @@ from ..utils import DummyObject, requires_backends -class StableDiffusionXLControlNetPipelin(metaclass=DummyObject): +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): From 0c904ce0287c54f729be5e1e7bceeb469c7c6694 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 13:55:06 +0530 Subject: [PATCH 37/47] Empty-Commit From ae48efcbfb36bd3ffe67ab673ec6100ca8e3fc39 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 16:32:48 +0530 Subject: [PATCH 38/47] Update requirements_sdxl.txt --- examples/controlnet/requirements_sdxl.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/requirements_sdxl.txt b/examples/controlnet/requirements_sdxl.txt index 68e00742489e..0192e03ddb4c 100644 --- a/examples/controlnet/requirements_sdxl.txt +++ b/examples/controlnet/requirements_sdxl.txt @@ -5,4 +5,5 @@ ftfy tensorboard Jinja2 invisible-watermark>=0.2.0 -datasets \ No newline at end of file +datasets +wandb From 74dda654f4fccde77aa168ca5b762866849d7ec0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 16:47:35 +0530 Subject: [PATCH 39/47] remove vae upcasting part. --- .../controlnet/pipeline_controlnet_sd_xl.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index f864e0f8c6fa..c03749c6504a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -686,6 +686,26 @@ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, d add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1033,27 +1053,6 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(latents.dtype) - self.vae.decoder.conv_in.to(latents.dtype) - self.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From 08153a994524c22aa9b456df55d8c101094ca604 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 16:53:25 +0530 Subject: [PATCH 40/47] Apply suggestions from code review Co-authored-by: Patrick von Platen --- .../controlnet/pipeline_controlnet_sd_xl.py | 51 ------------------- 1 file changed, 51 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c03749c6504a..28b0380268d7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -170,24 +170,6 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: - cpu_offload(cpu_offloaded_model, device) - def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -212,25 +194,6 @@ def enable_model_cpu_offload(self, gpu_id=0): # We'll offload the last model manually. self.final_offload_hook = hook - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, @@ -418,20 +381,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature From dc4839ee29ad7dc3438573313e93b695dd27821f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 16:54:18 +0530 Subject: [PATCH 41/47] run make style --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 28b0380268d7..835be9f4445b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -14,7 +14,6 @@ import inspect -import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np From 54fdb56c8a1efd98bfc3c3b76569fe1dba308051 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:02:01 +0530 Subject: [PATCH 42/47] run make fix-copies. --- src/diffusers/__init__.py | 1 - src/diffusers/pipelines/__init__.py | 9 ++++++++- .../utils/dummy_torch_and_transformers_objects.py | 15 --------------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8afeb5dc5130..68c6ecd24979 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -175,7 +175,6 @@ StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, - StableDiffusionXLControlNetPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3de9049afaa5..1fdc49584b49 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -50,7 +50,6 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, - StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, @@ -127,6 +126,14 @@ StableDiffusionXLPipeline, ) +try: + if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 +else: + from .controlnet import StableDiffusionXLControlNetPipeline + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7a5a3310b922..016760337c69 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -737,21 +737,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 5d53cb8b1ad8f0929a5f33fd933e83efd832116c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:25:31 +0530 Subject: [PATCH 43/47] disable suppot for multicontrolnet. --- .../controlnet/pipeline_controlnet_sd_xl.py | 76 +------------------ 1 file changed, 2 insertions(+), 74 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 835be9f4445b..fd4338946e6f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -106,14 +106,14 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, ): super().__init__() if isinstance(controlnet, (list, tuple)): - controlnet = MultiControlNetModel(controlnet) + raise ValueError("MultiControlNet is not yet supported.") self.register_modules( vae=vae, @@ -444,15 +444,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -463,25 +454,6 @@ def check_inputs( and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) else: assert False @@ -493,21 +465,6 @@ def check_inputs( ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) else: assert False @@ -516,12 +473,6 @@ def check_inputs( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( @@ -822,9 +773,6 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) @@ -866,26 +814,6 @@ def __call__( guess_mode=guess_mode, ) height, width = image.shape[-2:] - elif isinstance(controlnet, MultiControlNetModel): - images = [] - - for image_ in image: - image_ = self.prepare_image( - image=image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - images.append(image_) - - image = images - height, width = image[0].shape[-2:] else: assert False From 740944ae22932f57b3bbc84d0b808eb58252b07e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:27:53 +0530 Subject: [PATCH 44/47] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/__init__.py | 7 ------- src/diffusers/pipelines/__init__.py | 7 ------- 2 files changed, 14 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 68c6ecd24979..3d34a8877bba 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -203,13 +203,6 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) - -try: - if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 -else: from .pipelines import StableDiffusionXLControlNetPipeline try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1fdc49584b49..1edcbb89afe9 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -125,13 +125,6 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) - -try: - if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 -else: from .controlnet import StableDiffusionXLControlNetPipeline try: From e35cf2b4fe8be88075c9c7dc069db554e83d6e42 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:28:56 +0530 Subject: [PATCH 45/47] run make fix-copies. --- ...formers_and_invisible_watermark_objects.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py index 4049e854b4cc..1b42f8d88d8c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -2,6 +2,51 @@ from ..utils import DummyObject, requires_backends +class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + +class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + +class StableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] From d7aecd2d3a15303b8acbde05b11c6a0133946198 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:30:54 +0530 Subject: [PATCH 46/47] dtyle/. --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3d34a8877bba..445e63b4d406 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -199,11 +199,11 @@ from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .pipelines import ( + StableDiffusionXLControlNetPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) - from .pipelines import StableDiffusionXLControlNetPipeline try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1edcbb89afe9..802ae4f5bc94 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -120,12 +120,12 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: + from .controlnet import StableDiffusionXLControlNetPipeline from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) - from .controlnet import StableDiffusionXLControlNetPipeline try: if not is_onnx_available(): From f129bc4028f44a0d087f29e3363bdff781f99371 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jul 2023 17:33:51 +0530 Subject: [PATCH 47/47] fix-copies. --- ...ch_and_transformers_and_invisible_watermark_objects.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py index 1b42f8d88d8c..3b4d59c537fa 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -2,7 +2,7 @@ from ..utils import DummyObject, requires_backends -class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): @@ -17,7 +17,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) -class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): +class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): @@ -32,7 +32,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) -class StableDiffusionXLPipeline(metaclass=DummyObject): +class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): @@ -47,7 +47,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) -class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): +class StableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs):