From 66a51eda09acca2faef97506277902ffa549d356 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 5 Sep 2022 11:42:35 +0530 Subject: [PATCH 01/54] begin text2image script --- examples/text_to_image/train_text_to_image.py | 451 ++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 examples/text_to_image/train_text_to_image.py diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py new file mode 100644 index 000000000000..8ed2388be477 --- /dev/null +++ b/examples/text_to_image/train_text_to_image.py @@ -0,0 +1,451 @@ +import argparse +import math +import os +from pathlib import Path +from typing import Optional + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +import PIL +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from huggingface_hub import HfFolder, Repository, whoami +from torchvision import transforms +from torchvision.io import ImageReadMode, read_image +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=5000, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=True, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--use_auth_token", + action="store_true", + help=( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" + " private models)." + ), + ) + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def freeze_params(params): + for param in params: + param.requires_grad = False + + +dataset_name_mapping = { + "image_caption_dataset.py": ("image_path", "caption"), +} + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + use_auth_token=args.use_auth_token, + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token + ) + + # Freeze vae and text_encoder + freeze_params(vae.parameters()) + freeze_params(text_encoder.parameters()) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + optimizer = torch.optim.AdamW( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # TODO (patil-suraj): laod scheduler using args + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + ) + + if args.dataset_name is not None: + train_dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + use_auth_token=True if args.use_auth_token else None, + split="train", + ) + else: + train_dataset = load_dataset( + "imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train" + ) + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = train_dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples): + captions = [caption for caption in examples[caption_column]] + examples["input_ids"] = tokenizer( + captions, max_length=tokenizer.model_max_length, padding=True, truncation=True + ).input_ids + return examples + + preprocess = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=PIL.InterpolationMode.BILINEAR), + # transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + lambda tensor: tensor.to(torch.uint8), # TODO: why ? + lambda tensor: (tensor / 127.5 - 1.0).to(torch.float32), + ] + ) + + def preprocess_images(examples): + images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] + examples["pixel_values"] = [preprocess(image) for image in images] + return examples + + train_dataset.set_transform(tokenize_captions) + train_dataset.set_transform(preprocess_images) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long) + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + } + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # Move vae and unet to device + vae.to(accelerator.device) + text_encoder.to(accelerator.device) + + # Keep vae and unet in eval model as we don't train these + vae.eval() + text_encoder.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual and compute loss + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline( + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub( + args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From d062da6a95ce792df55a9188c66050a94db6d909 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 5 Sep 2022 12:51:02 +0200 Subject: [PATCH 02/54] loading the datasets, preprocessing & transforms --- examples/text_to_image/train_text_to_image.py | 148 ++++++++++++++---- 1 file changed, 116 insertions(+), 32 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8ed2388be477..56d7e6ba848f 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -36,13 +36,47 @@ def parse_args(): help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( - "--tokenizer_name", + "--dataset_name", type=str, default=None, - help="Pretrained tokenizer name or path if not the same as model_name", + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset)." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument( + "--validation_data_dir", type=str, default=None, help="A folder containing the validation data." + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), ) parser.add_argument( - "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + "--max_eval_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ), + ) + parser.add_argument( + "--train_val_split", + type=float, + default=0.15, + help="Percent to split off of train for validation", ) parser.add_argument( "--output_dir", @@ -50,6 +84,12 @@ def parse_args(): default="sd-model-finetuned", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", @@ -61,11 +101,16 @@ def parse_args(): ), ) parser.add_argument( - "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution (if not set, use random crop)", ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." + ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", @@ -150,8 +195,9 @@ def parse_args(): if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank - if args.train_data_dir is None: - raise ValueError("You must specify a train data directory.") + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None: + raise ValueError("Need either a dataset name or a training/validation folder.") return args @@ -208,17 +254,12 @@ def main(): elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) - # Load the tokenizer and add the placeholder token as a additional special token - if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) - elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - use_auth_token=args.use_auth_token, - ) - # Load models and create wrapper for stable diffusion + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + use_auth_token=args.use_auth_token, + ) text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token ) @@ -247,27 +288,48 @@ def main(): eps=args.adam_epsilon, ) - # TODO (patil-suraj): laod scheduler using args + # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" ) + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. if args.dataset_name is not None: - train_dataset = load_dataset( + # Downloading and loading a dataset from the hub. + dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, use_auth_token=True if args.use_auth_token else None, - split="train", ) else: - train_dataset = load_dataset( - "imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train" + data_files = {} + if args.train_dir is not None: + data_files["train"] = os.path.join(args.train_dir, "**") + if args.validation_dir is not None: + data_files["validation"] = os.path.join(args.validation_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_process#imagefolder. + + # If we don't have a validation split, split off a percentage of train as validation. + args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split + if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: + split = dataset["train"].train_test_split(args.train_val_split) + dataset["train"] = split["train"] + dataset["validation"] = split["test"] # Preprocessing the datasets. # We need to tokenize inputs and targets. - column_names = train_dataset["train"].column_names + column_names = dataset["train"].column_names # 6. Get the column names for input/target. dataset_columns = dataset_name_mapping.get(args.dataset_name, None) @@ -297,23 +359,42 @@ def tokenize_captions(examples): ).input_ids return examples - preprocess = transforms.Compose( + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=PIL.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + val_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=PIL.InterpolationMode.BILINEAR), - # transforms.CenterCrop(args.resolution), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), - lambda tensor: tensor.to(torch.uint8), # TODO: why ? - lambda tensor: (tensor / 127.5 - 1.0).to(torch.float32), + transforms.Normalize([0.5], [0.5]), ] ) - def preprocess_images(examples): + def preprocess_train(examples): images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] - examples["pixel_values"] = [preprocess(image) for image in images] + examples["pixel_values"] = [train_transforms(image) for image in images] return examples - train_dataset.set_transform(tokenize_captions) - train_dataset.set_transform(preprocess_images) + def preprocess_val(examples): + images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] + examples["pixel_values"] = [val_transforms(image) for image in images] + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(tokenize_captions).with_transform(preprocess_train) + if args.max_eval_samples is not None: + dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) + # Set the validation transforms + eval_dataset = dataset["validation"].with_transform(tokenize_captions).with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -324,7 +405,10 @@ def collate_fn(examples): } train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = torch.utils.data.DataLoader( + eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size ) # Scheduler and math around the number of training steps. @@ -424,7 +508,7 @@ def collate_fn(examples): accelerator.wait_for_everyone() - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline( text_encoder=accelerator.unwrap_model(text_encoder), From 3ed3a34265910a20217de545e49ff8fa48c49190 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 5 Sep 2022 14:38:23 +0200 Subject: [PATCH 03/54] handle input features correctly --- examples/text_to_image/train_text_to_image.py | 63 +++++++++++++------ 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 56d7e6ba848f..c396ae8acfb3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,9 +1,11 @@ import argparse import math import os +import random from pathlib import Path from typing import Optional +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -54,6 +56,15 @@ def parse_args(): parser.add_argument( "--validation_data_dir", type=str, default=None, help="A folder containing the validation data." ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) parser.add_argument( "--max_train_samples", type=int, @@ -352,16 +363,24 @@ def main(): # Preprocessing the datasets. # We need to tokenize input captions and transform the images. - def tokenize_captions(examples): - captions = [caption for caption in examples[caption_column]] - examples["input_ids"] = tokenizer( - captions, max_length=tokenizer.model_max_length, padding=True, truncation=True - ).input_ids - return examples + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids + return input_ids train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=PIL.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -369,7 +388,7 @@ def tokenize_captions(examples): ) val_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=PIL.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -377,39 +396,47 @@ def tokenize_captions(examples): ) def preprocess_train(examples): - images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] + images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples def preprocess_val(examples): - images = [read_image(image_file, mode=ImageReadMode.RGB) for image_file in examples[image_column]] + images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [val_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples, is_train=False) return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms - train_dataset = dataset["train"].with_transform(tokenize_captions).with_transform(preprocess_train) + train_dataset = dataset["train"].with_transform(preprocess_train) if args.max_eval_samples is not None: dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) # Set the validation transforms - eval_dataset = dataset["validation"].with_transform(tokenize_captions).with_transform(preprocess_val) + eval_dataset = dataset["validation"].with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) - input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long) + input_ids = [example["input_ids"] for example in examples] + padded_tokens = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ) return { "pixel_values": pixel_values, - "input_ids": input_ids, + "input_ids": padded_tokens.input_ids, + "attention_mask": padded_tokens.attention_mask, } train_dataloader = torch.utils.data.DataLoader( - train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.per_device_train_batch_size - ) - eval_dataloader = torch.utils.data.DataLoader( - eval_dataset, collate_fn=collate_fn, batch_size=args.per_device_eval_batch_size + train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size ) + eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size) # Scheduler and math around the number of training steps. overrode_max_train_steps = False From ce569a1053bda8f726b98e50af866b6c174404e0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 22 Sep 2022 17:05:26 +0200 Subject: [PATCH 04/54] add gradient checkpointing support --- examples/text_to_image/train_text_to_image.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c396ae8acfb3..d7e06962c9ee 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -10,7 +10,6 @@ import torch.nn.functional as F import torch.utils.checkpoint -import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -20,7 +19,6 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from torchvision import transforms -from torchvision.io import ImageReadMode, read_image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -135,6 +133,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -285,6 +288,9 @@ def main(): freeze_params(vae.parameters()) freeze_params(text_encoder.parameters()) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes @@ -416,7 +422,7 @@ def preprocess_val(examples): if args.max_eval_samples is not None: dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) # Set the validation transforms - eval_dataset = dataset["validation"].with_transform(preprocess_val) + # eval_dataset = dataset["validation"].with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -436,7 +442,7 @@ def collate_fn(examples): train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size ) - eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size) + # eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size) # Scheduler and math around the number of training steps. overrode_max_train_steps = False From 837a586a221f9887bccccf5b38a3db20a39b9139 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 22 Sep 2022 18:08:30 +0200 Subject: [PATCH 05/54] fix output names --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d7e06962c9ee..11c321cfcfdc 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -502,7 +502,7 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).sample().detach() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -519,7 +519,7 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual and compute loss - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) From 3893029994b7089016a4ebd30048d60bb26a66f9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 22 Sep 2022 18:29:38 +0200 Subject: [PATCH 06/54] run unet in train mode not text encoder --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 11c321cfcfdc..50d2bac87ebf 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -498,7 +498,7 @@ def collate_fn(examples): global_step = 0 for epoch in range(args.num_train_epochs): - text_encoder.train() + unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space From 61513b04a1f17b5790284045e239b421b205ee9d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 11:31:08 +0200 Subject: [PATCH 07/54] use no_grad instead of freezing params --- examples/text_to_image/train_text_to_image.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 50d2bac87ebf..8bee545c1b69 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -285,8 +285,8 @@ def main(): ) # Freeze vae and text_encoder - freeze_params(vae.parameters()) - freeze_params(text_encoder.parameters()) + # freeze_params(vae.parameters()) + # freeze_params(text_encoder.parameters()) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -458,13 +458,13 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler + text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler ) # Move vae and unet to device - vae.to(accelerator.device) - text_encoder.to(accelerator.device) + # vae.to(accelerator.device) + # text_encoder.to(accelerator.device) # Keep vae and unet in eval model as we don't train these vae.eval() @@ -502,7 +502,8 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + with torch.no_grad(): + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -516,7 +517,8 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + with torch.no_grad(): + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -545,8 +547,8 @@ def collate_fn(examples): if accelerator.is_main_process: pipeline = StableDiffusionPipeline( text_encoder=accelerator.unwrap_model(text_encoder), - vae=vae, - unet=unet, + vae=accelerator.unwrap_model(vae), + unet=accelerator.unwrap_model(unet), tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True From ed8f4dd7f1f85db3ea07592177e5fab2b6a9c3da Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 11:42:08 +0200 Subject: [PATCH 08/54] default max steps None --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8bee545c1b69..2ee1fd8587ef 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -124,7 +124,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( From e4fb47845d1f554bcf8a10b403bcf9aebc4d6dde Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 11:47:51 +0200 Subject: [PATCH 09/54] pad to longest --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2ee1fd8587ef..c8f6a4b2518b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -429,7 +429,7 @@ def collate_fn(examples): input_ids = [example["input_ids"] for example in examples] padded_tokens = tokenizer.pad( {"input_ids": input_ids}, - padding="max_length", + padding=True, max_length=tokenizer.model_max_length, return_tensors="pt", ) From 7414de10c3993be19100e66a6223eaaad766d50b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 11:52:45 +0200 Subject: [PATCH 10/54] don't pad when tokenizing --- examples/text_to_image/train_text_to_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c8f6a4b2518b..4f776d190e89 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -381,7 +381,9 @@ def tokenize_captions(examples, is_train=True): raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) - input_ids = tokenizer(captions, max_length=tokenizer.model_max_length, padding=True, truncation=True).input_ids + input_ids = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True + ).input_ids return input_ids train_transforms = transforms.Compose( From ce4a7a2d8da223170923c187c0f3b09966cc938e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 12:15:57 +0200 Subject: [PATCH 11/54] fix encode on multi gpu --- examples/text_to_image/train_text_to_image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4f776d190e89..dd3ea850fba6 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -505,7 +505,10 @@ def collate_fn(examples): with accelerator.accumulate(unet): # Convert images to latent space with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + if accelerator.num_processes > 1: + vae.module.encode(batch["pixel_values"]).latent_dist.sample().detach() + else: + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents From 95d78360b4d2c89c4b0b0c73f5bcc4a5da7f85c7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 12:18:29 +0200 Subject: [PATCH 12/54] fix stupid bug --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index dd3ea850fba6..1a1dd4e19176 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -506,7 +506,7 @@ def collate_fn(examples): # Convert images to latent space with torch.no_grad(): if accelerator.num_processes > 1: - vae.module.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = vae.module.encode(batch["pixel_values"]).latent_dist.sample().detach() else: latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 From 54b700d85abeae0bb22cbe6af6933285d38208bc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 15:25:29 +0200 Subject: [PATCH 13/54] add random flip --- examples/text_to_image/train_text_to_image.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1a1dd4e19176..c554f5d3ef3e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -114,6 +114,11 @@ def parse_args(): action="store_true", help="Whether to center crop images before resizing to resolution (if not set, use random crop)", ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -390,6 +395,7 @@ def tokenize_captions(examples, is_train=True): [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] From 725fb9646a7f32bc2931acdc805b9af24c6b09ba Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 16:21:57 +0200 Subject: [PATCH 14/54] add ema --- examples/text_to_image/train_text_to_image.py | 68 ++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c554f5d3ef3e..860a46b79d98 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,4 +1,5 @@ import argparse +import copy import math import os import random @@ -167,6 +168,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -241,6 +243,65 @@ def freeze_params(params): } +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + model, + decay=0.9999, + device=None, + ): + self.averaged_model = copy.deepcopy(model).eval() + self.averaged_model.requires_grad_(False) + + self.decay = decay + + if device is not None: + self.averaged_model = self.averaged_model.to(device=device) + + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + value = (1 + optimization_step) / (10 + optimization_step) + return 1 - min(self.decay, value) + + @torch.no_grad() + def step(self, new_model): + ema_state_dict = {} + ema_params = self.averaged_model.state_dict() + + self.optimization_step += 1 + self.decay = self.get_decay(self.optimization_step) + + for key, param in new_model.named_parameters(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ema_params[key] = ema_param + + if not param.requires_grad: + ema_param.sub_(self.decay * (ema_param - param.data.to(dtype=ema_param.dtype))) + else: + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_param = ema_params[key] + + ema_state_dict[key] = ema_param + + for key, param in new_model.named_buffers(): + ema_state_dict[key] = param + + self.averaged_model.load_state_dict(ema_state_dict, strict=False) + + def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) @@ -470,6 +531,9 @@ def collate_fn(examples): text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler ) + if args.use_ema: + ema_unet = EMAModel(unet) + # Move vae and unet to device # vae.to(accelerator.device) # text_encoder.to(accelerator.device) @@ -544,6 +608,8 @@ def collate_fn(examples): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + if args.use_ema: + ema_unet.step(unet) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -559,7 +625,7 @@ def collate_fn(examples): pipeline = StableDiffusionPipeline( text_encoder=accelerator.unwrap_model(text_encoder), vae=accelerator.unwrap_model(vae), - unet=accelerator.unwrap_model(unet), + unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet), tokenizer=tokenizer, scheduler=PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True From 584b3f7be9d0df04a0ed8d81d405af0361429a12 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 16:30:09 +0200 Subject: [PATCH 15/54] fix ema --- examples/text_to_image/train_text_to_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 860a46b79d98..082383e76b30 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -288,10 +288,11 @@ def step(self, new_model): ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) ema_params[key] = ema_param - if not param.requires_grad: + if param.requires_grad: + param = param.data.to(dtype=ema_param.dtype).to(device=ema_param.device) ema_param.sub_(self.decay * (ema_param - param.data.to(dtype=ema_param.dtype))) else: - ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data.to(device=ema_param.device)) ema_param = ema_params[key] ema_state_dict[key] = ema_param From 0f0b09869c213b4833e224f4a55e0b30aa037818 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 16:35:30 +0200 Subject: [PATCH 16/54] put ema on cpu --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 082383e76b30..98aee00de093 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -533,7 +533,7 @@ def collate_fn(examples): ) if args.use_ema: - ema_unet = EMAModel(unet) + ema_unet = EMAModel(unet, device="cpu") # Move vae and unet to device # vae.to(accelerator.device) From 56a9fd0d934ada2be787b010be62a5cc22064dcc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 20:31:40 +0200 Subject: [PATCH 17/54] improve EMA model --- examples/text_to_image/train_text_to_image.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 98aee00de093..2fb985ea38c5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -273,8 +273,7 @@ def get_decay(self, optimization_step): @torch.no_grad() def step(self, new_model): - ema_state_dict = {} - ema_params = self.averaged_model.state_dict() + ema_state_dict = self.averaged_model.state_dict() self.optimization_step += 1 self.decay = self.get_decay(self.optimization_step) @@ -283,24 +282,23 @@ def step(self, new_model): if isinstance(param, dict): continue try: - ema_param = ema_params[key] + ema_param = ema_state_dict[key] except KeyError: ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) - ema_params[key] = ema_param + ema_state_dict[key] = ema_param + + param = param.clone().detach().to(ema_param.dtype) if param.requires_grad: - param = param.data.to(dtype=ema_param.dtype).to(device=ema_param.device) - ema_param.sub_(self.decay * (ema_param - param.data.to(dtype=ema_param.dtype))) + ema_state_dict[key].sub_(self.decay * (ema_param - param)) else: - ema_params[key].copy_(param.to(dtype=ema_param.dtype).data.to(device=ema_param.device)) - ema_param = ema_params[key] - - ema_state_dict[key] = ema_param + ema_state_dict[key].copy_(param) for key, param in new_model.named_buffers(): ema_state_dict[key] = param self.averaged_model.load_state_dict(ema_state_dict, strict=False) + torch.cuda.empty_cache() def main(): @@ -533,7 +531,7 @@ def collate_fn(examples): ) if args.use_ema: - ema_unet = EMAModel(unet, device="cpu") + ema_unet = EMAModel(unet) # Move vae and unet to device # vae.to(accelerator.device) From 5c0540130accccfd5a8dd54c736bdf27d2203158 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 21:29:39 +0200 Subject: [PATCH 18/54] contiguous_format --- examples/text_to_image/train_text_to_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2fb985ea38c5..c3ea58d84273 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -287,7 +287,7 @@ def step(self, new_model): ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) ema_state_dict[key] = ema_param - param = param.clone().detach().to(ema_param.dtype) + param = param.clone().detach().to(ema_param.dtype).to(ema_param.device) if param.requires_grad: ema_state_dict[key].sub_(self.decay * (ema_param - param)) @@ -494,6 +494,7 @@ def preprocess_val(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = [example["input_ids"] for example in examples] padded_tokens = tokenizer.pad( {"input_ids": input_ids}, From ad42acb515256e0ea8a8cd7178652c9865ecf9c6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 22:29:02 +0200 Subject: [PATCH 19/54] don't warp vae and text encode in accelerate --- examples/text_to_image/train_text_to_image.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c3ea58d84273..a1b294d1cde1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -350,8 +350,8 @@ def main(): ) # Freeze vae and text_encoder - # freeze_params(vae.parameters()) - # freeze_params(text_encoder.parameters()) + freeze_params(vae.parameters()) + freeze_params(text_encoder.parameters()) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -527,20 +527,16 @@ def collate_fn(examples): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, vae, unet, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) if args.use_ema: ema_unet = EMAModel(unet) # Move vae and unet to device - # vae.to(accelerator.device) - # text_encoder.to(accelerator.device) - - # Keep vae and unet in eval model as we don't train these - vae.eval() - text_encoder.eval() + vae.to(accelerator.device) + text_encoder.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -575,10 +571,7 @@ def collate_fn(examples): with accelerator.accumulate(unet): # Convert images to latent space with torch.no_grad(): - if accelerator.num_processes > 1: - latents = vae.module.encode(batch["pixel_values"]).latent_dist.sample().detach() - else: - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -623,8 +616,8 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. if accelerator.is_main_process: pipeline = StableDiffusionPipeline( - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + text_encoder=text_encoder, + vae=vae, unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet), tokenizer=tokenizer, scheduler=PNDMScheduler( From 4e54ae27ffba2cbd1a9140ad4dba349ab31a9d97 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 22:31:18 +0200 Subject: [PATCH 20/54] remove no_grad --- examples/text_to_image/train_text_to_image.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a1b294d1cde1..409ae33f0380 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -570,8 +570,7 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -585,8 +584,7 @@ def collate_fn(examples): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample From 9cf8d2bf90c8b221573d8b815145354a95388f33 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 23 Sep 2022 22:32:17 +0200 Subject: [PATCH 21/54] use randn_like --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 409ae33f0380..8860b0350c2d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -574,7 +574,7 @@ def collate_fn(examples): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() From 2feec1908de66b43bc602cd958d63918d22bd00c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 24 Sep 2022 18:25:11 +0200 Subject: [PATCH 22/54] fix resize --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8860b0350c2d..4820eb7fac1b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -453,7 +453,7 @@ def tokenize_captions(examples, is_train=True): train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), From 7044b2d72f407536315b3048543c0c24775cd3b3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 24 Sep 2022 19:56:34 +0200 Subject: [PATCH 23/54] improve few things --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4820eb7fac1b..61897ede0bc6 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -570,7 +570,7 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -588,7 +588,7 @@ def collate_fn(examples): # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred, noise, reduction="mean") accelerator.backward(loss) optimizer.step() From 4809770bff4c1a3f3f49c5381303c13d23c57225 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 16:33:18 +0200 Subject: [PATCH 24/54] log epoch loss --- examples/text_to_image/train_text_to_image.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 61897ede0bc6..085501629761 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -567,6 +567,7 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() + total_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space @@ -589,6 +590,7 @@ def collate_fn(examples): # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(noise_pred, noise, reduction="mean") + total_loss += loss.detach().float() accelerator.backward(loss) optimizer.step() @@ -609,9 +611,11 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - accelerator.wait_for_everyone() + logger.info(f"epoch {epoch}: train_loss: {total_loss.item() / len(train_dataloader)}") + accelerator.log({"epoch": epoch, "train_loss": total_loss.item() / len(train_dataloader)}, step=global_step) # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() if accelerator.is_main_process: pipeline = StableDiffusionPipeline( text_encoder=text_encoder, From fdfbad331cebc1e44609a331ae9ba27c0b4fb7c1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 16:44:27 +0200 Subject: [PATCH 25/54] set log level --- examples/text_to_image/train_text_to_image.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 085501629761..049d189969bd 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1,5 +1,6 @@ import argparse import copy +import logging import math import os import random @@ -312,6 +313,12 @@ def main(): logging_dir=logging_dir, ) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) From 03d124bca77cf3f60f9a3c4cefdff2be35bdcd47 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 16:47:50 +0200 Subject: [PATCH 26/54] don't log each step --- examples/text_to_image/train_text_to_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 049d189969bd..cc791c79e626 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -613,7 +613,6 @@ def collate_fn(examples): logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break From abebd23cf82bc70fdde672d018faab73a02f1ea3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 17:38:50 +0200 Subject: [PATCH 27/54] remove max_length from collate --- examples/text_to_image/train_text_to_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index cc791c79e626..c35898e1594d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -506,7 +506,6 @@ def collate_fn(examples): padded_tokens = tokenizer.pad( {"input_ids": input_ids}, padding=True, - max_length=tokenizer.model_max_length, return_tensors="pt", ) return { From 47798197a475c7b08539fd25687107b7c6915f07 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 17:40:33 +0200 Subject: [PATCH 28/54] style --- examples/text_to_image/train_text_to_image.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c35898e1594d..1953eca67917 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -503,11 +503,7 @@ def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = [example["input_ids"] for example in examples] - padded_tokens = tokenizer.pad( - {"input_ids": input_ids}, - padding=True, - return_tensors="pt", - ) + padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") return { "pixel_values": pixel_values, "input_ids": padded_tokens.input_ids, From c6ad72392773ad8bced3ce2d7090886223e19389 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 17:43:09 +0200 Subject: [PATCH 29/54] add report_to option --- examples/text_to_image/train_text_to_image.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1953eca67917..08c3dc2f1b9f 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -210,6 +210,16 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() @@ -309,7 +319,7 @@ def main(): accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with="tensorboard", + log_with=args.report_to, logging_dir=logging_dir, ) From f4cd6ff53ac53091344b840d6f1e70da3854b9eb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 19:00:18 +0200 Subject: [PATCH 30/54] make scale_lr false by default --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 08c3dc2f1b9f..9deb1e7fd160 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -154,7 +154,7 @@ def parse_args(): parser.add_argument( "--scale_lr", action="store_true", - default=True, + default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( From 4cc238dc0c7060cd0aac57800db613ef2176da46 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 1 Oct 2022 19:43:35 +0200 Subject: [PATCH 31/54] add grad clipping --- examples/text_to_image/train_text_to_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9deb1e7fd160..9879aec9a7a7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -174,6 +174,7 @@ def parse_args(): parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument( "--use_auth_token", @@ -605,6 +606,8 @@ def collate_fn(examples): total_loss += loss.detach().float() accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() From c643e9419b65a0572ad7d66587b8458cfc1a2e5e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sun, 2 Oct 2022 13:07:18 +0200 Subject: [PATCH 32/54] add an option to use 8bit adam --- examples/text_to_image/train_text_to_image.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9879aec9a7a7..997d2748fd8d 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -169,6 +169,9 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") @@ -380,7 +383,19 @@ def main(): ) # Initialize the optimizer - optimizer = torch.optim.AdamW( + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( unet.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), From 3caf7c683fab40ea046e113dbcd8223f5ca238d2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 3 Oct 2022 16:15:23 +0200 Subject: [PATCH 33/54] fix logging in multi-gpu, log every step --- examples/text_to_image/train_text_to_image.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 997d2748fd8d..fbc900811111 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -595,7 +595,7 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() - total_loss = 0.0 + train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space @@ -618,7 +618,9 @@ def collate_fn(examples): # Predict the noise residual and compute loss noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(noise_pred, noise, reduction="mean") - total_loss += loss.detach().float() + + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps accelerator.backward(loss) if accelerator.sync_gradients: @@ -629,20 +631,19 @@ def collate_fn(examples): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 if args.use_ema: ema_unet.step(unet) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break - logger.info(f"epoch {epoch}: train_loss: {total_loss.item() / len(train_dataloader)}") - accelerator.log({"epoch": epoch, "train_loss": total_loss.item() / len(train_dataloader)}, step=global_step) - # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: From 518448c5cc9c707c46f74b91fd6a90b654bf133c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 13:50:36 +0200 Subject: [PATCH 34/54] more comments --- examples/text_to_image/train_text_to_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index fbc900811111..f13880b3802f 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -588,6 +588,7 @@ def collate_fn(examples): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") @@ -619,9 +620,11 @@ def collate_fn(examples): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(noise_pred, noise, reduction="mean") + # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() train_loss += avg_loss.item() / args.gradient_accumulation_steps + # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) From 926c20eb26556287bdb694bc519eb9f4daecf6dd Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 13:59:30 +0200 Subject: [PATCH 35/54] remove eval for now --- examples/text_to_image/train_text_to_image.py | 53 +------------------ 1 file changed, 2 insertions(+), 51 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index f13880b3802f..620117cf2d00 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -53,9 +53,6 @@ def parse_args(): help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") - parser.add_argument( - "--validation_data_dir", type=str, default=None, help="A folder containing the validation data." - ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing an image." ) @@ -74,21 +71,6 @@ def parse_args(): "value if set." ), ) - parser.add_argument( - "--max_eval_samples", - type=int, - default=None, - help=( - "For debugging purposes or quicker training, truncate the number of evaluation examples to this " - "value if set." - ), - ) - parser.add_argument( - "--train_val_split", - type=float, - default=0.15, - help="Percent to split off of train for validation", - ) parser.add_argument( "--output_dir", type=str, @@ -124,9 +106,6 @@ def parse_args(): parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) - parser.add_argument( - "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." - ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--max_train_steps", @@ -232,8 +211,8 @@ def parse_args(): args.local_rank = env_local_rank # Sanity checks - if args.dataset_name is None and args.train_data_dir is None and args.validation_data_dir is None: - raise ValueError("Need either a dataset name or a training/validation folder.") + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") return args @@ -425,8 +404,6 @@ def main(): data_files = {} if args.train_dir is not None: data_files["train"] = os.path.join(args.train_dir, "**") - if args.validation_dir is not None: - data_files["validation"] = os.path.join(args.validation_dir, "**") dataset = load_dataset( "imagefolder", data_files=data_files, @@ -435,13 +412,6 @@ def main(): # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_process#imagefolder. - # If we don't have a validation split, split off a percentage of train as validation. - args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split - if isinstance(args.train_val_split, float) and args.train_val_split > 0.0: - split = dataset["train"].train_test_split(args.train_val_split) - dataset["train"] = split["train"] - dataset["validation"] = split["test"] - # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset["train"].column_names @@ -493,14 +463,6 @@ def tokenize_captions(examples, is_train=True): transforms.Normalize([0.5], [0.5]), ] ) - val_transforms = transforms.Compose( - [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] @@ -509,21 +471,11 @@ def preprocess_train(examples): return examples - def preprocess_val(examples): - images = [image.convert("RGB") for image in examples[image_column]] - examples["pixel_values"] = [val_transforms(image) for image in images] - examples["input_ids"] = tokenize_captions(examples, is_train=False) - return examples - with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) - if args.max_eval_samples is not None: - dataset["validation"] = dataset["validation"].shuffle(seed=args.seed).select(range(args.max_eval_samples)) - # Set the validation transforms - # eval_dataset = dataset["validation"].with_transform(preprocess_val) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) @@ -539,7 +491,6 @@ def collate_fn(examples): train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size ) - # eval_dataloader = torch.utils.data.DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=args.eval_batch_size) # Scheduler and math around the number of training steps. overrode_max_train_steps = False From 12d19dfef4d869e80a114eb93bbd0816ee402f91 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 14:23:16 +0200 Subject: [PATCH 36/54] adress review comments --- examples/text_to_image/train_text_to_image.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 620117cf2d00..e0f111f5f15c 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -227,11 +227,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: return f"{organization}/{model_id}" -def freeze_params(params): - for param in params: - param.requires_grad = False - - dataset_name_mapping = { "image_caption_dataset.py": ("image_path", "caption"), } @@ -350,8 +345,8 @@ def main(): ) # Freeze vae and text_encoder - freeze_params(vae.parameters()) - freeze_params(text_encoder.parameters()) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -449,9 +444,8 @@ def tokenize_captions(examples, is_train=True): raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) - input_ids = tokenizer( - captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True - ).input_ids + inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) + input_ids = inputs.input_ids return input_ids train_transforms = transforms.Compose( From ac0b09e792b440f53de2d1ddc520cd814465f1d6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 17:53:26 +0200 Subject: [PATCH 37/54] add requirements file --- examples/text_to_image/requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 examples/text_to_image/requirements.txt diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt new file mode 100644 index 000000000000..c0649bbe2bef --- /dev/null +++ b/examples/text_to_image/requirements.txt @@ -0,0 +1,6 @@ +accelerate +torchvision +transformers>=4.21.0 +ftfy +tensorboard +modelcards \ No newline at end of file From 48930b792fe4f750c41d274c575e6759f1e555f6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 18:48:52 +0200 Subject: [PATCH 38/54] begin readme --- examples/text_to_image/README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 examples/text_to_image/README.md diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md new file mode 100644 index 000000000000..e69de29bb2d1 From 7964342131b283ddebd9badb5efb5c23d9d23bc9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 4 Oct 2022 18:49:17 +0200 Subject: [PATCH 39/54] begin readme --- examples/text_to_image/README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index e69de29bb2d1..884d7115c95d 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -0,0 +1,21 @@ +# Stable Diffusion text-to-image fine-tuning + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. + + +## Running locally +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install git+https://github.com/huggingface/diffusers.git +pip install -U -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` \ No newline at end of file From 1c8387cf09f9921c865b2d92c51ebf94cd4ccbd1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 6 Oct 2022 11:42:34 +0200 Subject: [PATCH 40/54] fix typo --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index e0f111f5f15c..378a5aed825b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -397,8 +397,8 @@ def main(): ) else: data_files = {} - if args.train_dir is not None: - data_files["train"] = os.path.join(args.train_dir, "**") + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") dataset = load_dataset( "imagefolder", data_files=data_files, From 3eea0dbfadc973be1a9095e2b0772c8cabe75e2c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 6 Oct 2022 11:52:28 +0200 Subject: [PATCH 41/54] fix push to hub --- examples/text_to_image/train_text_to_image.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 378a5aed825b..1ee507f71114 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -609,9 +609,7 @@ def collate_fn(examples): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: - repo.push_to_hub( - args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True - ) + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) accelerator.end_training() From eb8e6c30c9d18bb8874e1bde0b214a489ba1d231 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 6 Oct 2022 16:24:59 +0200 Subject: [PATCH 42/54] populate readme --- examples/text_to_image/README.md | 35 ++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 884d7115c95d..bbdc67428691 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -1,7 +1,6 @@ # Stable Diffusion text-to-image fine-tuning -[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. -The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. +The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. ## Running locally @@ -18,4 +17,36 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e ```bash accelerate config +``` + +### Pokemon example + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. + +
+ +``` +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch ../diffusers/examples/text_to_image/train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --dataset_name=$dataset_name \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 --gradient_checkpointing \ + --gradient_accumulation_steps=4 --max_grad_norm=1 \ + --max_train_steps=15000 \ + --learning_rate=1e-05 --use_ema \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ ``` \ No newline at end of file From 2cb7c431a4399f91775e819f50187363e5ff6263 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 6 Oct 2022 16:25:15 +0200 Subject: [PATCH 43/54] update readme --- examples/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index bbdc67428691..bf800f79e48c 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -35,7 +35,7 @@ If you have already cloned the repo, then you won't need to go through these ste
-``` +```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" From 7228818dbc9035904c29eb8b6021ed151f513922 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 6 Oct 2022 16:27:51 +0200 Subject: [PATCH 44/54] remove use_auth_token from the script --- examples/text_to_image/train_text_to_image.py | 27 +++---------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1ee507f71114..092d5ee2dd74 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -158,14 +158,6 @@ def parse_args(): parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument( - "--use_auth_token", - action="store_true", - help=( - "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" - " private models)." - ), - ) parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", @@ -329,20 +321,10 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) # Load models and create wrapper for stable diffusion - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="tokenizer", - use_auth_token=args.use_auth_token, - ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") # Freeze vae and text_encoder vae.requires_grad_(False) @@ -393,7 +375,6 @@ def main(): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, - use_auth_token=True if args.use_auth_token else None, ) else: data_files = {} From b08d85d2f10c352480bd521d105621b51f0c77ad Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 12:06:39 +0200 Subject: [PATCH 45/54] address some review comments --- examples/text_to_image/README.md | 2 +- examples/text_to_image/train_text_to_image.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index bf800f79e48c..7452011e1e67 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -40,7 +40,7 @@ export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" accelerate launch ../diffusers/examples/text_to_image/train_text_to_image.py \ - --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ --resolution=512 --center_crop --random_flip \ --train_batch_size=1 --gradient_checkpointing \ diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 092d5ee2dd74..77196d9dbaf7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -386,7 +386,7 @@ def main(): cache_dir=args.cache_dir, ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_process#imagefolder. + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets. # We need to tokenize inputs and targets. From db8e31aa032c8ae68764f6f8a83b0b25d8858b74 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 12:13:13 +0200 Subject: [PATCH 46/54] better mixed precision support --- examples/text_to_image/train_text_to_image.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 77196d9dbaf7..3b9319e05fbc 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -488,6 +488,18 @@ def collate_fn(examples): if args.use_ema: ema_unet = EMAModel(unet) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + # Move vae and unet to device vae.to(accelerator.device) text_encoder.to(accelerator.device) @@ -526,14 +538,15 @@ def collate_fn(examples): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) From 17cb6e7f27522e52018f7cda3d684c41b3c929b1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 14:26:03 +0200 Subject: [PATCH 47/54] remove redundant to --- examples/text_to_image/train_text_to_image.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3b9319e05fbc..00ddf1e88485 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -500,10 +500,6 @@ def collate_fn(examples): text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - # Move vae and unet to device - vae.to(accelerator.device) - text_encoder.to(accelerator.device) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: From 25625cec1043d32aca60ea6965cc046e416d19c0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 16:13:47 +0200 Subject: [PATCH 48/54] create ema model early --- examples/text_to_image/train_text_to_image.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 00ddf1e88485..b7ad45f9070e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -326,6 +326,9 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + if args.use_ema: + ema_unet = EMAModel(unet) + # Freeze vae and text_encoder vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -485,9 +488,6 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler ) - if args.use_ema: - ema_unet = EMAModel(unet) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -500,6 +500,9 @@ def collate_fn(examples): text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) + # Move the ema_unet to gpu. + ema_unet.averaged_model.to(accelerator.device) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: From 5d71880c33754c83f649e15d4311545926b191d9 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 11 Oct 2022 18:15:50 +0200 Subject: [PATCH 49/54] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/text_to_image/train_text_to_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b7ad45f9070e..3fa99406ae3a 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -43,7 +43,8 @@ def parse_args(): default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset)." + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( @@ -96,7 +97,7 @@ def parse_args(): parser.add_argument( "--center_crop", action="store_true", - help="Whether to center crop images before resizing to resolution (if not set, use random crop)", + help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", ) parser.add_argument( "--random_flip", From 3a6e4f2c9445cc7859d114f85dd3b00ea5387231 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 18:18:02 +0200 Subject: [PATCH 50/54] better description for train_data_dir --- examples/text_to_image/train_text_to_image.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b7ad45f9070e..7f411cf2a509 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -52,7 +52,16 @@ def parse_args(): default=None, help="The config of the Dataset, leave as None if there's only one config.", ) - parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing an image." ) From 1c8b0260c065aeb97b8c2de066d12763a996a678 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 18:30:36 +0200 Subject: [PATCH 51/54] add diffusers in requirements --- examples/text_to_image/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt index c0649bbe2bef..a80836a32027 100644 --- a/examples/text_to_image/requirements.txt +++ b/examples/text_to_image/requirements.txt @@ -1,3 +1,4 @@ +diffusers==0.4.1 accelerate torchvision transformers>=4.21.0 From 5b22178644ab62817dc69f988f2306e9fe7d8143 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 18:34:16 +0200 Subject: [PATCH 52/54] update dataset_name_mapping --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index bd37d91c71bf..e4a91ff5c8b3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -230,7 +230,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: dataset_name_mapping = { - "image_caption_dataset.py": ("image_path", "caption"), + "lambdalabs/pokemon-blip-captions": ("image", "text"), } From f0b43574b2440cb4292a4b98350a2a73a7ed3810 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 18:57:37 +0200 Subject: [PATCH 53/54] update readme --- examples/text_to_image/README.md | 48 ++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 7452011e1e67..9689c4285b6e 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -2,6 +2,10 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset. +___Note___: + +___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ + ## Running locally ### Installing the dependencies @@ -35,18 +39,50 @@ If you have already cloned the repo, then you won't need to go through these ste
+#### Hardware +With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" -accelerate launch ../diffusers/examples/text_to_image/train_text_to_image.py \ +accelerate launch train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --mixed_precision="fp16" \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + +To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). +If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export TRAIN_DIR="path_to_your_dataset" + +accelerate launch train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --use_ema \ --resolution=512 --center_crop --random_flip \ - --train_batch_size=1 --gradient_checkpointing \ - --gradient_accumulation_steps=4 --max_grad_norm=1 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --mixed_precision="fp16" \ --max_train_steps=15000 \ - --learning_rate=1e-05 --use_ema \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" \ -``` \ No newline at end of file + --output_dir="sd-pokemon-model +``` + From f9a40255948e24a21d78efb672231ba79c71b284 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 11 Oct 2022 19:01:42 +0200 Subject: [PATCH 54/54] add inference example --- examples/text_to_image/README.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 9689c4285b6e..6aca642cda4a 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -83,6 +83,19 @@ accelerate launch train_text_to_image.py \ --learning_rate=1e-05 \ --max_grad_norm=1 \ --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model + --output_dir="sd-pokemon-model" ``` +Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline` + + +```python +from diffusers import StableDiffusionPipeline + +model_path = "path_to_saved_model" +pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("yoda-pokemon.png") +```