From 38064ed4ccdb2cb278cc7716db742649426ecf77 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:02:51 +0530 Subject: [PATCH 01/45] add: initial implementation of the pix2pix instruct training script. --- examples/instruct_pix2pix/README.md | 1 + .../train_instruct_pix2pix.py | 1003 +++++++++++++++++ 2 files changed, 1004 insertions(+) create mode 100644 examples/instruct_pix2pix/README.md create mode 100644 examples/instruct_pix2pix/train_instruct_pix2pix.py diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md new file mode 100644 index 000000000000..9c86e02ee67f --- /dev/null +++ b/examples/instruct_pix2pix/README.md @@ -0,0 +1 @@ +# Training InstructPix2Pix \ No newline at end of file diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py new file mode 100644 index 000000000000..5eae620b71fb --- /dev/null +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -0,0 +1,1003 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +"""Script to fine-tune Stable Diffusion for InstructPix2Pix.""" + +import argparse +import logging +import math +import os +import random +from pathlib import Path +from typing import Optional + +import accelerate +import datasets +import numpy as np +import PIL +import requests +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.14.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "sayakpaul/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), +} +LAYER_TO_FILL = "conv_in.weight" +NULL_PROMPT = "" +WANDB_TABLE_COL_NAMES = ["original_image", "edit_prompt", "edited_image"] + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--original_image_column", + type=str, + default="image", + help="The column of the dataset containing the original image on which edits where made.", + ) + parser.add_argument( + "--edited_image_column", + type=str, + default="image", + help="The column of the dataset containing the edited image.", + ) + parser.add_argument( + "--edit_prompt_column", + type=str, + default="text", + help="The column of the dataset containing the edit instruction.", + ) + parser.add_argument( + "--validation_image_url", + type=str, + default=None, + help="URL to the original image that you would like to edit (used during inference for debugging purposes).", + ) + parser.add_argument( + "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="instruct-pix2pix-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=256, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--conditioning_dropout_prob", + type=float, + default=None, + help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def initialize_unet(unet: UNet2DConditionModel, instruct_pix2pix_unet: UNet2DConditionModel): + pretrained_unet_state_dict = unet.state_dict() + instruct_pix2pix_unet_state_dict = instruct_pix2pix_unet.state_dict() + for k in pretrained_unet_state_dict: + if k == LAYER_TO_FILL: + instruct_pix2pix_unet_state_dict[k].zero_() + instruct_pix2pix_unet_state_dict[k][:, :4, :, :].copy_(pretrained_unet_state_dict[k]) + else: + instruct_pix2pix_unet_state_dict[k].copy_(pretrained_unet_state_dict[k]) + instruct_pix2pix_unet.load_state_dict(instruct_pix2pix_unet_state_dict) + return instruct_pix2pix_unet + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + logging_dir=logging_dir, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + if accelerator.is_main_process: + instruct_pix2pix_config = dict(unet.config) + instruct_pix2pix_config.update({"in_channels": 8}) + + instruct_pix2pix_unet = UNet2DConditionModel.from_config(instruct_pix2pix_config) + + if accelerator.is_main_process: + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + instruct_pix2pix_unet = initialize_unet(unet, instruct_pix2pix_unet) + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_config(instruct_pix2pix_config) + if accelerator._is_main_process: + ema_unet = initialize_unet(unet, ema_unet) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + # Remove the `unet` as we don't need it. + del unet + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + instruct_pix2pix_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + instruct_pix2pix_unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + instruct_pix2pix_unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.original_image_column is None: + original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + original_image_column = args.original_image_column + if original_image_column not in column_names: + raise ValueError( + f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edit_prompt_column is None: + edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + edit_prompt_column = args.edit_prompt_column + if edit_prompt_column not in column_names: + raise ValueError( + f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edited_image_column is None: + edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] + else: + edited_image_column = args.edited_image_column + if edited_image_column not in column_names: + raise ValueError( + f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[edit_prompt_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{edit_prompt_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + original_images = [image.convert("RGB") for image in examples[original_image_column]] + edited_images = [image.convert("RGB") for image in examples[edited_image_column]] + examples["original_pixel_values"] = [train_transforms(image) for image in original_images] + examples["edited_pixel_values"] = [train_transforms(image) for image in edited_images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) + original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() + edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) + edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return { + "original_pixel_values": original_pixel_values, + "edited_pixel_values": edited_pixel_values, + "input_ids": input_ids, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + instruct_pix2pix_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + instruct_pix2pix_unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("instruct-pix2pix", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + instruct_pix2pix_unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(instruct_pix2pix_unet): + # We want to learn the denoising process w.r.t the edited images which + # are conditioned on the original image (which was edited) and the edit instruction. + # So, first, convert images to latent space. + latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning. + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + logger.info(f"Using condiitioning dropout with prob: {args.conditioning_dropout_prob}.") + random_p = torch.rand(bsz, device=latents.device) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = text_encoder(tokenize_captions([NULL_PROMPT]))[0] + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = instruct_pix2pix_unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(instruct_pix2pix_unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(instruct_pix2pix_unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if ( + (args.validation_image_url is not None) + and (args.validation_prompt is not None) + and (epoch % args.validation_epochs == 0) + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(instruct_pix2pix_unet.parameters()) + ema_unet.copy_to(instruct_pix2pix_unet.parameters()) + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=instruct_pix2pix_unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + original_image = download_image(args.validation_image_url) + edited_images = [] + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"validation": wandb_table}) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(instruct_pix2pix_unet.parameters()) + + del pipeline + torch.cuda.empty_cache() + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(instruct_pix2pix_unet) + if args.use_ema: + ema_unet.copy_to(instruct_pix2pix_unet.parameters()) + + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + + if args.validation_prompt is not None: + edited_images = [] + pipeline.torch_dtype = weight_dtype + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data( + wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt + ) + tracker.log({"test": wandb_table}) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 2c825977f7d811abb7f535e0c8a61cbe04d0f42c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:12:42 +0530 Subject: [PATCH 02/45] shorten cli arg. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 5eae620b71fb..5d727fd1f73d 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -124,7 +124,7 @@ def parse_args(): help="The column of the dataset containing the edit instruction.", ) parser.add_argument( - "--validation_image_url", + "--val_image_url", type=str, default=None, help="URL to the original image that you would like to edit (used during inference for debugging purposes).", @@ -899,7 +899,7 @@ def collate_fn(examples): if accelerator.is_main_process: if ( - (args.validation_image_url is not None) + (args.val_image_url is not None) and (args.validation_prompt is not None) and (epoch % args.validation_epochs == 0) ): @@ -923,7 +923,7 @@ def collate_fn(examples): # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - original_image = download_image(args.validation_image_url) + original_image = download_image(args.val_image_url) edited_images = [] with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): for _ in range(args.num_validation_images): From 675f399b415eec05901116ecf10022b1d0d777d3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:23:54 +0530 Subject: [PATCH 03/45] fix: main process check. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 5d727fd1f73d..9c629489b380 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -496,7 +496,7 @@ def main(): # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_config(instruct_pix2pix_config) - if accelerator._is_main_process: + if accelerator.is_main_process: ema_unet = initialize_unet(unet, ema_unet) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) From 534aa5605ed4b43b03b587db66cf67cd44fa4037 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:29:48 +0530 Subject: [PATCH 04/45] fix: dataset column names. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 9c629489b380..3f8f44659621 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -108,19 +108,19 @@ def parse_args(): parser.add_argument( "--original_image_column", type=str, - default="image", + default="input_image", help="The column of the dataset containing the original image on which edits where made.", ) parser.add_argument( "--edited_image_column", type=str, - default="image", + default="edited_image", help="The column of the dataset containing the edited image.", ) parser.add_argument( "--edit_prompt_column", type=str, - default="text", + default="edit_prompt", help="The column of the dataset containing the edit instruction.", ) parser.add_argument( From 9b3bbcec6b9591c2a43618c31b7b7c659e5bc15f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:34:54 +0530 Subject: [PATCH 05/45] simplify tokenization. --- .../instruct_pix2pix/train_instruct_pix2pix.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 3f8f44659621..4af5df0e2e55 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -641,18 +641,7 @@ def load_model_hook(models, input_dir): # Preprocessing the datasets. # We need to tokenize input captions and transform the images. - def tokenize_captions(examples, is_train=True): - captions = [] - for caption in examples[edit_prompt_column]: - if isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - # take a random caption if there are multiple - captions.append(random.choice(caption) if is_train else caption[0]) - else: - raise ValueError( - f"Caption column `{edit_prompt_column}` should contain either strings or lists of strings." - ) + def tokenize_captions(captions): inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) @@ -674,7 +663,8 @@ def preprocess_train(examples): edited_images = [image.convert("RGB") for image in examples[edited_image_column]] examples["original_pixel_values"] = [train_transforms(image) for image in original_images] examples["edited_pixel_values"] = [train_transforms(image) for image in edited_images] - examples["input_ids"] = tokenize_captions(examples) + captions = [caption for caption in examples[edit_prompt_column]] + examples["input_ids"] = tokenize_captions(captions) return examples with accelerator.main_process_first(): From a64f1bc6a13e00d6f6cb05c57f1e779a0f1b92ba Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:40:08 +0530 Subject: [PATCH 06/45] proper placement of null conditions. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 4af5df0e2e55..d911ddf817b9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -826,7 +826,7 @@ def collate_fn(examples): prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) # Final text conditioning. - null_conditioning = text_encoder(tokenize_captions([NULL_PROMPT]))[0] + null_conditioning = text_encoder(tokenize_captions([NULL_PROMPT]).to(accelerator.device))[0] encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. From 8799d777651028826abf313916ea15a60adec461 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:41:55 +0530 Subject: [PATCH 07/45] apply styling. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index d911ddf817b9..f371c338071e 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -19,13 +19,11 @@ import logging import math import os -import random from pathlib import Path from typing import Optional import accelerate import datasets -import numpy as np import PIL import requests import torch From e7aaa2e80eb87d535e06fbba551dedb3039629b0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 11:45:19 +0530 Subject: [PATCH 08/45] remove debugging message for conditioning do. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index f371c338071e..3181033ff2e5 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -818,7 +818,6 @@ def collate_fn(examples): # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. if args.conditioning_dropout_prob is not None: - logger.info(f"Using condiitioning dropout with prob: {args.conditioning_dropout_prob}.") random_p = torch.rand(bsz, device=latents.device) # Sample masks for the edit prompts. prompt_mask = random_p < 2 * args.conditioning_dropout_prob From e7794436749f3f135b9d85454058e4c36a10e2b6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 13:46:07 +0530 Subject: [PATCH 09/45] complete license. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 3181033ff2e5..c96c100001f8 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -12,6 +12,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +# limitations under the License. """Script to fine-tune Stable Diffusion for InstructPix2Pix.""" From b1e9e587447939ca5eabcbd700e81bec45bc7aa1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 13:47:00 +0530 Subject: [PATCH 10/45] add: requirements.tzt --- examples/instruct_pix2pix/requirements.txt | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 examples/instruct_pix2pix/requirements.txt diff --git a/examples/instruct_pix2pix/requirements.txt b/examples/instruct_pix2pix/requirements.txt new file mode 100644 index 000000000000..d0adc48197d3 --- /dev/null +++ b/examples/instruct_pix2pix/requirements.txt @@ -0,0 +1,7 @@ +accelerate +torchvision +transformers>=4.25.1 +datasets +ftfy +tensorboard +Jinja2 \ No newline at end of file From b16bd74204c01dd2e4ce30499655d6880067d1d1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Feb 2023 14:34:06 +0530 Subject: [PATCH 11/45] wandb column name order. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index c96c100001f8..a7c50da7af9b 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -59,7 +59,7 @@ } LAYER_TO_FILL = "conv_in.weight" NULL_PROMPT = "" -WANDB_TABLE_COL_NAMES = ["original_image", "edit_prompt", "edited_image"] +WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"] def parse_args(): From fe8f5c878a36067b7cd71ee7b2bf754a5cea3bac Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Mar 2023 16:18:44 +0530 Subject: [PATCH 12/45] fix: augmentation. --- .../train_instruct_pix2pix.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index a7c50da7af9b..c286cfbdda66 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -25,6 +25,7 @@ import accelerate import datasets +import numpy as np import PIL import requests import torch @@ -386,6 +387,11 @@ def initialize_unet(unet: UNet2DConditionModel, instruct_pix2pix_unet: UNet2DCon return instruct_pix2pix_unet +def convert_to_np(image, resolution): + image = image.convert("RGB").resize((resolution, resolution)) + return np.array(image).transpose(2, 0, 1) + + def download_image(url): image = PIL.Image.open(requests.get(url, stream=True).raw) image = PIL.ImageOps.exif_transpose(image) @@ -649,19 +655,36 @@ def tokenize_captions(captions): # Preprocessing the datasets. train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), ] ) + def preprocess_images(examples): + original_images = np.concatenate([convert_to_np(image) for image in examples[original_image_column]]) + edited_images = np.concatenate([convert_to_np(image) for image in examples[edited_image_column]]) + # We need to ensure that the original and the edited images undergo the same + # augmentation transforms. + images = np.concatenate([original_images, edited_images]) + images = torch.tensor(images) + images = 2 * (images / 255) - 1 + return train_transforms(images) + def preprocess_train(examples): - original_images = [image.convert("RGB") for image in examples[original_image_column]] - edited_images = [image.convert("RGB") for image in examples[edited_image_column]] - examples["original_pixel_values"] = [train_transforms(image) for image in original_images] - examples["edited_pixel_values"] = [train_transforms(image) for image in edited_images] + # Preprocess images. + preprocessed_images = preprocess_images(examples) + # Since the original and edited images were concatenated before + # applying the transformations, we need to separate them and reshape + # them accordingly. + original_images, edited_images = preprocessed_images.chunk(2) + original_images = original_images.reshape(-1, 3, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution) + + # Collate the preprocessed images into the `examples`. + examples["original_pixel_values"] = original_images + examples["edited_pixel_values"] = edited_images + + # Preprocess the captions. captions = [caption for caption in examples[edit_prompt_column]] examples["input_ids"] = tokenize_captions(captions) return examples From 03ab82d65a438c551cf13c200f4111fb43716252 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Mar 2023 16:19:27 +0530 Subject: [PATCH 13/45] change: dataset_id. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index c286cfbdda66..46dde2456534 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -56,7 +56,7 @@ logger = get_logger(__name__, log_level="INFO") DATASET_NAME_MAPPING = { - "sayakpaul/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), + "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), } LAYER_TO_FILL = "conv_in.weight" NULL_PROMPT = "" From 0c9fd19e80920279f3b8fe450264685f6995c26e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Mar 2023 16:35:59 +0530 Subject: [PATCH 14/45] fix: convert_to_np() call. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 46dde2456534..2ec0d08b88a0 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -661,8 +661,12 @@ def tokenize_captions(captions): ) def preprocess_images(examples): - original_images = np.concatenate([convert_to_np(image) for image in examples[original_image_column]]) - edited_images = np.concatenate([convert_to_np(image) for image in examples[edited_image_column]]) + original_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[original_image_column]] + ) + edited_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] + ) # We need to ensure that the original and the edited images undergo the same # augmentation transforms. images = np.concatenate([original_images, edited_images]) From e6d09f02203735ecd015c7db74f000dab486915e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Mar 2023 16:38:18 +0530 Subject: [PATCH 15/45] fix: reshaping. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 2ec0d08b88a0..03af392a3f53 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -681,8 +681,8 @@ def preprocess_train(examples): # applying the transformations, we need to separate them and reshape # them accordingly. original_images, edited_images = preprocessed_images.chunk(2) - original_images = original_images.reshape(-1, 3, args.resolution) - edited_images = edited_images.reshape(-1, 3, args.resolution) + original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) # Collate the preprocessed images into the `examples`. examples["original_pixel_values"] = original_images From 66f4b315df6b340f5ad677d4a0387e79499d9989 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 07:19:58 +0530 Subject: [PATCH 16/45] fix: final ema copy. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 03af392a3f53..6795c821d17d 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,7 +973,7 @@ def collate_fn(examples): if accelerator.is_main_process: unet = accelerator.unwrap_model(instruct_pix2pix_unet) if args.use_ema: - ema_unet.copy_to(instruct_pix2pix_unet.parameters()) + ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, From daaff4279065948f3e016ea5f5b9bda66ac05480 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 08:03:00 +0530 Subject: [PATCH 17/45] Apply suggestions from code review Co-authored-by: Patrick von Platen --- examples/instruct_pix2pix/requirements.txt | 3 +-- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/instruct_pix2pix/requirements.txt b/examples/instruct_pix2pix/requirements.txt index d0adc48197d3..176ef92a1424 100644 --- a/examples/instruct_pix2pix/requirements.txt +++ b/examples/instruct_pix2pix/requirements.txt @@ -3,5 +3,4 @@ torchvision transformers>=4.25.1 datasets ftfy -tensorboard -Jinja2 \ No newline at end of file +tensorboard \ No newline at end of file diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 6795c821d17d..0ea9145b6a28 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -517,7 +517,7 @@ def main(): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) - instruct_pix2pix_unet.enable_xformers_memory_efficient_attention() + unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") @@ -851,7 +851,7 @@ def collate_fn(examples): prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) # Final text conditioning. - null_conditioning = text_encoder(tokenize_captions([NULL_PROMPT]).to(accelerator.device))[0] + null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) # Sample masks for the original images. From 090a03206d6bddfc85050cd86949c1aadc3f5631 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 08:18:28 +0530 Subject: [PATCH 18/45] address PR comments. --- .../train_instruct_pix2pix.py | 77 ++++++++----------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 0ea9145b6a28..ce0998fb1129 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -29,6 +29,7 @@ import PIL import requests import torch +import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import transformers @@ -58,8 +59,6 @@ DATASET_NAME_MAPPING = { "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), } -LAYER_TO_FILL = "conv_in.weight" -NULL_PROMPT = "" WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"] @@ -374,19 +373,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def initialize_unet(unet: UNet2DConditionModel, instruct_pix2pix_unet: UNet2DConditionModel): - pretrained_unet_state_dict = unet.state_dict() - instruct_pix2pix_unet_state_dict = instruct_pix2pix_unet.state_dict() - for k in pretrained_unet_state_dict: - if k == LAYER_TO_FILL: - instruct_pix2pix_unet_state_dict[k].zero_() - instruct_pix2pix_unet_state_dict[k][:, :4, :, :].copy_(pretrained_unet_state_dict[k]) - else: - instruct_pix2pix_unet_state_dict[k].copy_(pretrained_unet_state_dict[k]) - instruct_pix2pix_unet.load_state_dict(instruct_pix2pix_unet_state_dict) - return instruct_pix2pix_unet - - def convert_to_np(image, resolution): image = image.convert("RGB").resize((resolution, resolution)) return np.array(image).transpose(2, 0, 1) @@ -412,9 +398,7 @@ def main(): ), ) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) - accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -423,6 +407,8 @@ def main(): project_config=accelerator_project_config, ) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") @@ -484,15 +470,19 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - if accelerator.is_main_process: - instruct_pix2pix_config = dict(unet.config) - instruct_pix2pix_config.update({"in_channels": 8}) - - instruct_pix2pix_unet = UNet2DConditionModel.from_config(instruct_pix2pix_config) - if accelerator.is_main_process: logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - instruct_pix2pix_unet = initialize_unet(unet, instruct_pix2pix_unet) + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channel=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in # Freeze vae and text_encoder vae.requires_grad_(False) @@ -500,13 +490,7 @@ def main(): # Create EMA for the unet. if args.use_ema: - ema_unet = UNet2DConditionModel.from_config(instruct_pix2pix_config) - if accelerator.is_main_process: - ema_unet = initialize_unet(unet, ema_unet) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) - - # Remove the `unet` as we don't need it. - del unet + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -556,7 +540,7 @@ def load_model_hook(models, input_dir): accelerator.register_load_state_pre_hook(load_model_hook) if args.gradient_checkpointing: - instruct_pix2pix_unet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -582,7 +566,7 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - instruct_pix2pix_unet.parameters(), + unet.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -735,8 +719,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - instruct_pix2pix_unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - instruct_pix2pix_unet, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) if args.use_ema: @@ -809,7 +793,7 @@ def collate_fn(examples): progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): - instruct_pix2pix_unet.train() + unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step @@ -818,7 +802,7 @@ def collate_fn(examples): progress_bar.update(1) continue - with accelerator.accumulate(instruct_pix2pix_unet): + with accelerator.accumulate(unet): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. # So, first, convert images to latent space. @@ -846,7 +830,7 @@ def collate_fn(examples): # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. if args.conditioning_dropout_prob is not None: - random_p = torch.rand(bsz, device=latents.device) + random_p = torch.rand(bsz, device=latents.device, generator=generator) # Sample masks for the edit prompts. prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1) @@ -876,7 +860,7 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = instruct_pix2pix_unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -886,7 +870,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(instruct_pix2pix_unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -894,7 +878,7 @@ def collate_fn(examples): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: - ema_unet.step(instruct_pix2pix_unet.parameters()) + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) @@ -925,11 +909,11 @@ def collate_fn(examples): # create pipeline if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(instruct_pix2pix_unet.parameters()) - ema_unet.copy_to(instruct_pix2pix_unet.parameters()) + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=instruct_pix2pix_unet, + unet=unet, revision=args.revision, torch_dtype=weight_dtype, ) @@ -937,7 +921,6 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) original_image = download_image(args.val_image_url) edited_images = [] with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): @@ -963,7 +946,7 @@ def collate_fn(examples): tracker.log({"validation": wandb_table}) if args.use_ema: # Switch back to the original UNet parameters. - ema_unet.restore(instruct_pix2pix_unet.parameters()) + ema_unet.restore(unet.parameters()) del pipeline torch.cuda.empty_cache() @@ -971,7 +954,7 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(instruct_pix2pix_unet) + unet = accelerator.unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) From eb3b7cadafe94ed0b80c06116e6d256030bd0136 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:12:19 +0530 Subject: [PATCH 19/45] add: readme details. --- examples/instruct_pix2pix/README.md | 145 +++++++++++++++++- .../train_instruct_pix2pix.py | 2 +- 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 9c86e02ee67f..f328128b825f 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -1 +1,144 @@ -# Training InstructPix2Pix \ No newline at end of file +# InstructPix2Pix training example + +[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs: + +![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png) + +The output is an "edited" image that reflects the edit instruction applied on the input image. + +The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. + +***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix +training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.*** + + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +### Toy example + +As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset +is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper. + +Configure environment variables such as the dataset identifier and the Stable Diffusion +checkpoint: + +```bash +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export DATASET_ID="fusing/instructpix2pix-1000-samples" +``` + +Now, we can launch training: + +```bash +accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --enable_xformers_memory_efficient_attention \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=15000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --seed=42 +``` + +Additionally, we support performing validation inference to monitor training progress +with Weights and Biases. You can enable this feature with `report_to="wandb"`: + +```bash +accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --enable_xformers_memory_efficient_attention \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=15000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \ + --validation_prompt="make the mountains snowy" \ + --seed=42 \ + --report_to=wandb + ``` + + We recommend this type of validation as it can be useful for model debugging. + + [Here] (TODO:run link), you can find an example training run that includes some validation samples and the training hyperparameters. + + Once training is complete, we can perform inference: + + ```python +import PIL +import requests +import torch +from diffusers import StableDiffusionInstructPix2PixPipeline + +model_id = "model id" +pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +generator = torch.Generator("cuda").manual_seed(0) + +url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png" + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + +image = download_image(url) +prompt = "wipe out the lake" +num_inference_steps = 20 +image_guidance_scale = 1.5 +guidance_scale = 50 + +edited_image = pipe(prompt, + image=image, + num_inference_steps=num_inference_steps, + image_guidance_scale=image_guidance_scale, + guidance_scale=guidance_scale, + generator=generator, +).images[0] +edited_image.save("edited_image.png") +``` diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index ce0998fb1129..b92d12278d0e 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.14.0.dev0") +check_min_version("0.15.0.dev0") logger = get_logger(__name__, log_level="INFO") From 81ae46419eb67ab4cd2aef94efe7ccc173ab8a17 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:16:35 +0530 Subject: [PATCH 20/45] config fix. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b92d12278d0e..8eb23b485f8d 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -474,7 +474,7 @@ def main(): logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") in_channels = 8 out_channels = unet.conv_in.out_channels - unet.register_to_config(in_channel=in_channels) + unet.register_to_config(in_channels=in_channels) with torch.no_grad(): new_conv_in = nn.Conv2d( From e0b17e55e5795bb4cef2bba1d4d978ea1255824c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:23:17 +0530 Subject: [PATCH 21/45] downgrade version. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 8eb23b485f8d..0de66a64f1cc 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.15.0.dev0") +check_min_version("0.14.0.dev0") logger = get_logger(__name__, log_level="INFO") From c3ebe7bda65c1ba6b312075cb9ee62e4e4304a18 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:30:38 +0530 Subject: [PATCH 22/45] reduce image width in the readme. --- examples/instruct_pix2pix/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index f328128b825f..f61e57dfc1e0 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -2,7 +2,9 @@ [InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs: -![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png) +

