From fb5488d991fbb7517663c87cc55ec900ebca3ce7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sat, 3 Aug 2024 18:21:53 +0300 Subject: [PATCH 01/46] initial commit - dreambooth for flux --- examples/dreambooth/train_dreambooth_flux.py | 1772 ++++++++++++++++++ 1 file changed, 1772 insertions(+) create mode 100644 examples/dreambooth/train_dreambooth_flux.py diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py new file mode 100644 index 000000000000..2739a7565572 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -0,0 +1,1772 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import itertools +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + SD3Transformer2DModel, + FluxPipeline, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.utils import ( + check_min_version, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# SD3 DreamBooth - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md). + +Was the text encoder fine-tuned? {train_text_encoder}. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +## License + +Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "sd3", + "sd3-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two, class_three): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--precondition_outputs", + type=int, + default=1, + help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " + "model `target` is calculated.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoder.device, + num_images_per_prompt=num_images_per_prompt, + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[-1].device, + ) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = StableDiffusion3Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = SD3Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + transformer.requires_grad_(True) + vae.requires_grad_(False) + if args.train_text_encoder: + text_encoder_one.requires_grad_(True) + text_encoder_two.requires_grad_(True) + else: + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=torch.float32) + if not args.train_text_encoder: + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for i, model in enumerate(models): + if isinstance(unwrap_model(model), SD3Transformer2DModel): + unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + if isinstance(unwrap_model(model), CLIPTextModelWithProjection): + hidden_size = unwrap_model(model).config.hidden_size + if hidden_size == 768: + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) + elif hidden_size == 1280: + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) + else: + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) + else: + raise ValueError(f"Wrong model supplied: {type(model)=}.") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + for _ in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + if isinstance(unwrap_model(model), SD3Transformer2DModel): + load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): + try: + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder") + model(**load_model.config) + model.load_state_dict(load_model.state_dict()) + except Exception: + try: + load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") + model(**load_model.config) + model.load_state_dict(load_model.state_dict()) + except Exception: + try: + load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") + model(**load_model.config) + model.load_state_dict(load_model.state_dict()) + except Exception: + raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + else: + raise ValueError(f"Unsupported model found: {type(model)=}") + + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_parameters_one_with_lr = { + "params": text_encoder_one.parameters(), + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_parameters_two_with_lr = { + "params": text_encoder_two.parameters(), + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + text_parameters_two_with_lr, + ] + else: + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + params_to_optimize[3]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection + del text_encoder_one, text_encoder_two + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + ( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-sd3" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one, text_encoder_two]) + with accelerator.accumulate(models_to_accumulate): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts) + tokens_two = tokenize_prompt(tokenizer_two, prompts) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Predict the noise residual + if not args.train_text_encoder: + model_pred = transformer( + hidden_states=noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + else: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + model_pred = transformer( + hidden_states=noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + if args.precondition_outputs: + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + if args.precondition_outputs: + target = model_input + else: + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain( + transformer.parameters(), + text_encoder_one.parameters(), + text_encoder_two.parameters() + ) + if args.train_text_encoder + else transformer.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if not args.train_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two + torch.cuda.empty_cache() + gc.collect() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_two = unwrap_model(text_encoder_two) + pipeline = StableDiffusion3Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=transformer, + text_encoder=text_encoder_one, + text_encoder_2=text_encoder_two, + ) + else: + pipeline = StableDiffusion3Pipeline.from_pretrained( + args.pretrained_model_name_or_path, transformer=transformer + ) + + # save the pipeline + pipeline.save_pretrained(args.output_dir) + + # Final inference + # Load previous pipeline + pipeline = StableDiffusion3Pipeline.from_pretrained( + args.output_dir, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 3062dafc067efbf0dd7ee6cdb6b3002300d2cd83 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 10:28:40 +0300 Subject: [PATCH 02/46] update transformer to be FluxTransformer2DModel --- examples/dreambooth/train_dreambooth_flux.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 2739a7565572..6ca1395ffa0e 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -47,7 +47,7 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - SD3Transformer2DModel, + FluxTransformer2DModel, FluxPipeline, ) from diffusers.optimization import get_scheduler @@ -1100,7 +1100,7 @@ def main(args): revision=args.revision, variant=args.variant, ) - transformer = SD3Transformer2DModel.from_pretrained( + transformer = FluxTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) @@ -1147,7 +1147,7 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for i, model in enumerate(models): - if isinstance(unwrap_model(model), SD3Transformer2DModel): + if isinstance(unwrap_model(model), FluxTransformer2DModel): unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): @@ -1170,8 +1170,8 @@ def load_model_hook(models, input_dir): model = models.pop() # load diffusers style into model - if isinstance(unwrap_model(model), SD3Transformer2DModel): - load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder="transformer") + if isinstance(unwrap_model(model), FluxTransformer2DModel): + load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) @@ -1506,6 +1506,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): models_to_accumulate.extend([text_encoder_one, text_encoder_two]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + #latent_image_ids= prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1549,13 +1550,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = transformer( hidden_states=noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timestep / 1000, + timestep=timesteps / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] else: @@ -1568,13 +1568,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = transformer( hidden_states=noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timestep / 1000, + timestep=timesteps / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 1448aa1f84052d7ea035a131f7d520d9fe7eec3d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 13:46:33 +0300 Subject: [PATCH 03/46] update training loop and validation inference --- examples/dreambooth/train_dreambooth_flux.py | 35 ++++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 6ca1395ffa0e..82e093eb9773 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -58,7 +58,7 @@ ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module - +from diffusers.pipelines.flux.pipeline_flux import prepare_latent_image_ids,pack_latents,unpack_latents if is_wandb_available(): import wandb @@ -1506,7 +1506,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): models_to_accumulate.extend([text_encoder_one, text_encoder_two]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - #latent_image_ids= prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1523,6 +1522,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) + latent_image_ids = prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1545,10 +1551,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + packed_noisy_model_input = pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + # Predict the noise residual if not args.train_text_encoder: model_pred = transformer( - hidden_states=noisy_model_input, + hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000, guidance=guidance, @@ -1566,7 +1580,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_input_ids_list=[tokens_one, tokens_two], ) model_pred = transformer( - hidden_states=noisy_model_input, + hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000, guidance=guidance, @@ -1577,6 +1591,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] + model_pred = unpack_latents( + model_pred, + height=model_input.shape[2], + width=model_input.shape[3], + vae_scale_factor=vae.config.scaling_factor, + ) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. if args.precondition_outputs: @@ -1710,14 +1731,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_two = unwrap_model(text_encoder_two) - pipeline = StableDiffusion3Pipeline.from_pretrained( + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, ) else: - pipeline = StableDiffusion3Pipeline.from_pretrained( + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer ) @@ -1726,7 +1747,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Final inference # Load previous pipeline - pipeline = StableDiffusion3Pipeline.from_pretrained( + pipeline = FluxPipeline.from_pretrained( args.output_dir, revision=args.revision, variant=args.variant, From a59e0120569f44f4b2d69cde676376e3d808c30f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 14:00:04 +0300 Subject: [PATCH 04/46] fix sd3->flux docs --- examples/dreambooth/train_dreambooth_flux.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 82e093eb9773..0d726aed351c 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -87,7 +87,7 @@ def save_model_card( ) model_description = f""" -# SD3 DreamBooth - {repo_id} +# Flux [dev] DreamBooth - {repo_id} @@ -95,7 +95,7 @@ def save_model_card( These are {repo_id} DreamBooth weights for {base_model}. -The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md). +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). Was the text encoder fine-tuned? {train_text_encoder}. @@ -114,7 +114,7 @@ def save_model_card( ## License -Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. +Please adhere to the licensing terms as described `[here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)`. """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, @@ -129,8 +129,8 @@ def save_model_card( "text-to-image", "diffusers-training", "diffusers", - "sd3", - "sd3-diffusers", + "flux", + "flux-diffusers", "template:sd-lora", ] @@ -1024,7 +1024,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 - pipeline = StableDiffusion3Pipeline.from_pretrained( + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, revision=args.revision, @@ -1430,7 +1430,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-sd3" + tracker_name = "dreambooth-flux" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! From df074d43e66f1c096dab19413ba54b92763283ae Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 14:48:42 +0300 Subject: [PATCH 05/46] add guidance handling, not sure if it makes sense(?) --- examples/dreambooth/train_dreambooth_flux.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 0d726aed351c..8b9f8bab5054 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -434,6 +434,13 @@ def parse_args(input_args=None): help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + parser.add_argument( "--text_encoder_lr", type=float, @@ -1559,6 +1566,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): width=model_input.shape[3], ) + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + # Predict the noise residual if not args.train_text_encoder: model_pred = transformer( From c897c31f84283c707f53d417043c9ca5cfcb65cd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 15:41:17 +0300 Subject: [PATCH 06/46] inital dreambooth lora commit --- examples/dreambooth/train_dreambooth_flux.py | 2 +- .../dreambooth/train_dreambooth_lora_flux.py | 1752 +++++++++++++++++ 2 files changed, 1753 insertions(+), 1 deletion(-) create mode 100644 examples/dreambooth/train_dreambooth_lora_flux.py diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 8b9f8bab5054..a1b9fb641c20 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -114,7 +114,7 @@ def save_model_card( ## License -Please adhere to the licensing terms as described `[here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)`. +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py new file mode 100644 index 000000000000..e35d8b36eb6a --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -0,0 +1,1752 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import gc +import itertools +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxTransformer2DModel, + FluxPipeline, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke + +- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**. + - Rename it and place it on your `models/Lora` folder. + - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/). + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--precondition_outputs", + type=int, + default=1, + help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " + "model `target` is calculated.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoder.device, + num_images_per_prompt=num_images_per_prompt, + ) + + prompt_embeds = _encode_prompt_with_t5( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[-1].device, + ) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=torch.float32) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + text_encoder_two_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + transformer_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] + else: + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds + + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection + del text_encoder_one, text_encoder_two + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + with accelerator.accumulate(models_to_accumulate): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Predict the noise residual + model_pred = transformer( + hidden_states=noisy_model_input, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + return_dict=False, + )[0] + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + if args.precondition_outputs: + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + if args.precondition_outputs: + target = model_input + else: + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = (transformer_lora_parameters) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + text_encoder_one, text_encoder_two = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two + ) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + del text_encoder_one, text_encoder_two + + torch.cuda.empty_cache() + gc.collect() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer = transformer.to(torch.float32) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers + ) + + # Final inference + # Load previous pipeline + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From c3c38a42b3a98589bdf3951ed1ddcad9b4ffa191 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 Aug 2024 15:53:55 +0300 Subject: [PATCH 07/46] fix text_ids in compute_text_embeddings --- examples/dreambooth/train_dreambooth_flux.py | 2 + .../dreambooth/train_dreambooth_lora_flux.py | 46 +++++++++++++++---- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index a1b9fb641c20..54cd12d02f64 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1375,9 +1375,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if not args.train_text_encoder: prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e35d8b36eb6a..c42d867d710f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -64,6 +64,7 @@ convert_unet_state_dict_to_peft, is_wandb_available, ) +from diffusers.pipelines.flux.pipeline_flux import prepare_latent_image_ids,pack_latents,unpack_latents from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -1391,22 +1392,23 @@ def load_model_hook(models, input_dir): def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length ) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) @@ -1426,9 +1428,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if not train_dataset.custom_instance_prompts: prompt_embeds = instance_prompt_hidden_states pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) # Scheduler and math around the number of training steps. @@ -1538,7 +1542,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( prompts, text_encoders, tokenizers ) @@ -1546,6 +1550,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) + latent_image_ids = prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1568,15 +1579,34 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + packed_noisy_model_input = pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + # Predict the noise residual model_pred = transformer( - hidden_states=noisy_model_input, - timestep=timesteps, - encoder_hidden_states=prompt_embeds, + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, return_dict=False, )[0] + model_pred = unpack_latents( + model_pred, + height=model_input.shape[2], + width=model_input.shape[3], + vae_scale_factor=vae.config.scaling_factor, + ) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. if args.precondition_outputs: From 259e44334c89affcb3b4dcf1b4694b02a31ba3fc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:08:16 +0300 Subject: [PATCH 08/46] fix imports of static methods --- examples/dreambooth/train_dreambooth_flux.py | 8 ++++---- examples/dreambooth/train_dreambooth_lora_flux.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 54cd12d02f64..49d0eb0d6d90 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -58,7 +58,7 @@ ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module -from diffusers.pipelines.flux.pipeline_flux import prepare_latent_image_ids,pack_latents,unpack_latents +from diffusers.pipelines.flux.pipeline_flux import _prepare_latent_image_ids,_pack_latents,_unpack_latents if is_wandb_available(): import wandb @@ -1531,7 +1531,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - latent_image_ids = prepare_latent_image_ids( + latent_image_ids = _prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], model_input.shape[3], @@ -1560,7 +1560,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = pack_latents( + packed_noisy_model_input = _pack_latents( noisy_model_input, batch_size=model_input.shape[0], num_channels_latents=model_input.shape[1], @@ -1607,7 +1607,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] - model_pred = unpack_latents( + model_pred = _unpack_latents( model_pred, height=model_input.shape[2], width=model_input.shape[3], diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c42d867d710f..3a1993f8c569 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -64,7 +64,7 @@ convert_unet_state_dict_to_peft, is_wandb_available, ) -from diffusers.pipelines.flux.pipeline_flux import prepare_latent_image_ids,pack_latents,unpack_latents +from diffusers.pipelines.flux.pipeline_flux import _prepare_latent_image_ids,_pack_latents,_unpack_latents from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -1550,7 +1550,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - latent_image_ids = prepare_latent_image_ids( + latent_image_ids = _prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], model_input.shape[3], @@ -1579,7 +1579,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = pack_latents( + packed_noisy_model_input = _pack_latents( noisy_model_input, batch_size=model_input.shape[0], num_channels_latents=model_input.shape[1], @@ -1600,7 +1600,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] - model_pred = unpack_latents( + model_pred = _unpack_latents( model_pred, height=model_input.shape[2], width=model_input.shape[3], From 3d5b713b3e37eb644d5c5a21643a991a8c03971c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:14:08 +0300 Subject: [PATCH 09/46] fix pipeline loading in readme, remove auto1111 docs for now --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 49d0eb0d6d90..4e46b694f612 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -108,7 +108,7 @@ def save_model_card( ```py from diffusers import AutoPipelineForText2Image import torch -pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda') +pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.bfloat16).to('cuda') image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3a1993f8c569..d03242b16a17 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -121,17 +121,11 @@ def save_model_card( ```py from diffusers import AutoPipelineForText2Image import torch -pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda') +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` -### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke - -- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**. - - Rename it and place it on your `models/Lora` folder. - - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/). - For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) ## License From cdbb69c63277e137ec9979a82f4e92a34bbe3eeb Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:18:25 +0300 Subject: [PATCH 10/46] fix pipeline loading in readme, remove auto1111 docs for now, remove some irrelevant text_encoder_3 refs --- examples/dreambooth/train_dreambooth_flux.py | 4 +--- examples/dreambooth/train_dreambooth_lora_flux.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 4e46b694f612..4e7ca63676dc 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -119,7 +119,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="flux-1-dev-non-commercial-license", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -1163,8 +1163,6 @@ def save_model_hook(models, weights, output_dir): unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) elif hidden_size == 1280: unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) - else: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_3")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d03242b16a17..7109bfbd053d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -135,7 +135,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="flux-1-dev-non-commercial-license", base_model=base_model, prompt=instance_prompt, model_description=model_description, @@ -1200,7 +1200,6 @@ def save_model_hook(models, weights, output_dir): output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, - text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, ) def load_model_hook(models, input_dir): From b249d36dea62ae6d47e7979e7d103c7442f28f8f Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:27:23 +0300 Subject: [PATCH 11/46] Update examples/dreambooth/train_dreambooth_flux.py Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_flux.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 4e7ca63676dc..aae5f0b33e10 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1187,12 +1187,10 @@ def load_model_hook(models, input_dir): model.load_state_dict(load_model.state_dict()) except Exception: try: - load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder="text_encoder_2") + load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_2") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - try: - load_model = T5EncoderModel.from_pretrained(input_dir, subfolder="text_encoder_3") model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: From 37b8e6456cc45743d1fd01a0c3300867e65b813b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:32:48 +0300 Subject: [PATCH 12/46] fix te2 loading and remove te2 refs from text encoder training --- examples/dreambooth/train_dreambooth_flux.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index aae5f0b33e10..c3952a79b530 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1191,9 +1191,6 @@ def load_model_hook(models, input_dir): model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - model(**load_model.config) - model.load_state_dict(load_model.state_dict()) - except Exception: raise ValueError(f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1222,15 +1219,9 @@ def load_model_hook(models, input_dir): "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - text_parameters_two_with_lr = { - "params": text_encoder_two.parameters(), - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } params_to_optimize = [ transformer_parameters_with_lr, text_parameters_one_with_lr, - text_parameters_two_with_lr, ] else: params_to_optimize = [transformer_parameters_with_lr] @@ -1291,7 +1282,6 @@ def load_model_hook(models, input_dir): # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate params_to_optimize[2]["lr"] = args.learning_rate - params_to_optimize[3]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1503,12 +1493,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer.train() if args.train_text_encoder: text_encoder_one.train() - text_encoder_two.train() for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one, text_encoder_two]) + models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] From 0714278fe2fc7c1d1f84a0bdf643e127e60c3f99 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:34:34 +0300 Subject: [PATCH 13/46] fix tokenizer_2 initialization --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index c3952a79b530..e155a28504a8 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1079,7 +1079,7 @@ def main(args): subfolder="tokenizer", revision=args.revision, ) - tokenizer_two = CLIPTokenizer.from_pretrained( + tokenizer_two = T5TokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 7109bfbd053d..6cbd94efad1a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1091,7 +1091,7 @@ def main(args): subfolder="tokenizer", revision=args.revision, ) - tokenizer_two = CLIPTokenizer.from_pretrained( + tokenizer_two = T5TokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, From 80c3fe04bb57f3ae5fdb477d554826cea377d1e8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 12:38:06 +0300 Subject: [PATCH 14/46] remove text_encoder training refs from lora script (for now) --- .../dreambooth/train_dreambooth_lora_flux.py | 46 +------------------ 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6cbd94efad1a..0b5422f632b1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1148,9 +1148,6 @@ def main(args): if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder_one.gradient_checkpointing_enable() - text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -1161,15 +1158,6 @@ def main(args): ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.rank, - init_lora_weights="gaussian", - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder_one.add_adapter(text_lora_config) - text_encoder_two.add_adapter(text_lora_config) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1234,21 +1222,12 @@ def load_model_hook(models, input_dir): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) - - _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ - ) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": models = [transformer_] - if args.train_text_encoder: - models.extend([text_encoder_one_, text_encoder_two_]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models) @@ -1268,37 +1247,14 @@ def load_model_hook(models, input_dir): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": models = [transformer] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: - text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) - text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: - # different learning rate for text encoder and unet - text_lora_parameters_one_with_lr = { - "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - text_lora_parameters_two_with_lr = { - "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - params_to_optimize = [ - transformer_parameters_with_lr, - text_lora_parameters_one_with_lr, - text_lora_parameters_two_with_lr, - ] - else: - params_to_optimize = [transformer_parameters_with_lr] + params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): From 77a023588b2984c727b1d97adbf18379ab1e1c9c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 15:55:56 +0300 Subject: [PATCH 15/46] try with vae in bfloat16, fix model hook save --- examples/dreambooth/train_dreambooth_flux.py | 12 +++++------- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index e155a28504a8..5661fde1df40 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1120,7 +1120,7 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -1134,7 +1134,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=torch.bfloat16) if not args.train_text_encoder: text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) @@ -1158,11 +1158,9 @@ def save_model_hook(models, weights, output_dir): unwrap_model(model).save_pretrained(os.path.join(output_dir, "transformer")) elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)): if isinstance(unwrap_model(model), CLIPTextModelWithProjection): - hidden_size = unwrap_model(model).config.hidden_size - if hidden_size == 768: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) - elif hidden_size == 1280: - unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder")) + else: + unwrap_model(model).save_pretrained(os.path.join(output_dir, "text_encoder_2")) else: raise ValueError(f"Wrong model supplied: {type(model)=}.") diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 0b5422f632b1..5c5725e72d41 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1141,7 +1141,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=torch.float32) + vae.to(accelerator.device, dtype=torch.bfloat16) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From a64f4a5230ce47a96960e898c09196988e6e8ced Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 16:15:06 +0300 Subject: [PATCH 16/46] fix tokenization --- examples/dreambooth/train_dreambooth_flux.py | 18 ++++++++++-------- .../dreambooth/train_dreambooth_lora_flux.py | 11 ----------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 5661fde1df40..c1fe36add5f9 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -848,12 +848,14 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt): +def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): text_inputs = tokenizer( prompt, padding="max_length", - max_length=77, + max_length=max_sequence_length, truncation=True, + return_length=False, + return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids @@ -1367,11 +1369,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) @@ -1507,8 +1509,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompts, text_encoders, tokenizers ) else: - tokens_one = tokenize_prompt(tokenizer_one, prompts) - tokens_two = tokenize_prompt(tokenizer_two, prompts) + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5c5725e72d41..d89d00cc4f86 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -860,17 +860,6 @@ def __getitem__(self, index): return example -def tokenize_prompt(tokenizer, prompt): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=77, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - return text_input_ids - def _encode_prompt_with_t5( text_encoder, From 08a12968f76319456d6d07f90ff8aa7e8423838e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 18:19:45 +0300 Subject: [PATCH 17/46] fix static imports --- examples/dreambooth/train_dreambooth_flux.py | 7 +++---- examples/dreambooth/train_dreambooth_lora_flux.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index c1fe36add5f9..6215292d8f38 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -58,7 +58,6 @@ ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module -from diffusers.pipelines.flux.pipeline_flux import _prepare_latent_image_ids,_pack_latents,_unpack_latents if is_wandb_available(): import wandb @@ -1516,7 +1515,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - latent_image_ids = _prepare_latent_image_ids( + latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], model_input.shape[3], @@ -1545,7 +1544,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = _pack_latents( + packed_noisy_model_input = FluxPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], num_channels_latents=model_input.shape[1], @@ -1592,7 +1591,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] - model_pred = _unpack_latents( + model_pred = FluxPipeline._unpack_latents( model_pred, height=model_input.shape[2], width=model_input.shape[3], diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d89d00cc4f86..4ecfcb27ce75 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -64,7 +64,6 @@ convert_unet_state_dict_to_peft, is_wandb_available, ) -from diffusers.pipelines.flux.pipeline_flux import _prepare_latent_image_ids,_pack_latents,_unpack_latents from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -1488,7 +1487,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - latent_image_ids = _prepare_latent_image_ids( + latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], model_input.shape[3], @@ -1517,7 +1516,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = _pack_latents( + packed_noisy_model_input = FluxPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], num_channels_latents=model_input.shape[1], @@ -1538,7 +1537,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): return_dict=False, )[0] - model_pred = _unpack_latents( + model_pred = FluxPipeline._unpack_latents( model_pred, height=model_input.shape[2], width=model_input.shape[3], From e4746830c87fa862c4892cc0b8c684646bd2f979 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 18:24:28 +0300 Subject: [PATCH 18/46] fix CLIP import --- examples/dreambooth/train_dreambooth_flux.py | 6 +++--- examples/dreambooth/train_dreambooth_lora_flux.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 6215292d8f38..984f18046f80 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -198,10 +198,10 @@ def import_model_class_from_model_name_or_path( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] - if model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel - return CLIPTextModelWithProjection + return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 4ecfcb27ce75..d4c5f5949bcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -215,10 +215,10 @@ def import_model_class_from_model_name_or_path( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] - if model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel - return CLIPTextModelWithProjection + return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel From 97f15460e5c8b01664fd81a7e47bf75065c219da Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 18:38:09 +0300 Subject: [PATCH 19/46] remove text_encoder training refs (for now) from lora script --- .../dreambooth/train_dreambooth_lora_flux.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index d4c5f5949bcf..9aff4ae964e9 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1323,34 +1323,32 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - if not args.train_text_encoder: - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] - - def compute_text_embeddings(prompt, text_encoders, tokenizers): - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders, tokenizers, prompt, args.max_sequence_length - ) - prompt_embeds = prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - text_ids = text_ids.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds, text_ids + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids - if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + if not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) # Clear the memory here - if not args.train_text_encoder and train_dataset.custom_instance_prompts: + if not train_dataset.custom_instance_prompts: del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two From 187e42177aa2144626b3aad785e52b663c300c24 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 6 Aug 2024 21:56:44 +0300 Subject: [PATCH 20/46] fix minor bug in encode_prompt, add guidance def in lora script, ... --- examples/dreambooth/train_dreambooth_flux.py | 6 +++--- .../dreambooth/train_dreambooth_lora_flux.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 984f18046f80..af7ce8842028 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -946,7 +946,7 @@ def encode_prompt( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - device=device if device is not None else text_encoder.device, + device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, ) @@ -956,7 +956,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[-1].device, + device=device if device is not None else text_encoders[1].device, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1554,7 +1554,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # handle guidance if transformer.config.guidance_embeds: - guidance = torch.tensor([args.guidance_scale], device=device) + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: guidance = None diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9aff4ae964e9..85169e382a20 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -453,11 +453,12 @@ def parse_args(input_args=None): ) parser.add_argument( - "--text_encoder_lr", + "--guidance_scale", type=float, - default=5e-6, - help="Text encoder learning rate to use.", + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", ) + parser.add_argument( "--scale_lr", action="store_true", @@ -945,7 +946,7 @@ def encode_prompt( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, - device=device if device is not None else text_encoder.device, + device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, ) @@ -955,7 +956,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, - device=device if device is not None else text_encoders[-1].device, + device=device if device is not None else text_encoders[1].device, ) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) @@ -1522,6 +1523,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): width=model_input.shape[3], ) + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + # Predict the noise residual model_pred = transformer( hidden_states=packed_noisy_model_input, From b24f6732cffeccea0062d38c8b822fa9f74fedf8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 12:22:41 +0300 Subject: [PATCH 21/46] fix unpack_latents args --- examples/dreambooth/train_dreambooth_flux.py | 6 +++--- examples/dreambooth/train_dreambooth_lora_flux.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index af7ce8842028..397488a3a6a3 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1593,9 +1593,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = FluxPipeline._unpack_latents( model_pred, - height=model_input.shape[2], - width=model_input.shape[3], - vae_scale_factor=vae.config.scaling_factor, + height=int(model_input.shape[2])*8, + width=int(model_input.shape[3])*8, + vae_scale_factor=16, #should this be 2 ** (len(vae.config.block_out_channels))? ) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 85169e382a20..1819cea8d5d5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1542,12 +1542,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=latent_image_ids, return_dict=False, )[0] - model_pred = FluxPipeline._unpack_latents( model_pred, - height=model_input.shape[2], - width=model_input.shape[3], - vae_scale_factor=vae.config.scaling_factor, + height=int(model_input.shape[2])*8, + width=int(model_input.shape[3])*8, + vae_scale_factor=16, #should this be 2 ** (len(vae.config.block_out_channels))? ) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. From ad1d236866a7e0ce5fa6eacd467bcb39cd9ef1e1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 16:38:37 +0300 Subject: [PATCH 22/46] fix license in readme --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 397488a3a6a3..6e93a652ab8e 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -118,7 +118,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="flux-1-dev-non-commercial-license", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 1819cea8d5d5..c3be14581df9 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -134,7 +134,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="flux-1-dev-non-commercial-license", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, From bcb752baffa2fd3af7f03247324394aab929d337 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 17:24:55 +0300 Subject: [PATCH 23/46] add "none" to weighting_scheme options for uniform sampling --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 6e93a652ab8e..1b8817944424 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -483,7 +483,7 @@ def parse_args(input_args=None): "--weighting_scheme", type=str, default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c3be14581df9..cf37bf40c13e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -496,7 +496,7 @@ def parse_args(input_args=None): "--weighting_scheme", type=str, default="logit_normal", - choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." From df880f3ccaf8821c997a2597cc8f64c9d544c1c0 Mon Sep 17 00:00:00 2001 From: Linoy Date: Wed, 7 Aug 2024 18:01:40 +0000 Subject: [PATCH 24/46] style --- examples/dreambooth/train_dreambooth_flux.py | 39 +++++++-------- .../dreambooth/train_dreambooth_lora_flux.py | 47 +++++++------------ 2 files changed, 34 insertions(+), 52 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 1b8817944424..f2f5bfa56ebb 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -47,8 +47,8 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - FluxTransformer2DModel, FluxPipeline, + FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 @@ -59,6 +59,7 @@ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -943,12 +944,12 @@ def encode_prompt( dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( - text_encoder=text_encoders[0], - tokenizer=tokenizers[0], - prompt=prompt, - device=device if device is not None else text_encoders[0].device, - num_images_per_prompt=num_images_per_prompt, - ) + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], @@ -1099,9 +1100,7 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1190,7 +1189,7 @@ def load_model_hook(models, input_dir): model(**load_model.config) model.load_state_dict(load_model.state_dict()) except Exception: - raise ValueError(f"Couldn't load the model of type: ({type(model)}).") + raise ValueError(f"Couldn't load the model of type: ({type(model)}).") else: raise ValueError(f"Unsupported model found: {type(model)=}") @@ -1593,9 +1592,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2])*8, - width=int(model_input.shape[3])*8, - vae_scale_factor=16, #should this be 2 ** (len(vae.config.block_out_channels))? + height=int(model_input.shape[2]) * 8, + width=int(model_input.shape[3]) * 8, + vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? ) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. @@ -1642,9 +1641,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain( - transformer.parameters(), - text_encoder_one.parameters(), - text_encoder_two.parameters() + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() ) if args.train_text_encoder else transformer.parameters() @@ -1697,9 +1694,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if not args.train_text_encoder: - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1738,9 +1733,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_2=text_encoder_two, ) else: - pipeline = FluxPipeline.from_pretrained( - args.pretrained_model_name_or_path, transformer=transformer - ) + pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer) # save the pipeline pipeline.save_pretrained(args.output_dir) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index cf37bf40c13e..ec326dbb8336 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -49,12 +49,11 @@ from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, - FluxTransformer2DModel, FluxPipeline, + FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -860,7 +859,6 @@ def __getitem__(self, index): return example - def _encode_prompt_with_t5( text_encoder, tokenizer, @@ -943,12 +941,12 @@ def encode_prompt( dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( - text_encoder=text_encoders[0], - tokenizer=tokenizers[0], - prompt=prompt, - device=device if device is not None else text_encoders[0].device, - num_images_per_prompt=num_images_per_prompt, - ) + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], @@ -1098,9 +1096,7 @@ def main(args): args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", @@ -1147,7 +1143,6 @@ def main(args): ) transformer.add_adapter(transformer_lora_config) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -1370,7 +1365,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) text_ids = torch.cat([text_ids, class_text_ids], dim=0) - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1390,8 +1384,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`. transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) + transformer, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1479,8 +1473,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) + prompts, text_encoders, tokenizers + ) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1544,9 +1538,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2])*8, - width=int(model_input.shape[3])*8, - vae_scale_factor=16, #should this be 2 ** (len(vae.config.block_out_channels))? + height=int(model_input.shape[2]) * 8, + width=int(model_input.shape[3]) * 8, + vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? ) # Follow: Section 5 of https://arxiv.org/abs/2206.00364. @@ -1591,7 +1585,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (transformer_lora_parameters) + params_to_clip = transformer_lora_parameters accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -1638,9 +1632,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - text_encoder_one, text_encoder_two = load_text_encoders( - text_encoder_cls_one, text_encoder_cls_two - ) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1671,10 +1663,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer = transformer.to(torch.float32) transformer_lora_layers = get_peft_model_state_dict(transformer) - FluxPipeline.save_lora_weights( - save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers - ) + FluxPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) # Final inference # Load previous pipeline From e69244adc1e21be2e3db26e48174979ebf993783 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 21:24:33 +0300 Subject: [PATCH 25/46] adapt model saving - remove text encoder refs --- examples/dreambooth/train_dreambooth_lora_flux.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index ec326dbb8336..cf9cc4190607 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1152,16 +1152,10 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None - text_encoder_one_lora_layers_to_save = None - text_encoder_two_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1171,7 +1165,6 @@ def save_model_hook(models, weights, output_dir): FluxPipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, ) def load_model_hook(models, input_dir): From 66125321f3ffe39d897c4b26837f00a563c64095 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 21:26:45 +0300 Subject: [PATCH 26/46] adapt model loading - remove text encoder refs --- examples/dreambooth/train_dreambooth_lora_flux.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index cf9cc4190607..81f30f21358c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1169,18 +1169,11 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - text_encoder_one_ = None - text_encoder_two_ = None - while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(transformer))): transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") From 155dbb2f6d55a469e6c5cc9bfbb1c4f81a70a773 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 7 Aug 2024 21:56:36 +0300 Subject: [PATCH 27/46] initial commit for readme --- examples/dreambooth/README_flux.md | 153 +++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 examples/dreambooth/README_flux.md diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md new file mode 100644 index 000000000000..dc3054e31579 --- /dev/null +++ b/examples/dreambooth/README_flux.md @@ -0,0 +1,153 @@ +# DreamBooth training example for FLUX.1 [dev] + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. + +The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script. + +> [!NOTE] +> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: + +```bash +huggingface-cli login +``` + +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run +```bash +pip install -r requirements_flux.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. + +Now, we can launch training using: + +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux" + +accelerate launch train_dreambooth_flux.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="fp16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +> [!NOTE] +> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. + +> [!TIP] +> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. + +## LoRA + DreamBooth + +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +To perform DreamBooth with LoRA, run: + +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux-lora" + +accelerate launch train_dreambooth_lora_sd3.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="fp16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-5 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +### Text Encoder Training (Coming Soon!) + + +## Other notes From 15261410fb191b2a98bec4db7a86fb7caa1c0eb3 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:39:54 +0300 Subject: [PATCH 28/46] Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 81f30f21358c..c0d30eb7c6fb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1367,7 +1367,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): power=args.lr_power, ) - # Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`. transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler From 6b78e1978d06da1a89a56095db5bd2ea74adeea2 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:41:11 +0300 Subject: [PATCH 29/46] Update examples/dreambooth/train_dreambooth_lora_flux.py Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c0d30eb7c6fb..975a6515c73b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1382,7 +1382,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux-lora" + tracker_name = "dreambooth-flux-dev-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! From a2ac0eb88eed6da147fe61f5449ec3108bc4134e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 11:39:21 +0300 Subject: [PATCH 30/46] fix vae casting --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index f2f5bfa56ebb..74256ad199b2 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1134,7 +1134,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=torch.bfloat16) + vae.to(accelerator.device, dtype=weight_dtype) if not args.train_text_encoder: text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 975a6515c73b..152c9bc22113 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1126,7 +1126,7 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=torch.bfloat16) + vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) From 7f0fe8a83247c621d2b32627f7fcf5d0932e7375 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 12:19:04 +0300 Subject: [PATCH 31/46] remove precondition_outputs --- examples/dreambooth/train_dreambooth_flux.py | 17 ++--------------- .../dreambooth/train_dreambooth_lora_flux.py | 16 +--------------- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 74256ad199b2..8aa06b6ea80f 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -498,13 +498,6 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) - parser.add_argument( - "--precondition_outputs", - type=int, - default=1, - help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " - "model `target` is calculated.", - ) parser.add_argument( "--optimizer", type=str, @@ -1597,20 +1590,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? ) - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - # Preconditioning of the model outputs. - if args.precondition_outputs: - model_pred = model_pred * (-sigmas) + noisy_model_input + model_pred = model_pred * (-sigmas) + noisy_model_input # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss - if args.precondition_outputs: - target = model_input - else: - target = noise - model_input + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 152c9bc22113..2770a90792e6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -509,13 +509,6 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) - parser.add_argument( - "--precondition_outputs", - type=int, - default=1, - help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " - "model `target` is calculated.", - ) parser.add_argument( "--optimizer", type=str, @@ -1528,20 +1521,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? ) - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - # Preconditioning of the model outputs. - if args.precondition_outputs: - model_pred = model_pred * (-sigmas) + noisy_model_input # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss - if args.precondition_outputs: - target = model_input - else: - target = noise - model_input + target = noise - model_input if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. From d0fb727768ccc75bcb3f6654a8fa9ad46dd35fb9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 12:54:41 +0300 Subject: [PATCH 32/46] readme --- examples/dreambooth/README_flux.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index dc3054e31579..0fe0a2bf47d2 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -3,8 +3,17 @@ [DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject. The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script. +> [!NOTE] +> **Memory consumption** +> +> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - +> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training. +> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](documentation/quickstart/FLUX.md) + > [!NOTE] +> **Gated model** +> > As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: ```bash @@ -147,7 +156,9 @@ accelerate launch train_dreambooth_lora_sd3.py \ --push_to_hub ``` -### Text Encoder Training (Coming Soon!) +> [!TODO] +> ### Text Encoder Training (Coming Soon!) ## Other notes +Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file From dcd26d1fc80910b69a2c523da89fbc12e91f5f56 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 12:58:03 +0300 Subject: [PATCH 33/46] readme --- examples/dreambooth/README_flux.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 0fe0a2bf47d2..982608a2d143 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -156,8 +156,10 @@ accelerate launch train_dreambooth_lora_sd3.py \ --push_to_hub ``` -> [!TODO] -> ### Text Encoder Training (Coming Soon!) +### Text Encoder Training +- [x] add text encoder training support for dreambooth script +- [ ] add text encoder training support for lora script + ## Other notes From 60c5b6507f9cae33028d79df13df6fd8002e055b Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 8 Aug 2024 10:01:11 +0000 Subject: [PATCH 34/46] style --- examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2770a90792e6..3592c2020fa6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1521,7 +1521,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? ) - # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) From aac9183b7dff338fe54036b6da70cfcf2f4d75a2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 13:14:11 +0300 Subject: [PATCH 35/46] readme --- examples/dreambooth/README_flux.md | 34 +++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 982608a2d143..99c3be193967 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -136,7 +136,7 @@ export MODEL_NAME="black-forest-labs/FLUX.1-dev" export INSTANCE_DIR="dog" export OUTPUT_DIR="trained-flux-lora" -accelerate launch train_dreambooth_lora_sd3.py \ +accelerate launch train_dreambooth_lora_flux.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ @@ -160,7 +160,39 @@ accelerate launch train_dreambooth_lora_sd3.py \ - [x] add text encoder training support for dreambooth script - [ ] add text encoder training support for lora script +Alongside the transformer, fine-tuning of the CLIP text encoder is also supported. +To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: +> [!NOTE] +> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL). +By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed. +> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled. + +To perform DreamBooth LoRA with text-encoder training, run: +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export OUTPUT_DIR="trained-flux-dev-dreambooth" + +accelerate launch train_dreambooth_flux.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="fp16" \ + --train_text_encoder\ + --instance_prompt="a photo of sks dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-5 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` ## Other notes Thanks to `bghira` for their help with reviewing & insight sharing ♥️ \ No newline at end of file From 56059d82b7d5cea7a6723b80bb022932e8a2d8b4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 13:22:37 +0300 Subject: [PATCH 36/46] readme --- examples/dreambooth/README_flux.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 99c3be193967..8c4f8b53a145 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -189,7 +189,6 @@ accelerate launch train_dreambooth_flux.py \ --lr_warmup_steps=0 \ --max_train_steps=500 \ --validation_prompt="A photo of sks dog in a bucket" \ - --validation_epochs=25 \ --seed="0" \ --push_to_hub ``` From 911306aa2665f2fad373535d71c5b1946ea5fbe9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 13:40:42 +0300 Subject: [PATCH 37/46] update weighting scheme default & docs --- examples/dreambooth/train_dreambooth_flux.py | 3 ++- examples/dreambooth/train_dreambooth_lora_flux.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 8aa06b6ea80f..8209251c9870 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -483,8 +483,9 @@ def parse_args(input_args=None): parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", + default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss') ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3592c2020fa6..77eebe9ca23b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -494,8 +494,9 @@ def parse_args(input_args=None): parser.add_argument( "--weighting_scheme", type=str, - default="logit_normal", + default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss') ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." From 8e4d2300d7f5afcf665ffa98a3ddccf042e79068 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 8 Aug 2024 10:42:09 +0000 Subject: [PATCH 38/46] style --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 8209251c9870..a44d877527fe 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -485,7 +485,7 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss') + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 77eebe9ca23b..5dc11c8c7bf3 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -496,7 +496,7 @@ def parse_args(input_args=None): type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help=('We default to the "none" weighting scheme for uniform sampling and uniform loss') + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." From 573026c2a057254dc004244da4449c1f7f976015 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 17:52:20 +0300 Subject: [PATCH 39/46] add text_encoder training to lora script, change vae_scale_factor value in both --- examples/dreambooth/train_dreambooth_flux.py | 6 +- .../dreambooth/train_dreambooth_lora_flux.py | 274 ++++++++++++++---- 2 files changed, 215 insertions(+), 65 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index a44d877527fe..22e1519f4438 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -138,7 +138,7 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) -def load_text_encoders(class_one, class_two, class_three): +def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) @@ -1417,7 +1417,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux" + tracker_name = "dreambooth-flux-dev-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1588,7 +1588,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred, height=int(model_input.shape[2]) * 8, width=int(model_input.shape[3]) * 8, - vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? + vae_scale_factor=2 ** (len(vae.config.block_out_channels)), # should this be 2 ** (len(vae.config.block_out_channels))? ) model_pred = model_pred * (-sigmas) + noisy_model_input diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5dc11c8c7bf3..ade19dc50275 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -394,7 +394,11 @@ def parse_args(input_args=None): action="store_true", help="whether to randomly flip images horizontally", ) - + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -458,6 +462,12 @@ def parse_args(input_args=None): help="the FLUX.1 dev variant is a guidance distilled model", ) + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -853,6 +863,20 @@ def __getitem__(self, index): return example +def tokenize_prompt(tokenizer, prompt, max_sequence_length=512): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + def _encode_prompt_with_t5( text_encoder, tokenizer, @@ -1085,6 +1109,7 @@ def main(args): text_encoder_cls_two = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) + # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" @@ -1101,12 +1126,16 @@ def main(args): args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) - transformer.requires_grad_(False) + transformer.requires_grad_(True) vae.requires_grad_(False) - text_encoder_one.requires_grad_(False) - text_encoder_two.requires_grad_(False) + if args.train_text_encoder: + text_encoder_one.requires_grad_(True) + text_encoder_two.requires_grad_(True) + else: + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) - # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -1121,12 +1150,15 @@ def main(args): ) vae.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device, dtype=weight_dtype) - text_encoder_one.to(accelerator.device, dtype=weight_dtype) - text_encoder_two.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -1146,10 +1178,13 @@ def unwrap_model(model): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1159,22 +1194,27 @@ def save_model_hook(models, weights, output_dir): FluxPipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, ) def load_model_hook(models, input_dir): transformer_ = None + text_encoder_one_ = None + while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(transformer))): transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict = FluxPipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") @@ -1186,12 +1226,17 @@ def load_model_hook(models, input_dir): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models) @@ -1211,14 +1256,29 @@ def load_model_hook(models, input_dir): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) # only upcast trainable parameters (LoRA) into fp32 cast_training_params(models, dtype=torch.float32) - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - # Optimization parameters - transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - params_to_optimize = [transformer_parameters_with_lr] + transformer_parameters_with_lr = {"params": transformer.parameters(), "lr": args.learning_rate} + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_parameters_one_with_lr = { + "params": text_encoder_one.parameters(), + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] + else: + params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): @@ -1266,6 +1326,16 @@ def load_model_hook(models, input_dir): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1299,32 +1369,37 @@ def load_model_hook(models, input_dir): num_workers=args.dataloader_num_workers, ) - tokenizers = [tokenizer_one, tokenizer_two] - text_encoders = [text_encoder_one, text_encoder_two] - - def compute_text_embeddings(prompt, text_encoders, tokenizers): - with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( - text_encoders, tokenizers, prompt, args.max_sequence_length - ) - prompt_embeds = prompt_embeds.to(accelerator.device) - pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) - text_ids = text_ids.to(accelerator.device) - return prompt_embeds, pooled_prompt_embeds, text_ids + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] - if not train_dataset.custom_instance_prompts: + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( - args.class_prompt, text_encoders, tokenizers - ) + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) # Clear the memory here - if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two @@ -1337,13 +1412,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # have to pass them to the dataloader. if not train_dataset.custom_instance_prompts: - prompt_embeds = instance_prompt_hidden_states - pooled_prompt_embeds = instance_pooled_prompt_embeds - text_ids = instance_text_ids - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) - text_ids = torch.cat([text_ids, class_text_ids], dim=0) + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1362,9 +1448,26 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) # Prepare everything with our `accelerator`. - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + ( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1442,18 +1545,26 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() + if args.train_text_encoder: + text_encoder_one.train() for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) + if not args.train_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) + tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1504,24 +1615,46 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): guidance = None # Predict the noise residual - model_pred = transformer( - hidden_states=packed_noisy_model_input, - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) - timestep=timesteps / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - return_dict=False, - )[0] + if not args.train_text_encoder: + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + else: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( model_pred, height=int(model_input.shape[2]) * 8, width=int(model_input.shape[3]) * 8, - vae_scale_factor=16, # should this be 2 ** (len(vae.config.block_out_channels))? + vae_scale_factor=2 ** (len(vae.config.block_out_channels)), # should this be 2 ** (len(vae.config.block_out_channels))? ) + model_pred = model_pred * (-sigmas) + noisy_model_input + # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) @@ -1556,7 +1689,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = transformer_lora_parameters + params_to_clip = ( + itertools.chain( + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + ) + if args.train_text_encoder + else transformer.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -1603,7 +1742,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + # create pipeline + if not args.train_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1622,10 +1763,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline_args=pipeline_args, epoch=epoch, ) - del text_encoder_one, text_encoder_two - - torch.cuda.empty_cache() - gc.collect() + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two + torch.cuda.empty_cache() + gc.collect() # Save the lora layers accelerator.wait_for_everyone() @@ -1634,7 +1775,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer = transformer.to(torch.float32) transformer_lora_layers = get_peft_model_state_dict(transformer) - FluxPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + else: + text_encoder_lora_layers = None + + FluxPipeline.save_lora_weights(save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_one_lora_layers=text_encoder_lora_layers) # Final inference # Load previous pipeline @@ -1665,6 +1814,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): repo_id, images=images, base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, From aea5d1fb9c20c07ad56db64c9319eb31c24b6517 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 8 Aug 2024 14:53:00 +0000 Subject: [PATCH 40/46] style --- examples/dreambooth/train_dreambooth_flux.py | 5 ++++- examples/dreambooth/train_dreambooth_lora_flux.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 22e1519f4438..f3742530785e 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1588,7 +1588,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred, height=int(model_input.shape[2]) * 8, width=int(model_input.shape[3]) * 8, - vae_scale_factor=2 ** (len(vae.config.block_out_channels)), # should this be 2 ** (len(vae.config.block_out_channels))? + vae_scale_factor=2 + ** ( + len(vae.config.block_out_channels) + ), # should this be 2 ** (len(vae.config.block_out_channels))? ) model_pred = model_pred * (-sigmas) + noisy_model_input diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index ade19dc50275..2237b57eecc7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1650,7 +1650,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred, height=int(model_input.shape[2]) * 8, width=int(model_input.shape[3]) * 8, - vae_scale_factor=2 ** (len(vae.config.block_out_channels)), # should this be 2 ** (len(vae.config.block_out_channels))? + vae_scale_factor=2 + ** ( + len(vae.config.block_out_channels) + ), # should this be 2 ** (len(vae.config.block_out_channels))? ) model_pred = model_pred * (-sigmas) + noisy_model_input @@ -1781,9 +1784,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: text_encoder_lora_layers = None - FluxPipeline.save_lora_weights(save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers, - text_encoder_one_lora_layers=text_encoder_lora_layers) + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_one_lora_layers=text_encoder_lora_layers, + ) # Final inference # Load previous pipeline From bde7dedc0a397d61394293426a47cee6a7700246 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 18:39:05 +0300 Subject: [PATCH 41/46] text encoder training fixes --- examples/dreambooth/train_dreambooth_flux.py | 14 +++++++------- examples/dreambooth/train_dreambooth_lora_flux.py | 7 ++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 22e1519f4438..d2890af27bf5 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1109,7 +1109,7 @@ def main(args): vae.requires_grad_(False) if args.train_text_encoder: text_encoder_one.requires_grad_(True) - text_encoder_two.requires_grad_(True) + text_encoder_two.requires_grad_(False) else: text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) @@ -1137,7 +1137,6 @@ def main(args): transformer.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() - text_encoder_two.gradient_checkpointing_enable() def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1390,14 +1389,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ( transformer, text_encoder_one, - text_encoder_two, optimizer, train_dataloader, lr_scheduler, ) = accelerator.prepare( transformer, text_encoder_one, - text_encoder_two, optimizer, train_dataloader, lr_scheduler, @@ -1629,7 +1626,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain( - transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + transformer.parameters(), text_encoder_one.parameters() ) if args.train_text_encoder else transformer.parameters() @@ -1683,6 +1680,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + else: # even when training the text encoder we're only training text encoder one + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, + variant=args.variant + ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -1713,12 +1715,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) - text_encoder_two = unwrap_model(text_encoder_two) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder_one, - text_encoder_2=text_encoder_two, ) else: pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index ade19dc50275..3d6e9d56315e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1130,7 +1130,7 @@ def main(args): vae.requires_grad_(False) if args.train_text_encoder: text_encoder_one.requires_grad_(True) - text_encoder_two.requires_grad_(True) + text_encoder_two.requires_grad_(False) else: text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) @@ -1158,7 +1158,6 @@ def main(args): transformer.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() - text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -1452,14 +1451,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ( transformer, text_encoder_one, - text_encoder_two, optimizer, train_dataloader, lr_scheduler, ) = accelerator.prepare( transformer, text_encoder_one, - text_encoder_two, optimizer, train_dataloader, lr_scheduler, @@ -1691,7 +1688,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.sync_gradients: params_to_clip = ( itertools.chain( - transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + transformer.parameters(), text_encoder_one.parameters() ) if args.train_text_encoder else transformer.parameters() From f7816303748890c021136de93574234235142277 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 8 Aug 2024 15:41:15 +0000 Subject: [PATCH 42/46] style --- examples/dreambooth/train_dreambooth_flux.py | 12 ++++++------ examples/dreambooth/train_dreambooth_lora_flux.py | 4 +--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 2e34b65b67fb..2877651f6ecb 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1628,9 +1628,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain( - transformer.parameters(), text_encoder_one.parameters() - ) + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) if args.train_text_encoder else transformer.parameters() ) @@ -1683,10 +1681,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) - else: # even when training the text encoder we're only training text encoder one + else: # even when training the text encoder we're only training text encoder one text_encoder_two = class_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, - variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 299cb1ce568b..aafe96a102ad 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1690,9 +1690,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( - itertools.chain( - transformer.parameters(), text_encoder_one.parameters() - ) + itertools.chain(transformer.parameters(), text_encoder_one.parameters()) if args.train_text_encoder else transformer.parameters() ) From d77b67f6dd59078abd373a16eae66a6f16a69f2f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 18:47:59 +0300 Subject: [PATCH 43/46] update readme --- examples/dreambooth/README_flux.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 8c4f8b53a145..4c0ba7fbaa5e 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -157,8 +157,6 @@ accelerate launch train_dreambooth_lora_flux.py \ ``` ### Text Encoder Training -- [x] add text encoder training support for dreambooth script -- [ ] add text encoder training support for lora script Alongside the transformer, fine-tuning of the CLIP text encoder is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: @@ -171,9 +169,9 @@ By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is perfo To perform DreamBooth LoRA with text-encoder training, run: ```bash export MODEL_NAME="black-forest-labs/FLUX.1-dev" -export OUTPUT_DIR="trained-flux-dev-dreambooth" +export OUTPUT_DIR="trained-flux-dev-dreambooth-lora" -accelerate launch train_dreambooth_flux.py \ +accelerate launch train_dreambooth_lora_flux.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --output_dir=$OUTPUT_DIR \ From 1d8e25f8541b3e20db0e31122d81cb46fbfbbcf9 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 18:56:41 +0300 Subject: [PATCH 44/46] minor fixes --- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 2877651f6ecb..66e7a4e97f69 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1682,7 +1682,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) else: # even when training the text encoder we're only training text encoder one - text_encoder_two = class_two.from_pretrained( + text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index aafe96a102ad..1f00b96632c1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -54,6 +54,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, From dc1b10e34856ccd5ddf29e72cc5b5198c4c9545b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 20:21:07 +0300 Subject: [PATCH 45/46] fix te params --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 1f00b96632c1..9439bd9576f4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1269,7 +1269,7 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: # different learning rate for text encoder and unet text_parameters_one_with_lr = { - "params": text_encoder_one.parameters(), + "params": text_lora_parameters_one, "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } From 569f2e172299dd753fbd959b70eac354c15736d5 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 8 Aug 2024 20:51:19 +0300 Subject: [PATCH 46/46] fix te params --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9439bd9576f4..e32004a8d82c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1127,7 +1127,7 @@ def main(args): args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) - transformer.requires_grad_(True) + transformer.requires_grad_(False) vae.requires_grad_(False) if args.train_text_encoder: text_encoder_one.requires_grad_(True)