+ +

The output is an "edited" image that reflects the edit instruction applied on the input image. From 1807add201222ebcac34c8dd24aa52d949434689 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:46:28 +0530 Subject: [PATCH 23/45] note on hyperparameters during generation. --- examples/instruct_pix2pix/README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index f61e57dfc1e0..fb01fbaf3c51 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -3,10 +3,14 @@ [InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:

- + instructpix2pix-inputs

-The output is an "edited" image that reflects the edit instruction applied on the input image. +The output is an "edited" image that reflects the edit instruction applied on the input image: + +

+ instructpix2pix-output +

The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. @@ -133,7 +137,7 @@ image = download_image(url) prompt = "wipe out the lake" num_inference_steps = 20 image_guidance_scale = 1.5 -guidance_scale = 50 +guidance_scale = 10 edited_image = pipe(prompt, image=image, @@ -144,3 +148,13 @@ edited_image = pipe(prompt, ).images[0] edited_image.save("edited_image.png") ``` + +We encourage you to play with the following three parameters to control +speed and quality: + +* `num_inference_steps` +* `image_guidance_scale` +* `guidance_scale` + +Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact +on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example). From f12c04129686a75678553fbf634142b40042b95c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Mar 2023 09:51:47 +0530 Subject: [PATCH 24/45] add: output images. --- examples/instruct_pix2pix/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index fb01fbaf3c51..514899f7d7ee 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -9,7 +9,7 @@ The output is an "edited" image that reflects the edit instruction applied on the input image:

- instructpix2pix-output + instructpix2pix-output

The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. From 3627f96a2fd2515ea90e0c7d1a2b29882deacf6f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 11:35:48 +0530 Subject: [PATCH 25/45] update readme. --- examples/instruct_pix2pix/README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 514899f7d7ee..3fa6282f08ef 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -110,7 +110,7 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ We recommend this type of validation as it can be useful for model debugging. - [Here] (TODO:run link), you can find an example training run that includes some validation samples and the training hyperparameters. + [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters. Once training is complete, we can perform inference: @@ -120,7 +120,7 @@ import requests import torch from diffusers import StableDiffusionInstructPix2PixPipeline -model_id = "model id" +model_id = "model id" # <- replace this pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") generator = torch.Generator("cuda").manual_seed(0) @@ -149,6 +149,9 @@ edited_image = pipe(prompt, edited_image.save("edited_image.png") ``` +An example model repo obtained using this training script can be found +here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix). + We encourage you to play with the following three parameters to control speed and quality: From ac15648a1f1f248e3e0bdc3b34d5bd138c07c87d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 13:40:14 +0530 Subject: [PATCH 26/45] minor edits to readme. --- examples/instruct_pix2pix/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 3fa6282f08ef..13303f058a12 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -17,7 +17,6 @@ The `train_instruct_pix2pix.py` script shows how to implement the training proce ***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.*** - ## Running locally with PyTorch ### Installing the dependencies @@ -112,6 +111,8 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters. + ## Inference + Once training is complete, we can perform inference: ```python @@ -153,7 +154,7 @@ An example model repo obtained using this training script can be found here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix). We encourage you to play with the following three parameters to control -speed and quality: +speed and quality during performance: * `num_inference_steps` * `image_guidance_scale` From a1c14ff98405b28e51f0779c28daa8e3cac67d6a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 13:43:50 +0530 Subject: [PATCH 27/45] debugging statement. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 0de66a64f1cc..cf3ac0310027 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -974,6 +974,8 @@ def collate_fn(examples): edited_images = [] pipeline.torch_dtype = weight_dtype for _ in range(args.num_validation_images): + print(f"Pipeline device: {pipeline.device}") + print(f"Text encoder device: {pipeline.text_encoder.device}") edited_images.append( pipeline( args.validation_prompt, From d2effc4341245fb55bdf4574559e3f0c19ed73c1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 13:57:24 +0530 Subject: [PATCH 28/45] explicitly placement of the pipeline. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index cf3ac0310027..29f455bda9f6 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,6 +973,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline.torch_dtype = weight_dtype + pipeline.device = accelerator.device for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print(f"Text encoder device: {pipeline.text_encoder.device}") From 172599a7787f9269ee39b47440cac4ad9c6b2783 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 13:57:43 +0530 Subject: [PATCH 29/45] bump minimum diffusers version. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 29f455bda9f6..c104d39d9670 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.14.0.dev0") +check_min_version("0.15.0.dev0") logger = get_logger(__name__, log_level="INFO") From e0440e6fa4e3ef619b884916f028dbc8bbbabe6f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:10:13 +0530 Subject: [PATCH 30/45] fix: device attribute error. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index c104d39d9670..fb1a9183f297 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,7 +973,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline.torch_dtype = weight_dtype - pipeline.device = accelerator.device + pipeline = pipeline.to(accelerator.device) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print(f"Text encoder device: {pipeline.text_encoder.device}") From 49364b72a839cd84009c8d500fbe5e148b8bda3b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:20:57 +0530 Subject: [PATCH 31/45] weight dtype. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index fb1a9183f297..8285cb444824 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -972,7 +972,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - pipeline.torch_dtype = weight_dtype + # pipeline.torch_dtype = weight_dtype pipeline = pipeline.to(accelerator.device) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") From 8b728dc0373f66f8f0741ed37f1525d5797cfb8b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:27:09 +0530 Subject: [PATCH 32/45] debugging. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 8285cb444824..411c02d7c4e3 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -972,11 +972,15 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - # pipeline.torch_dtype = weight_dtype + pipeline.torch_dtype = weight_dtype pipeline = pipeline.to(accelerator.device) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print(f"Text encoder device: {pipeline.text_encoder.device}") + print( + f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" + ) + edited_images.append( pipeline( args.validation_prompt, From f3b2f533943348e0d5a67f71142aac5af07acaf9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:34:19 +0530 Subject: [PATCH 33/45] add dtype inform. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 411c02d7c4e3..2ce9862a1947 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -960,8 +960,8 @@ def collate_fn(examples): pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=text_encoder, - vae=vae, + # text_encoder=text_encoder, + # vae=vae, unet=unet, revision=args.revision, ) @@ -973,13 +973,15 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline.torch_dtype = weight_dtype - pipeline = pipeline.to(accelerator.device) + # pipeline = pipeline.to(accelerator.device) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") - print(f"Text encoder device: {pipeline.text_encoder.device}") print( f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" ) + print( + f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" + ) edited_images.append( pipeline( From 8faaab34deba7b08e3707f6c6b3443e2a0999af7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:45:59 +0530 Subject: [PATCH 34/45] add seoarate te and vae. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 2ce9862a1947..a46356ea9e2f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -960,8 +960,8 @@ def collate_fn(examples): pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - # text_encoder=text_encoder, - # vae=vae, + text_encoder=text_encoder, + vae=vae, unet=unet, revision=args.revision, ) From 52c75f9cf7ed41ad4323cb3087ef922f21441105 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 14:54:01 +0530 Subject: [PATCH 35/45] add: explicit casting/ --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index a46356ea9e2f..50c010e70e8e 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,7 +973,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline.torch_dtype = weight_dtype - # pipeline = pipeline.to(accelerator.device) + pipeline.unet = pipeline.unet.to(weight_dtype) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print( From e0c78de50565b6a4c6adb0695ef7aa8a2850b6c7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 15:18:09 +0530 Subject: [PATCH 36/45] remove casting. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 50c010e70e8e..4cdebcfea5c9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -972,8 +972,8 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - pipeline.torch_dtype = weight_dtype - pipeline.unet = pipeline.unet.to(weight_dtype) + # pipeline.torch_dtype = weight_dtype + # pipeline.unet = pipeline.unet.to(weight_dtype) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print( From 0074d055e9a052dbc81ed30499ff8a9954343572 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 15:22:04 +0530 Subject: [PATCH 37/45] up. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 4cdebcfea5c9..c52eaf7a21a1 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,7 +973,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] # pipeline.torch_dtype = weight_dtype - # pipeline.unet = pipeline.unet.to(weight_dtype) + pipeline.unet = pipeline.unet.to(weight_dtype) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print( From ba2855eaf3eb08d811a9bd609d6a855d893e108e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 15:27:56 +0530 Subject: [PATCH 38/45] up 2. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index c52eaf7a21a1..1422b3d58608 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -972,8 +972,8 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - # pipeline.torch_dtype = weight_dtype - pipeline.unet = pipeline.unet.to(weight_dtype) + pipeline.torch_dtype = weight_dtype + # pipeline.unet = pipeline.unet.to(weight_dtype) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print( From 59cd2da8795bac5d6b1fa149f0ecf157ace2d04f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Mar 2023 15:33:04 +0530 Subject: [PATCH 39/45] up 3. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 1422b3d58608..50c010e70e8e 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -973,7 +973,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline.torch_dtype = weight_dtype - # pipeline.unet = pipeline.unet.to(weight_dtype) + pipeline.unet = pipeline.unet.to(weight_dtype) for _ in range(args.num_validation_images): print(f"Pipeline device: {pipeline.device}") print( From a53cc41722cc4902368d6a889cf6b36c046c1452 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 17 Mar 2023 08:32:34 +0530 Subject: [PATCH 40/45] autocast. --- examples/instruct_pix2pix/README.md | 2 + .../train_instruct_pix2pix.py | 42 +++++++++---------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 13303f058a12..6674cc245066 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -111,6 +111,8 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters. + ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.*** + ## Inference Once training is complete, we can perform inference: diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 50c010e70e8e..da4b60cea658 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -972,27 +972,27 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - pipeline.torch_dtype = weight_dtype - pipeline.unet = pipeline.unet.to(weight_dtype) - for _ in range(args.num_validation_images): - print(f"Pipeline device: {pipeline.device}") - print( - f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" - ) - print( - f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" - ) - - edited_images.append( - pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - ) + # pipeline.torch_dtype = weight_dtype + # pipeline.unet = pipeline.unet.to(weight_dtype) + with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + for _ in range(args.num_validation_images): + # print(f"Pipeline device: {pipeline.device}") + # print( + # f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" + # ) + # print( + # f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" + # ) + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) for tracker in accelerator.trackers: if tracker.name == "wandb": From 4c8920cf79be44cecabfa403fcebaf9ad711dc76 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 17 Mar 2023 08:40:23 +0530 Subject: [PATCH 41/45] disable mixed-precision in the final inference. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index da4b60cea658..9421ef5d32bf 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -974,7 +974,7 @@ def collate_fn(examples): edited_images = [] # pipeline.torch_dtype = weight_dtype # pipeline.unet = pipeline.unet.to(weight_dtype) - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast(str(accelerator.device)): for _ in range(args.num_validation_images): # print(f"Pipeline device: {pipeline.device}") # print( From fbc626a177bec118a335e7c7ad56c04ae6fa96f2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 17 Mar 2023 08:44:21 +0530 Subject: [PATCH 42/45] debugging information. --- examples/instruct_pix2pix/train_instruct_pix2pix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 9421ef5d32bf..04a65fb59db6 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -976,10 +976,10 @@ def collate_fn(examples): # pipeline.unet = pipeline.unet.to(weight_dtype) with torch.autocast(str(accelerator.device)): for _ in range(args.num_validation_images): - # print(f"Pipeline device: {pipeline.device}") - # print( - # f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" - # ) + print(f"Pipeline device: {pipeline.device}") + print( + f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" + ) # print( # f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" # ) From 4e1545b101fbc299633ea554d5c90ec663f221d5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 20 Mar 2023 15:55:03 +0530 Subject: [PATCH 43/45] autocasting. --- .../instruct_pix2pix/train_instruct_pix2pix.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 04a65fb59db6..57430b7f150a 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -960,8 +960,8 @@ def collate_fn(examples): pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=text_encoder, - vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), unet=unet, revision=args.revision, ) @@ -972,17 +972,9 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] - # pipeline.torch_dtype = weight_dtype - # pipeline.unet = pipeline.unet.to(weight_dtype) + pipeline = pipeline.to(accelerator.device) with torch.autocast(str(accelerator.device)): for _ in range(args.num_validation_images): - print(f"Pipeline device: {pipeline.device}") - print( - f"UNet: {pipeline.unet.device} Text Encoder: {pipeline.text_encoder.device} VAE: {pipeline.vae.device}" - ) - # print( - # f"UNet: {pipeline.unet.dtype} Text Encoder: {pipeline.text_encoder.dtype} VAE: {pipeline.vae.dtype}" - # ) edited_images.append( pipeline( args.validation_prompt, From b6b1d749afaec55ef52c845ef59ae4b7cf000dfa Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 21 Mar 2023 07:29:43 +0530 Subject: [PATCH 44/45] add: instructpix2pix training section to the docs. --- docs/source/en/_toctree.yml | 2 + docs/source/en/training/instructpix2pix.mdx | 181 ++++++++++++++++++++ examples/instruct_pix2pix/README.md | 4 +- 3 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/training/instructpix2pix.mdx diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 301a4cccf404..4759df424b54 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -91,6 +91,8 @@ title: Text-to-image - local: training/lora title: Low-Rank Adaptation of Large Language Models (LoRA) + - local: training/instructpix2pix + title: InstructPix2Pix Training title: Training - sections: - local: conceptual/philosophy diff --git a/docs/source/en/training/instructpix2pix.mdx b/docs/source/en/training/instructpix2pix.mdx new file mode 100644 index 000000000000..e6f050b34acf --- /dev/null +++ b/docs/source/en/training/instructpix2pix.mdx @@ -0,0 +1,181 @@ + + +# InstructPix2Pix + +[InstructPix2Pix](https://arxiv.org/abs/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs: + +

+ instructpix2pix-inputs +

+ +The output is an "edited" image that reflects the edit instruction applied on the input image: + +

+ instructpix2pix-output +

+ +The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. + +***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix +training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.*** + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config + +write_basic_config() +``` + +### Toy example + +As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset +is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper. + +Configure environment variables such as the dataset identifier and the Stable Diffusion +checkpoint: + +```bash +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export DATASET_ID="fusing/instructpix2pix-1000-samples" +``` + +Now, we can launch training: + +```bash +accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --enable_xformers_memory_efficient_attention \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=15000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --seed=42 +``` + +Additionally, we support performing validation inference to monitor training progress +with Weights and Biases. You can enable this feature with `report_to="wandb"`: + +```bash +accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --enable_xformers_memory_efficient_attention \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=15000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \ + --validation_prompt="make the mountains snowy" \ + --seed=42 \ + --report_to=wandb + ``` + + We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`. + + [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters. + + ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.*** + + ## Inference + + Once training is complete, we can perform inference: + + ```python +import PIL +import requests +import torch +from diffusers import StableDiffusionInstructPix2PixPipeline + +model_id = "your_model_id" # <- replace this +pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") +generator = torch.Generator("cuda").manual_seed(0) + +url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png" + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + +image = download_image(url) +prompt = "wipe out the lake" +num_inference_steps = 20 +image_guidance_scale = 1.5 +guidance_scale = 10 + +edited_image = pipe( + prompt, + image=image, + num_inference_steps=num_inference_steps, + image_guidance_scale=image_guidance_scale, + guidance_scale=guidance_scale, + generator=generator, +).images[0] +edited_image.save("edited_image.png") +``` + +An example model repo obtained using this training script can be found +here - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix). + +We encourage you to play with the following three parameters to control +speed and quality during performance: + +* `num_inference_steps` +* `image_guidance_scale` +* `guidance_scale` + +Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact +on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example). diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 6674cc245066..02f0fed04299 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -107,7 +107,7 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ --report_to=wandb ``` - We recommend this type of validation as it can be useful for model debugging. + We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`. [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters. @@ -123,7 +123,7 @@ import requests import torch from diffusers import StableDiffusionInstructPix2PixPipeline -model_id = "model id" # <- replace this +model_id = "your_model_id" # <- replace this pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") generator = torch.Generator("cuda").manual_seed(0) From b499450ce7a914b8be68aa0dc44263a98e738a3d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 23 Mar 2023 10:05:19 +0530 Subject: [PATCH 45/45] Empty-Commit