From cd86f420999adf24c0d4a4d4992ca4038197269f Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 22 Oct 2023 20:54:51 -0400 Subject: [PATCH 01/33] Init commit --- examples/vqgan/discriminator.py | 46 ++ examples/vqgan/train_vqgan.py | 992 +++++++++++++++++++++++++++++++ src/diffusers/models/vae.py | 1 + src/diffusers/models/vq_model.py | 23 +- 4 files changed, 1053 insertions(+), 9 deletions(-) create mode 100644 examples/vqgan/discriminator.py create mode 100644 examples/vqgan/train_vqgan.py diff --git a/examples/vqgan/discriminator.py b/examples/vqgan/discriminator.py new file mode 100644 index 000000000000..3113a70fb580 --- /dev/null +++ b/examples/vqgan/discriminator.py @@ -0,0 +1,46 @@ +""" +Ported from Paella +""" +import torch +from torch import nn + + +class PaellaDiscriminator(nn.Module): + def __init__(self, config): + channels = config.discriminator.channels + cond_channels = config.discriminator.cond_channels + hidden_channels = config.discriminator.hidden_channels + depth = config.discriminator.depth + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm( + nn.Conv2d(channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) + ), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = hidden_channels // (2 ** max((d - i), 0)) + c_out = hidden_channels // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d( + (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1 + ) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view( + cond.size(0), + cond.size(1), + 1, + 1, + ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py new file mode 100644 index 000000000000..83490baa61a4 --- /dev/null +++ b/examples/vqgan/train_vqgan.py @@ -0,0 +1,992 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import math +import os +import time +from pathlib import Path + +import numpy as np +import PIL +import PIL.Image +import timm +import torch +import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import load_dataset +from discriminator import PaellaDiscriminator +from einops import rearrange, repeat +from huggingface_hub import create_repo +from PIL import Image +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from torchvision import transforms +from tqdm import tqdm + +from diffusers import VQModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.22.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def get_vq_model_class(): + return VQModel + + +def get_discriminator_class(): + return PaellaDiscriminator + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def _map_layer_to_idx(backbone, layers, offset=0): + """Maps set of layer names to indices of model. Ported from anomalib + + Returns: + Feature map extracted from the CNN + """ + idx = [] + features = timm.create_model( + backbone, + pretrained=False, + features_only=False, + exportable=True, + ) + for i in layers: + try: + idx.append(list(dict(features.named_children()).keys()).index(i) - offset) + except ValueError: + raise ValueError( + f"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}" + ) + return idx + + +def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, timm_model_normalization): + img_timm_model_input = timm_model_normalization(F.interpolate(pixel_values, timm_model_resolution)) + fmap_timm_model_input = timm_model_normalization(F.interpolate(fmap, timm_model_resolution)) + + if pixel_values.shape[1] == 1: + # handle grayscale for timm_model + img_timm_model_input, fmap_timm_model_input = (repeat(t, "b 1 ... -> b c ...", c=3) for t in (img_timm_model_input, fmap_timm_model_input)) + + img_timm_model_feats = timm_model(img_timm_model_input) + recon_timm_model_feats = timm_model(fmap_timm_model_input) + perceptual_loss = F.mse_loss(img_timm_model_feats[0], recon_timm_model_feats[0]) + for i in range(1, len(img_timm_model_feats)): + perceptual_loss += F.mse_loss(img_timm_model_feats[i], recon_timm_model_feats[i]) + perceptual_loss /= len(img_timm_model_feats) + return perceptual_loss + + +def grad_layer_wrt_loss(loss, layer): + return torch.autograd.grad( + outputs=loss, + inputs=layer, + grad_outputs=torch.ones_like(loss), + retain_graph=True, + )[0].detach() + + +def gradient_penalty(images, output, weight=10): + gradients = torch.autograd.grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + +@torch.no_grad() +def log_validation(model, args, validation_transform, accelerator, global_step): + logger.info("Generating images...") + original_images = [] + for image_path in args.validation_images.split("|"): + image = PIL.Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = validation_transform(image) + original_images.append(image[None]) + original_images = torch.cat(original_images, dim=0) + # Generate images + model.eval() + dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + dtype = torch.bfloat16 + + with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): + _, enc_token_ids = accelerator.unwrap_model(model).encode(original_images) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + enc_token_ids = torch.clamp(enc_token_ids, max=accelerator.unwrap_model(model).config.num_embeddings - 1) + images = accelerator.unwrap_model(model).decode_code(enc_token_ids) + model.train() + + # Convert to PIL images + images = torch.clamp(images, 0.0, 1.0) + original_images = torch.clamp(original_images, 0.0, 1.0) + images *= 255.0 + original_images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + images = np.concatenate([original_images, images], axis=2) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption="Original, Generated") for image in pil_images] + wandb.log({"vae_images": wandb_images}, step=global_step) + + +def save_checkpoint(model, discriminator, args, accelerator, global_step): + save_path = Path(args.output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + discr_state_dict = accelerator.get_state_dict(discriminator) + + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + ) + torch.save(discr_state_dict, save_path / "unwrapped_discriminator") + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + accelerator.save_state(save_path) + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--log_grad_norm_steps", + type=int, + default=500, + help=("Print logs of gradient norms every X steps."), + ) + parser.add_argument( + "--log_steps", + type=int, + default=50, + help=("Print logs every X steps."), + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the images in" + " `args.validation_images` and logging the reconstructed images." + ), + ) + parser.add_argument( + "--vae_loss", + type=str, + default="l2", + help="The loss function for vae reconstruction loss.", + ) + parser.add_argument( + "--timm_model_offset", + type=int, + default=0, + help="Offset of timm layers to indices.", + ) + parser.add_argument( + "--timm_model_layers", + type=str, + default="head", + help="The layers to get output from in the timm model.", + ) + parser.add_argument( + "--timm_model_backend", + type=str, + default="vgg19", + help="Timm model used to get the lpips loss", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=("A set of validation images evaluated every `--validation_steps` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--discr_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--discr_lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="vqgan-training", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + ######################### + # SETUP Accelerator # + ######################### + args = parse_args() + + # Enable TF32 on Ampere GPUs + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + vq_class = get_vq_model_class() + model = VQModel.from_pretrained(args.pretrained_model_name_or_path) + if args.use_ema: + ema_model = EMAModel(model.parameters(), model_cls=vq_class, model_config=model.config) + discriminator_class = get_discriminator_class() + discriminator = discriminator_class() + # TODO: Add timm_model_backend to config.training. Set default to vgg16 + idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split("|"), args.timm_model_offset) + + timm_model = timm.create_model( + args.timm_model_backend, + pretrained=True, + features_only=True, + exportable=True, + out_indices=idx, + ) + timm_model = timm_model.to(accelerator.device) + timm_model.requires_grad = False + timm_model.eval() + timm_transform = create_transform(**resolve_data_config(timm_model.pretrained_cfg, model=timm_model)) + try: + # Gets the resolution of the timm transformation after centercrop + timm_centercrop_transform = timm_transform.transforms[1] + assert isinstance( + timm_centercrop_transform, transforms.CenterCrop + ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + timm_centercrop_transform.size[0] + # Gets final normalization + timm_model_normalization = timm_transform.transforms[-1] + assert isinstance( + timm_model_normalization, transforms.Normalize + ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." + except AssertionError as e: + raise NotImplementedError(e) + # Enable flash attention if asked + if args.enable_xformers_memory_efficient_attention: + model.enable_xformers_memory_efficient_attention() + + learning_rate = args.learning_rate + if args.scale_lr: + learning_rate = ( + learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + list(model.parameters()), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + discr_optimizer = optimizer_cls( + list(discriminator.parameters()), + lr=args.discr_learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + ################################## + # DATLOADER and LR-SCHEDULER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + args.train_batch_size * accelerator.num_processes + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + # DataLoaders creation: + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + ] + ) + validation_transform = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + return {"pixel_values": pixel_values} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps=args.max_train_steps, + num_warmup_steps=args.lr_warmup_steps, + ) + discr_lr_scheduler = get_scheduler( + args.discr_lr_scheduler, + optimizer=discr_optimizer, + num_training_steps=args.max_train_steps, + num_warmup_steps=args.lr_warmup_steps, + ) + + # Prepare everything with accelerator + logger.info("Preparing model, optimizer and dataloaders") + # The dataloader are already aware of distributed training, so we don't need to prepare them. + model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare( + model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs. + # Note: We are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num training steps = {args.max_train_steps}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + resume_from_checkpoint = args.resume_from_checkpoint + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = resume_from_checkpoint + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + path = os.path.join(args.output_dir, path) + + if path is None: + accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(path) + accelerator.wait_for_everyone() + global_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + avg_gen_loss, avg_discr_loss = None, None + print("gradient accumulation steps", args.gradient_accumulation_steps) + for epoch in range(first_epoch, num_train_epochs): + model.train() + for i, batch in tqdm(enumerate(train_dataloader)): + pixel_values = batch["pixel_values"] + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + generator_step = ((i // args.gradient_accumulation_steps) % 2) == 0 + # Train Step + # The behavior of accelerator.accumulate is to + # 1. Check if gradients are synced(reached gradient-accumulation_steps) + # 2. If so sync gradients by stopping the not syncing process + if generator_step: + optimizer.zero_grad(set_to_none=True) + else: + discr_optimizer.zero_grad(set_to_none=True) + # encode images to the latent space and get the commit loss from vq tokenization + # Return commit loss + fmap, _, _, commit_loss = model(pixel_values, return_loss=True) + if accelerator.sync_gradients: + global_step += 1 + if generator_step: + with accelerator.accumulate(model): + # reconstruction loss. Pixel level differences between input vs output + if args.vae_loss == "l2": + loss = F.mse_loss(pixel_values, fmap) + else: + loss = F.l1_loss(pixel_values, fmap) + # perceptual loss. The high level feature mean squared error loss + perceptual_loss = get_perceptual_loss(pixel_values, fmap, timm_model) + # generator loss + gen_loss = -discriminator(fmap).mean() + last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) + + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-8) + adaptive_weight = adaptive_weight.clamp(max=1e4) + loss += commit_loss + loss += perceptual_loss + loss += adaptive_weight * gen_loss + # Gather thexd losses across all processes for logging (if we use distributed training). + avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean() + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and global_step % args.log_grad_norm_steps == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step) + else: + # Return discriminator loss + with accelerator.accumulate(discriminator): + fmap.detach_() + pixel_values.requires_grad_() + real = discriminator(pixel_values) + fake = discriminator(fmap) + loss = (F.relu(1 + fake) + F.relu(1 - real)).mean() + gp = gradient_penalty(pixel_values, real) + loss += gp + avg_discr_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + + discr_optimizer.step() + discr_lr_scheduler.step() + if ( + accelerator.sync_gradients + and global_step % args.log_grad_norm_steps == 0 + and accelerator.is_main_process + ): + log_grad_norm(discriminator, accelerator, global_step) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients and not generator_step and accelerator.is_main_process: + if args.use_ema: + ema_model.step(model.parameters()) + # wait for both generator and discriminator to settle + batch_time_m.update(time.time() - end) + end = time.time() + # Log metrics + if global_step % args.log_steps == 0: + samples_per_second_per_gpu = ( + args.gradient_accumulation_steps * args.train_batch_size / batch_time_m.val + ) + logs = { + "step_discr_loss": avg_discr_loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + if avg_gen_loss is not None: + logs["step_gen_loss"] = avg_gen_loss.item() + accelerator.log(logs, step=global_step) + logger.info( + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f} " + f"Step: {global_step} " + f"Discriminator Loss: {avg_discr_loss.item():0.4f} " + ) + if avg_gen_loss is not None: + logger.info(f"Generator Loss: {avg_gen_loss.item():0.4f} ") + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if global_step % args.checkpointing_steps == 0: + save_checkpoint(model, discriminator, args, accelerator, global_step) + + # Generate images + if global_step % args.validation_steps == 0: + log_validation(model, args, validation_transform, accelerator, global_step) + + # Stop training if max steps is reached + if global_step >= args.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, discriminator, args, accelerator, global_step) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + if args.use_ema: + ema_model.copy_to(model.parameters()) + model.save_pretrained(args.output_dir) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 36983eefc01f..b8ec8c1fbbe5 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -36,6 +36,7 @@ class DecoderOutput(BaseOutput): """ sample: torch.FloatTensor + commit_loss: torch.FloatTensor class Encoder(nn.Module): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 0c15300af213..241c9cb81a41 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -133,18 +133,19 @@ def decode( ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: - quant, _, _ = self.quantize(h) + quant, commit_loss, _ = self.quantize(h) else: quant = h + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) quant2 = self.post_quant_conv(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: - return (dec,) + return (dec, commit_loss,) - return DecoderOutput(sample=dec) + return DecoderOutput(sample=dec, commit_loss=commit_loss) - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def forward(self, sample: torch.FloatTensor, return_dict: bool = True, return_loss: bool = False) -> Union[DecoderOutput, torch.FloatTensor]: r""" The [`VQModel`] forward method. @@ -152,7 +153,8 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[ sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. - + return_loss (`bool`, *optional*, defaults to `False`): + Whether or not to return a commit loss. Returns: [`~models.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` @@ -160,9 +162,12 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[ """ x = sample h = self.encode(x).latents - dec = self.decode(h).sample + dec = self.decode(h) if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) + if return_loss: + return (dec.sample, dec.commit_loss,) + return (dec.sample,) + if return_loss: + return dec + return DecoderOutput(sample=dec.sample) From c0e44b004c6228c92787f2ae1fc118f8f9ace2b5 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 17:43:21 -0400 Subject: [PATCH 02/33] Removed einops --- examples/vqgan/train_vqgan.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 83490baa61a4..23dc4f27e18b 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -32,7 +32,6 @@ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from datasets import load_dataset from discriminator import PaellaDiscriminator -from einops import rearrange, repeat from huggingface_hub import create_repo from PIL import Image from timm.data import resolve_data_config @@ -115,7 +114,7 @@ def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, t if pixel_values.shape[1] == 1: # handle grayscale for timm_model - img_timm_model_input, fmap_timm_model_input = (repeat(t, "b 1 ... -> b c ...", c=3) for t in (img_timm_model_input, fmap_timm_model_input)) + img_timm_model_input, fmap_timm_model_input = (t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input)) img_timm_model_feats = timm_model(img_timm_model_input) recon_timm_model_feats = timm_model(fmap_timm_model_input) @@ -144,8 +143,8 @@ def gradient_penalty(images, output, weight=10): retain_graph=True, only_inputs=True, )[0] - - gradients = rearrange(gradients, "b ... -> b (...)") + bsz = gradients.shape[0] + gradients = torch.reshape(gradients, (bsz, -1)) return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() @@ -275,20 +274,6 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) parser.add_argument( "--revision", type=str, From ccaf3938c8a713f6d0efda27fb4ab2555e430550 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 18:14:18 -0400 Subject: [PATCH 03/33] Added default movq config for training --- examples/vqgan/train_vqgan.py | 46 +++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 23dc4f27e18b..98bca98910f8 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -271,9 +271,14 @@ def parse_args(): "--pretrained_model_name_or_path", type=str, default=None, - required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + ) parser.add_argument( "--revision", type=str, @@ -608,7 +613,44 @@ def main(): logger.info("Loading models and optimizer") vq_class = get_vq_model_class() - model = VQModel.from_pretrained(args.pretrained_model_name_or_path) + if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None: + # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder + model = VQModel( + act_fn="silu", + block_out_channels=[ + 128, + 256, + 256, + 512 + ], + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "AttnDownEncoderBlock2D" + ], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + norm_type="spatial", + num_vq_embeddings=16384, + out_channels=3, + sample_size=32, + scaling_factor=0.18215, + up_block_types=[ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + vq_embed_dim=4 + ) + elif args.pretrained_model_name_or_path is None: + model = VQModel.from_pretrained(args.pretrained_model_name_or_path) + else: + config = VQModel.load_config(args.model_config_name_or_path) + model = VQModel.from_config(config) if args.use_ema: ema_model = EMAModel(model.parameters(), model_cls=vq_class, model_config=model.config) discriminator_class = get_discriminator_class() From 4b361ccbecab7a853b390f2efbc21512af35dfb3 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 18:24:03 -0400 Subject: [PATCH 04/33] Update explanation of prompts --- examples/vqgan/train_vqgan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 98bca98910f8..609c77a6d7e1 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -239,7 +239,7 @@ def parse_args(): type=int, default=100, help=( - "Run validation every X steps. Validation consists of running the images in" + "Run validation every X steps. Validation consists of running reconstruction on images in" " `args.validation_images` and logging the reconstructed images." ), ) @@ -590,7 +590,6 @@ def main(): if accelerator.is_main_process: tracker_config = dict(vars(args)) - tracker_config.pop("validation_prompts") accelerator.init_trackers(args.tracker_project_name, tracker_config) # If passed along, set the training seed now. From a726069764ca521c0dbb59b79d4444d9d582f750 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 20:28:26 -0400 Subject: [PATCH 05/33] Fixed inheritance of discriminator and init_tracker --- examples/vqgan/discriminator.py | 21 +++++++++++------- examples/vqgan/train_vqgan.py | 38 ++++++++++++++++----------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/examples/vqgan/discriminator.py b/examples/vqgan/discriminator.py index 3113a70fb580..9b5b1f864df1 100644 --- a/examples/vqgan/discriminator.py +++ b/examples/vqgan/discriminator.py @@ -3,19 +3,24 @@ """ import torch from torch import nn +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config - -class PaellaDiscriminator(nn.Module): - def __init__(self, config): - channels = config.discriminator.channels - cond_channels = config.discriminator.cond_channels - hidden_channels = config.discriminator.hidden_channels - depth = config.discriminator.depth +# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py +class Discriminator(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels=3, + cond_channels=0, + hidden_channels=512, + depth=6 + ): super().__init__() d = max(depth - 3, 3) layers = [ nn.utils.spectral_norm( - nn.Conv2d(channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) + nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1) ), nn.LeakyReLU(0.2), ] diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 609c77a6d7e1..ccd78b46ce21 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -31,7 +31,7 @@ from accelerate.logging import get_logger from accelerate.utils import DistributedType, ProjectConfiguration, set_seed from datasets import load_dataset -from discriminator import PaellaDiscriminator +from discriminator import Discriminator from huggingface_hub import create_repo from PIL import Image from timm.data import resolve_data_config @@ -57,15 +57,6 @@ "lambdalabs/pokemon-blip-captions": ("image", "text"), } - -def get_vq_model_class(): - return VQModel - - -def get_discriminator_class(): - return PaellaDiscriminator - - class AverageMeter(object): """Computes and stores the average and current value""" @@ -152,7 +143,7 @@ def gradient_penalty(images, output, weight=10): def log_validation(model, args, validation_transform, accelerator, global_step): logger.info("Generating images...") original_images = [] - for image_path in args.validation_images.split("|"): + for image_path in args.validation_images: image = PIL.Image.open(image_path) if not image.mode == "RGB": image = image.convert("RGB") @@ -277,7 +268,13 @@ def parse_args(): "--model_config_name_or_path", type=str, default=None, - help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + help="The config of the Vq model to train, leave as None to use standard DDPM configuration.", + ) + parser.add_argument( + "--discriminator_config_name_or_path", + type=str, + default=None, + help="The config of the discriminator model to train, leave as None to use standard DDPM configuration.", ) parser.add_argument( "--revision", @@ -325,7 +322,7 @@ def parse_args(): ), ) parser.add_argument( - "--validation_image", + "--validation_images", type=str, default=None, nargs="+", @@ -334,7 +331,7 @@ def parse_args(): parser.add_argument( "--output_dir", type=str, - default="sd-model-finetuned", + default="vqgan-output", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -590,6 +587,7 @@ def main(): if accelerator.is_main_process: tracker_config = dict(vars(args)) + tracker_config.pop("validation_images") accelerator.init_trackers(args.tracker_project_name, tracker_config) # If passed along, set the training seed now. @@ -611,7 +609,6 @@ def main(): ######################### logger.info("Loading models and optimizer") - vq_class = get_vq_model_class() if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None: # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder model = VQModel( @@ -651,10 +648,13 @@ def main(): config = VQModel.load_config(args.model_config_name_or_path) model = VQModel.from_config(config) if args.use_ema: - ema_model = EMAModel(model.parameters(), model_cls=vq_class, model_config=model.config) - discriminator_class = get_discriminator_class() - discriminator = discriminator_class() - # TODO: Add timm_model_backend to config.training. Set default to vgg16 + ema_model = EMAModel(model.parameters(), model_cls=VQModel, model_config=model.config) + if args.discriminator_config_name_or_path is None: + discriminator = Discriminator() + else: + config = Discriminator.load_config(args.discriminator_config_name_or_path) + discriminator = Discriminator.from_config(config) + idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split("|"), args.timm_model_offset) timm_model = timm.create_model( From 0b0cea3b9a467b00ca6b40dfbb94a836f7ce2d27 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 20:35:21 -0400 Subject: [PATCH 06/33] Fixed incompatible api between muse and here --- examples/vqgan/train_vqgan.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index ccd78b46ce21..9a6a656b5887 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -823,22 +823,17 @@ def collate_fn(examples): model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare( model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) - # Afterwards we recalculate our number of training epochs. - # Note: We are not doing epoch based training here, but just using this for book keeping and being able to - # reuse the same training loop with other datasets/loaders. - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - # Train! logger.info("***** Running training *****") - logger.info(f" Num training steps = {args.max_train_steps}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) # Potentially load in the weights and states from a previous save resume_from_checkpoint = args.resume_from_checkpoint @@ -866,12 +861,10 @@ def collate_fn(examples): batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() - # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to # reuse the same training loop with other datasets/loaders. avg_gen_loss, avg_discr_loss = None, None - print("gradient accumulation steps", args.gradient_accumulation_steps) - for epoch in range(first_epoch, num_train_epochs): + for epoch in range(first_epoch, args.num_train_epochs): model.train() for i, batch in tqdm(enumerate(train_dataloader)): pixel_values = batch["pixel_values"] From 68be3c5a18cf943260dfa12fb2e27adf3665e909 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 21:32:55 -0400 Subject: [PATCH 07/33] Fixed output --- examples/vqgan/train_vqgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 9a6a656b5887..8a6d60a829a3 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -881,7 +881,7 @@ def collate_fn(examples): discr_optimizer.zero_grad(set_to_none=True) # encode images to the latent space and get the commit loss from vq tokenization # Return commit loss - fmap, _, _, commit_loss = model(pixel_values, return_loss=True) + fmap, commit_loss = model(pixel_values, return_dict=False, return_loss=True) if accelerator.sync_gradients: global_step += 1 if generator_step: From 3072e5dc9e6d0a2ce1810f5fdbc3cd0c50c1cd73 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 21:45:21 -0400 Subject: [PATCH 08/33] Setup init training --- examples/vqgan/train_vqgan.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 8a6d60a829a3..8e67db884d25 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -674,7 +674,7 @@ def main(): assert isinstance( timm_centercrop_transform, transforms.CenterCrop ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19." - timm_centercrop_transform.size[0] + timm_model_resolution = timm_centercrop_transform.size[0] # Gets final normalization timm_model_normalization = timm_transform.transforms[-1] assert isinstance( @@ -833,7 +833,16 @@ def collate_fn(examples): logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Potentially load in the weights and states from a previous save resume_from_checkpoint = args.resume_from_checkpoint @@ -892,7 +901,7 @@ def collate_fn(examples): else: loss = F.l1_loss(pixel_values, fmap) # perceptual loss. The high level feature mean squared error loss - perceptual_loss = get_perceptual_loss(pixel_values, fmap, timm_model) + perceptual_loss = get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution=timm_model_resolution, timm_model_normalization=timm_model_normalization) # generator loss gen_loss = -discriminator(fmap).mean() last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight From a3022018be38243dddd2fdfd40beca77213e7206 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 29 Oct 2023 22:27:11 -0400 Subject: [PATCH 09/33] Basic structure done --- examples/vqgan/train_vqgan.py | 37 ++++++++++++----------------------- src/diffusers/models/vae.py | 2 +- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 8e67db884d25..1272e8f535cc 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -142,29 +142,27 @@ def gradient_penalty(images, output, weight=10): @torch.no_grad() def log_validation(model, args, validation_transform, accelerator, global_step): logger.info("Generating images...") + dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + dtype = torch.bfloat16 original_images = [] for image_path in args.validation_images: image = PIL.Image.open(image_path) if not image.mode == "RGB": image = image.convert("RGB") - image = validation_transform(image) + image = validation_transform(image).to(accelerator.device, dtype=dtype) original_images.append(image[None]) - original_images = torch.cat(original_images, dim=0) # Generate images model.eval() - dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - dtype = torch.bfloat16 - - with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): - _, enc_token_ids = accelerator.unwrap_model(model).encode(original_images) - # In the beginning of training, the model is not fully trained and the generated token ids can be out of range - # so we clamp them to the correct range. - enc_token_ids = torch.clamp(enc_token_ids, max=accelerator.unwrap_model(model).config.num_embeddings - 1) - images = accelerator.unwrap_model(model).decode_code(enc_token_ids) + images = [] + for original_image in original_images: + image = accelerator.unwrap_model(model)(original_image).sample + images.append(image) model.train() + original_images = torch.cat(original_images, dim=0) + images = torch.cat(images, dim=0) # Convert to PIL images images = torch.clamp(images, 0.0, 1.0) @@ -178,7 +176,7 @@ def log_validation(model, args, validation_transform, accelerator, global_step): # Log images wandb_images = [wandb.Image(image, caption="Original, Generated") for image in pil_images] - wandb.log({"vae_images": wandb_images}, step=global_step) + accelerator.log({"vae_images": wandb_images}, step=global_step) def save_checkpoint(model, discriminator, args, accelerator, global_step): @@ -975,15 +973,6 @@ def collate_fn(examples): if avg_gen_loss is not None: logs["step_gen_loss"] = avg_gen_loss.item() accelerator.log(logs, step=global_step) - logger.info( - f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " - f"Batch (t): {batch_time_m.val:0.4f} " - f"LR: {lr_scheduler.get_last_lr()[0]:0.6f} " - f"Step: {global_step} " - f"Discriminator Loss: {avg_discr_loss.item():0.4f} " - ) - if avg_gen_loss is not None: - logger.info(f"Generator Loss: {avg_gen_loss.item():0.4f} ") # resetting batch / data time meters per log window batch_time_m.reset() diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b8ec8c1fbbe5..f072fc3af209 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -36,7 +36,7 @@ class DecoderOutput(BaseOutput): """ sample: torch.FloatTensor - commit_loss: torch.FloatTensor + commit_loss: Optional[torch.FloatTensor] = None class Encoder(nn.Module): From 388f880b3edb3b10897d545734ad3f1e24dd9dd2 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 31 Oct 2023 22:26:43 +0900 Subject: [PATCH 10/33] Removed attention for quick tests --- examples/vqgan/train_vqgan.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 1272e8f535cc..b30725acdd72 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -608,20 +608,18 @@ def main(): logger.info("Loading models and optimizer") if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None: - # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder + # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder but without the attention layers model = VQModel( act_fn="silu", block_out_channels=[ 128, 256, - 256, - 512 + 512, ], down_block_types=[ "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", - "AttnDownEncoderBlock2D" ], in_channels=3, latent_channels=4, @@ -633,7 +631,6 @@ def main(): sample_size=32, scaling_factor=0.18215, up_block_types=[ - "AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D" From fca82c53c0f9fd9b3da34631a0ae231fa5a942a5 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 31 Oct 2023 09:30:23 -0400 Subject: [PATCH 11/33] Style fixes --- examples/vqgan/discriminator.py | 12 ++++-------- examples/vqgan/train_vqgan.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/vqgan/discriminator.py b/examples/vqgan/discriminator.py index 9b5b1f864df1..00038ffb0481 100644 --- a/examples/vqgan/discriminator.py +++ b/examples/vqgan/discriminator.py @@ -3,19 +3,15 @@ """ import torch from torch import nn -from diffusers.models.modeling_utils import ModelMixin + from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + # Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py class Discriminator(ModelMixin, ConfigMixin): @register_to_config - def __init__( - self, - in_channels=3, - cond_channels=0, - hidden_channels=512, - depth=6 - ): + def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6): super().__init__() d = max(depth - 3, 3) layers = [ diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index b30725acdd72..93fa11c3c70b 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -57,6 +57,7 @@ "lambdalabs/pokemon-blip-captions": ("image", "text"), } + class AverageMeter(object): """Computes and stores the average and current value""" @@ -105,7 +106,9 @@ def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, t if pixel_values.shape[1] == 1: # handle grayscale for timm_model - img_timm_model_input, fmap_timm_model_input = (t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input)) + img_timm_model_input, fmap_timm_model_input = ( + t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input) + ) img_timm_model_feats = timm_model(img_timm_model_input) recon_timm_model_feats = timm_model(fmap_timm_model_input) @@ -630,12 +633,8 @@ def main(): out_channels=3, sample_size=32, scaling_factor=0.18215, - up_block_types=[ - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D" - ], - vq_embed_dim=4 + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + vq_embed_dim=4, ) elif args.pretrained_model_name_or_path is None: model = VQModel.from_pretrained(args.pretrained_model_name_or_path) @@ -896,7 +895,13 @@ def collate_fn(examples): else: loss = F.l1_loss(pixel_values, fmap) # perceptual loss. The high level feature mean squared error loss - perceptual_loss = get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution=timm_model_resolution, timm_model_normalization=timm_model_normalization) + perceptual_loss = get_perceptual_loss( + pixel_values, + fmap, + timm_model, + timm_model_resolution=timm_model_resolution, + timm_model_normalization=timm_model_normalization, + ) # generator loss gen_loss = -discriminator(fmap).mean() last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight From 1924fab72d3d6a7179ae59d5d3234b89b713a855 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 31 Oct 2023 09:40:19 -0400 Subject: [PATCH 12/33] Fixed vae/vqgan styles --- src/diffusers/models/vae.py | 2 +- src/diffusers/models/vq_model.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index f072fc3af209..029f4ca3aa83 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -36,7 +36,7 @@ class DecoderOutput(BaseOutput): """ sample: torch.FloatTensor - commit_loss: Optional[torch.FloatTensor] = None + commit_loss: Optional[torch.FloatTensor] = None class Encoder(nn.Module): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 241c9cb81a41..1683038e86f5 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -141,11 +141,16 @@ def decode( dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: - return (dec, commit_loss,) + return ( + dec, + commit_loss, + ) return DecoderOutput(sample=dec, commit_loss=commit_loss) - def forward(self, sample: torch.FloatTensor, return_dict: bool = True, return_loss: bool = False) -> Union[DecoderOutput, torch.FloatTensor]: + def forward( + self, sample: torch.FloatTensor, return_dict: bool = True, return_loss: bool = False + ) -> Union[DecoderOutput, torch.FloatTensor]: r""" The [`VQModel`] forward method. @@ -166,7 +171,10 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True, return_lo if not return_dict: if return_loss: - return (dec.sample, dec.commit_loss,) + return ( + dec.sample, + dec.commit_loss, + ) return (dec.sample,) if return_loss: return dec From 563744484e6d72f60ed72f8f427b19dc403cc2cb Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 31 Oct 2023 09:45:55 -0400 Subject: [PATCH 13/33] Removed redefinition of wandb --- examples/vqgan/train_vqgan.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 93fa11c3c70b..e6a894cb2376 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -26,7 +26,6 @@ import timm import torch import torch.nn.functional as F -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedType, ProjectConfiguration, set_seed From 2f5421de106077ad3e713bf97636e657b7ddae15 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 31 Oct 2023 10:23:30 -0400 Subject: [PATCH 14/33] Fixed log_validation and tqdm --- examples/vqgan/train_vqgan.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index e6a894cb2376..23cecfa8d836 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -174,11 +174,23 @@ def log_validation(model, args, validation_transform, accelerator, global_step): images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) images = np.concatenate([original_images, images], axis=2) - pil_images = [Image.fromarray(image) for image in images] + images = [Image.fromarray(image) for image in images] # Log images - wandb_images = [wandb.Image(image, caption="Original, Generated") for image in pil_images] - accelerator.log({"vae_images": wandb_images}, step=global_step) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images) + ] + } + ) + torch.cuda.empty_cache() + return images def save_checkpoint(model, discriminator, args, accelerator, global_step): @@ -863,12 +875,19 @@ def collate_fn(examples): batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to # reuse the same training loop with other datasets/loaders. avg_gen_loss, avg_discr_loss = None, None for epoch in range(first_epoch, args.num_train_epochs): model.train() - for i, batch in tqdm(enumerate(train_dataloader)): + for i, batch in enumerate(train_dataloader): pixel_values = batch["pixel_values"] pixel_values = pixel_values.to(accelerator.device, non_blocking=True) data_time_m.update(time.time() - end) @@ -884,8 +903,7 @@ def collate_fn(examples): # encode images to the latent space and get the commit loss from vq tokenization # Return commit loss fmap, commit_loss = model(pixel_values, return_dict=False, return_loss=True) - if accelerator.sync_gradients: - global_step += 1 + if generator_step: with accelerator.accumulate(model): # reconstruction loss. Pixel level differences between input vs output @@ -953,9 +971,12 @@ def collate_fn(examples): ): log_grad_norm(discriminator, accelerator, global_step) # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients and not generator_step and accelerator.is_main_process: + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) if args.use_ema: ema_model.step(model.parameters()) + if accelerator.sync_gradients and not generator_step and accelerator.is_main_process: # wait for both generator and discriminator to settle batch_time_m.update(time.time() - end) end = time.time() From e318ca86f4287f19a63c54ebb9291bf90912cfb0 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 31 Oct 2023 23:25:37 +0900 Subject: [PATCH 15/33] Nothing commit --- examples/vqgan/train_vqgan.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index e6a894cb2376..9d0f9a7f8ea8 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -981,7 +981,6 @@ def collate_fn(examples): # Save model checkpoint if global_step % args.checkpointing_steps == 0: save_checkpoint(model, discriminator, args, accelerator, global_step) - # Generate images if global_step % args.validation_steps == 0: log_validation(model, args, validation_transform, accelerator, global_step) From ce00b6ccb1c5513d5090adba39dad630e808486c Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Wed, 28 Feb 2024 08:07:46 -0500 Subject: [PATCH 16/33] Added commit loss to lookup_from_codebook --- src/diffusers/models/vq_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 219687dba866..73b0c00662e9 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -145,6 +145,7 @@ def decode( quant, commit_loss, _ = self.quantize(h) elif self.config.lookup_from_codebook: quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) else: quant = h commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) From 79bdc26b29af46538eba8d9298a78f92116f5cfa Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Thu, 7 Mar 2024 10:48:51 -0500 Subject: [PATCH 17/33] Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul --- src/diffusers/models/vq_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 73b0c00662e9..b08c04be5b8f 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -153,10 +153,7 @@ def decode( dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: - return ( - dec, - commit_loss, - ) + return dec, commit_loss return DecoderOutput(sample=dec, commit_loss=commit_loss) From d16fea1fc295911348ef95c1c9dc8b30bbf2f9af Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Thu, 7 Mar 2024 22:14:47 -0500 Subject: [PATCH 18/33] Adding perliminary README --- examples/vqgan/README.md | 119 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 examples/vqgan/README.md diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md new file mode 100644 index 000000000000..2a2e09a99a4e --- /dev/null +++ b/examples/vqgan/README.md @@ -0,0 +1,119 @@ +## Training an VQGAN VAE + +Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### Training on CIFAR10 + +The command to train a VQGAN model on cifar10 dataset: + +```bash +accelerate launch train_vqgan.py \ + --dataset_name=cifar10 \ + --image_column=img \ + --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \ + --resolution=128 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=8 \ + --report_to=wandb +``` + +The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below. + +# Modifying the architecture + +To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below +``` +{ + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 256, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "AttnDownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 16384, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "vq_embed_dim": 4 +} +``` +To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below +``` +``` +{ + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 256, + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 16384, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "vq_embed_dim": 4 +} +``` +For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that. \ No newline at end of file From 42504e55ca911fa383a566c64e4e3aef969aa996 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 26 Mar 2024 01:05:12 +0900 Subject: [PATCH 19/33] Fixed one typo --- examples/vqgan/train_vqgan.py | 2 +- src/diffusers/models/vq_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 2b89365a34d2..4cd047a41347 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -930,7 +930,7 @@ def collate_fn(examples): loss += commit_loss loss += perceptual_loss loss += adaptive_weight * gen_loss - # Gather thexd losses across all processes for logging (if we use distributed training). + # Gather the losses across all processes for logging (if we use distributed training). avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean() accelerator.backward(loss) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index b08c04be5b8f..876dc11c2ff3 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -176,7 +176,7 @@ def forward( """ h = self.encode(sample).latents - dec = self.decode(h).sample + dec = self.decode(h) if not return_dict: if return_loss: From 7c6aeecaa30392572750c628c5dc22099a5f0e65 Mon Sep 17 00:00:00 2001 From: unknown Date: Sat, 27 Apr 2024 02:26:02 +0900 Subject: [PATCH 20/33] Local changes --- examples/vqgan/train_vqgan.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 4cd047a41347..cc30838a3c22 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -187,7 +187,8 @@ def log_validation(model, args, validation_transform, accelerator, global_step): "validation": [ wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images) ] - } + }, + step=global_step ) torch.cuda.empty_cache() return images @@ -461,7 +462,7 @@ def parse_args(): parser.add_argument( "--dataloader_num_workers", type=int, - default=0, + default=4, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), @@ -783,7 +784,6 @@ def main(): transforms.ToTensor(), ] ) - def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] @@ -792,7 +792,6 @@ def preprocess_train(examples): with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) - # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): @@ -970,6 +969,7 @@ def collate_fn(examples): and accelerator.is_main_process ): log_grad_norm(discriminator, accelerator, global_step) + batch_time_m.update(time.time() - end) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: global_step += 1 @@ -978,8 +978,6 @@ def collate_fn(examples): ema_model.step(model.parameters()) if accelerator.sync_gradients and not generator_step and accelerator.is_main_process: # wait for both generator and discriminator to settle - batch_time_m.update(time.time() - end) - end = time.time() # Log metrics if global_step % args.log_steps == 0: samples_per_second_per_gpu = ( @@ -1005,7 +1003,7 @@ def collate_fn(examples): # Generate images if global_step % args.validation_steps == 0: log_validation(model, args, validation_transform, accelerator, global_step) - + end = time.time() # Stop training if max steps is reached if global_step >= args.max_train_steps: break From 4ad7a22e775d9719372a5b990a84a33665598ce0 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 28 Apr 2024 00:40:29 -0400 Subject: [PATCH 21/33] Fixed main issues --- examples/vqgan/README.md | 3 +- examples/vqgan/discriminator.py | 1 + examples/vqgan/requirements.txt | 8 ++ examples/vqgan/train_vqgan.py | 105 ++++++++++++-------- src/diffusers/models/vq_model.py | 2 +- tests/models/autoencoders/test_models_vq.py | 16 +++ 6 files changed, 91 insertions(+), 44 deletions(-) create mode 100644 examples/vqgan/requirements.txt diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md index 2a2e09a99a4e..34b3ad8ac5f0 100644 --- a/examples/vqgan/README.md +++ b/examples/vqgan/README.md @@ -1,4 +1,6 @@ ## Training an VQGAN VAE +VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file). + Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). @@ -84,7 +86,6 @@ To modify the architecture of the vqgan model you can save the config taken from ``` To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below ``` -``` { "_class_name": "VQModel", "_diffusers_version": "0.17.0.dev0", diff --git a/examples/vqgan/discriminator.py b/examples/vqgan/discriminator.py index 00038ffb0481..eb31c3cb0165 100644 --- a/examples/vqgan/discriminator.py +++ b/examples/vqgan/discriminator.py @@ -1,6 +1,7 @@ """ Ported from Paella """ + import torch from torch import nn diff --git a/examples/vqgan/requirements.txt b/examples/vqgan/requirements.txt new file mode 100644 index 000000000000..f204a70f1e0e --- /dev/null +++ b/examples/vqgan/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +datasets +timm +numpy +tqdm +tensorboard \ No newline at end of file diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 23cecfa8d836..ce1c51d7ffd8 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -14,12 +14,13 @@ # limitations under the License. import argparse -import json import math import os +import shutil import time from pathlib import Path +import accelerate import numpy as np import PIL import PIL.Image @@ -32,6 +33,7 @@ from datasets import load_dataset from discriminator import Discriminator from huggingface_hub import create_repo +from packaging import version from PIL import Image from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform @@ -48,14 +50,10 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.22.0.dev0") +check_min_version("0.27.0.dev0") logger = get_logger(__name__, log_level="INFO") -DATASET_NAME_MAPPING = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} - class AverageMeter(object): """Computes and stores the average and current value""" @@ -193,28 +191,6 @@ def log_validation(model, args, validation_transform, accelerator, global_step): return images -def save_checkpoint(model, discriminator, args, accelerator, global_step): - save_path = Path(args.output_dir) / f"checkpoint-{global_step}" - - # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) - # XXX: could also make this conditional on deepspeed - state_dict = accelerator.get_state_dict(model) - discr_state_dict = accelerator.get_state_dict(discriminator) - - if accelerator.is_main_process: - unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained( - save_path / "unwrapped_model", - save_function=accelerator.save, - state_dict=state_dict, - ) - torch.save(discr_state_dict, save_path / "unwrapped_discriminator") - json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) - logger.info(f"Saved state to {save_path}") - - accelerator.save_state(save_path) - - def log_grad_norm(model, accelerator, global_step): for name, param in model.named_parameters(): if param.grad is not None: @@ -691,6 +667,33 @@ def main(): if args.enable_xformers_memory_efficient_attention: model.enable_xformers_memory_efficient_attention() + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + vqmodel = models[0] + discriminator = models[1] + vqmodel.save_pretrained(os.path.join(output_dir, "vqmodel")) + discriminator.save_pretrained(os.path.join(output_dir, "discriminator")) + weights.pop() + weights.pop() + + def load_model_hook(models, input_dir): + discriminator = models.pop() + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") + discriminator.register_to_config(**load_model.config) + discriminator.load_state_dict(load_model.state_dict()) + del load_model + vqmodel = models.pop() + load_model = VQModel.from_pretrained(input_dir, subfolder="vqmodel") + vqmodel.register_to_config(**load_model.config) + vqmodel.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + learning_rate = args.learning_rate if args.scale_lr: learning_rate = ( @@ -759,15 +762,10 @@ def main(): column_names = dataset["train"].column_names # 6. Get the column names for input/target. - dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) - if args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] - else: - image_column = args.image_column - if image_column not in column_names: - raise ValueError( - f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" - ) + assert args.image_column is not None + image_column = args.image_column + if image_column not in column_names: + raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}") # Preprocessing the datasets. train_transforms = transforms.Compose( [ @@ -887,6 +885,7 @@ def collate_fn(examples): avg_gen_loss, avg_discr_loss = None, None for epoch in range(first_epoch, args.num_train_epochs): model.train() + discriminator.train() for i, batch in enumerate(train_dataloader): pixel_values = batch["pixel_values"] pixel_values = pixel_values.to(accelerator.device, non_blocking=True) @@ -1001,7 +1000,30 @@ def collate_fn(examples): data_time_m.reset() # Save model checkpoint if global_step % args.checkpointing_steps == 0: - save_checkpoint(model, discriminator, args, accelerator, global_step) + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") # Generate images if global_step % args.validation_steps == 0: @@ -1014,15 +1036,14 @@ def collate_fn(examples): accelerator.wait_for_everyone() - # Evaluate and save checkpoint at the end of training - save_checkpoint(model, discriminator, args, accelerator, global_step) - # Save the final trained checkpoint if accelerator.is_main_process: model = accelerator.unwrap_model(model) + discriminator = accelerator.unwrap_model(discriminator) if args.use_ema: ema_model.copy_to(model.parameters()) - model.save_pretrained(args.output_dir) + model.save_pretrained(os.path.join(args.output_dir, "vqmodel")) + discriminator.save_pretrained(os.path.join(args.output_dir, "discriminator")) accelerator.end_training() diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index b08c04be5b8f..876dc11c2ff3 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -176,7 +176,7 @@ def forward( """ h = self.encode(sample).latents - dec = self.decode(h).sample + dec = self.decode(h) if not return_dict: if return_loss: diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 8b138bf67f41..35d26cb761b5 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -89,6 +89,7 @@ def test_output_pretrained(self): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) + print(help(VQModel.forward)) with torch.no_grad(): output = model(image).sample @@ -97,3 +98,18 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + def test_loss_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.to(torch_device).eval() + + torch.manual_seed(0) + backend_manual_seed(torch_device, 0) + + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + output = model(image, return_loss=True).commit_loss.cpu() + # fmt: off + expected_output = torch.tensor([0.1936]) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) \ No newline at end of file From 0e79a25b0cbba990052b06daa2afda3b34b4432d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 28 Apr 2024 00:42:42 -0400 Subject: [PATCH 22/33] Merging --- examples/vqgan/train_vqgan.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 0494abb11586..6d0dcc4ac162 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -186,7 +186,7 @@ def log_validation(model, args, validation_transform, accelerator, global_step): wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images) ] }, - step=global_step + step=global_step, ) torch.cuda.empty_cache() return images @@ -782,6 +782,7 @@ def load_model_hook(models, input_dir): transforms.ToTensor(), ] ) + def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] @@ -998,7 +999,6 @@ def collate_fn(examples): data_time_m.reset() # Save model checkpoint if global_step % args.checkpointing_steps == 0: -<<<<<<< HEAD if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1024,9 +1024,6 @@ def collate_fn(examples): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") -======= - save_checkpoint(model, discriminator, args, accelerator, global_step) ->>>>>>> e1ed9ee1238da0eccbabf158b3b5d6688a49d3ae # Generate images if global_step % args.validation_steps == 0: log_validation(model, args, validation_transform, accelerator, global_step) From 508764498dd0d7137c4ddf31d327877ecca138b5 Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Sun, 28 Apr 2024 10:30:09 -0400 Subject: [PATCH 23/33] Update src/diffusers/models/vq_model.py Co-authored-by: Sayak Paul --- src/diffusers/models/vq_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 876dc11c2ff3..2fd6fcf28981 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -180,10 +180,7 @@ def forward( if not return_dict: if return_loss: - return ( - dec.sample, - dec.commit_loss, - ) + return dec.sample, dec.commit_loss return (dec.sample,) if return_loss: return dec From 45abf09afcd088ec84130e079ea37140a3b1f520 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 28 Apr 2024 17:18:00 -0400 Subject: [PATCH 24/33] Testing+Fixed bugs in training script --- examples/vqgan/README.md | 8 +- examples/vqgan/test_vqgan.py | 385 ++++++++++++++++++++++++++++++++++ examples/vqgan/train_vqgan.py | 24 ++- 3 files changed, 412 insertions(+), 5 deletions(-) create mode 100644 examples/vqgan/test_vqgan.py diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md index 34b3ad8ac5f0..a8417354cfa7 100644 --- a/examples/vqgan/README.md +++ b/examples/vqgan/README.md @@ -117,4 +117,10 @@ To lower the amount of layers in a VQGan, you can remove layers by modifying the "vq_embed_dim": 4 } ``` -For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that. \ No newline at end of file +For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that. + +## Extra training tips/ideas +During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646) +Secondly, training should seem to be done when both the discriminator and the generator loss converges. +Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it. +Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss. \ No newline at end of file diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py new file mode 100644 index 000000000000..76bc2c88e503 --- /dev/null +++ b/examples/vqgan/test_vqgan.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import sys +import tempfile +import torch + +from diffusers import VQModel +import json +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class TextToImage(ExamplesTestsAccelerate): + @property + def test_vqmodel_config(self): + return { + "_class_name": "VQModel", + "_diffusers_version": "0.17.0.dev0", + "act_fn": "silu", + "block_out_channels": [ + 32, + ], + "down_block_types": [ + "DownEncoderBlock2D", + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "norm_type": "spatial", + "num_vq_embeddings": 32, + "out_channels": 3, + "sample_size": 32, + "scaling_factor": 0.18215, + "up_block_types": [ + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4 + } + @property + def test_discriminator_config(self): + return { + "_class_name": "Discriminator", + "_diffusers_version": "0.27.0.dev0", + "in_channels": 3, + "cond_channels": 0, + "hidden_channels": 8, + "depth": 4 + } + def get_vq_and_discriminator_configs(self, tmpdir): + vqmodel_config_path = os.path.join(tmpdir, 'vqmodel.json') + discriminator_config_path = os.path.join(tmpdir, 'discriminator.json') + with open(vqmodel_config_path, 'w') as fp: + json.dump(self.test_vqmodel_config, fp) + with open(discriminator_config_path, 'w') as fp: + json.dump(self.test_discriminator_config, fp) + return vqmodel_config_path, discriminator_config_path + def test_vqmodel(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + test_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) + + def test_vqmodel_checkpointing(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4"}, + ) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=1 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + # In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step, + # the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5 + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_vqmodel_checkpointing_use_ema(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # check can run an intermediate checkpoint + model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming + shutil.rmtree(os.path.join(tmpdir, "checkpoint-2")) + + # Run training script for 2 total steps resuming from checkpoint 4 + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=1 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --output_dir {tmpdir} + --use_ema + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + # check can run new fully trained pipeline + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # no checkpoint-2 -> check old checkpoints do not exist + # check new checkpoints exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_vqmodel_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 6 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + # checkpoint-2 should have been deleted + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) + + def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) + # Run training script with checkpointing + # max_train_steps == 4, checkpointing_steps == 2 + # Should create checkpoints at steps 2, 4 + + initial_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 4 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --checkpointing_steps=2 + --output_dir {tmpdir} + --seed=0 + """.split() + + run_command(self._launch_args + initial_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-2", "checkpoint-4"}, + ) + + # resume and we should try to checkpoint at 6, where we'll have to remove + # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint + + resume_run_args = f""" + examples/vqgan/train_vqgan.py + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 32 + --image_column image + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 8 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --model_config_name_or_path {vqmodel_config_path} + --discriminator_config_name_or_path {discriminator_config_path} + --output_dir {tmpdir} + --checkpointing_steps=2 + --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')} + --checkpoints_total_limit=2 + --seed=0 + """.split() + + run_command(self._launch_args + resume_run_args) + + model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel") + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + _ = model(image) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-6", "checkpoint-8"}, + ) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 6d0dcc4ac162..93093badc1bd 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -257,13 +257,13 @@ def parse_args(): "--model_config_name_or_path", type=str, default=None, - help="The config of the Vq model to train, leave as None to use standard DDPM configuration.", + help="The config of the Vq model to train, leave as None to use standard Vq model configuration.", ) parser.add_argument( "--discriminator_config_name_or_path", type=str, default=None, - help="The config of the discriminator model to train, leave as None to use standard DDPM configuration.", + help="The config of the discriminator model to train, leave as None to use standard Vq model configuration.", ) parser.add_argument( "--revision", @@ -438,7 +438,7 @@ def parse_args(): parser.add_argument( "--dataloader_num_workers", type=int, - default=4, + default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), @@ -624,7 +624,7 @@ def main(): up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], vq_embed_dim=4, ) - elif args.pretrained_model_name_or_path is None: + elif args.pretrained_model_name_or_path is not None: model = VQModel.from_pretrained(args.pretrained_model_name_or_path) else: config = VQModel.load_config(args.model_config_name_or_path) @@ -673,6 +673,8 @@ def main(): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "vqmodel_ema")) vqmodel = models[0] discriminator = models[1] vqmodel.save_pretrained(os.path.join(output_dir, "vqmodel")) @@ -681,6 +683,11 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vqmodel_ema"), VQModel) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model discriminator = models.pop() load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") discriminator.register_to_config(**load_model.config) @@ -826,6 +833,8 @@ def collate_fn(examples): model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare( model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler ) + if args.use_ema: + ema_model.to(accelerator.device) # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") @@ -1026,7 +1035,14 @@ def collate_fn(examples): # Generate images if global_step % args.validation_steps == 0: + if args.use_ema: + # Store the VQGAN parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) log_validation(model, args, validation_transform, accelerator, global_step) + if args.use_ema: + # Switch back to the original VQGAN parameters. + ema_model.restore(model.parameters()) end = time.time() # Stop training if max steps is reached if global_step >= args.max_train_steps: From eb596840b8d3d794dbd20a9ab4f582e538534292 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 28 Apr 2024 17:26:09 -0400 Subject: [PATCH 25/33] Some style fixes --- examples/vqgan/test_vqgan.py | 24 ++++++++++++++------- tests/models/autoencoders/test_models_vq.py | 3 ++- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py index 76bc2c88e503..ed70873292c5 100644 --- a/examples/vqgan/test_vqgan.py +++ b/examples/vqgan/test_vqgan.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import shutil import sys import tempfile + import torch from diffusers import VQModel -import json + + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -59,8 +62,9 @@ def test_vqmodel_config(self): "up_block_types": [ "UpDecoderBlock2D", ], - "vq_embed_dim": 4 + "vq_embed_dim": 4, } + @property def test_discriminator_config(self): return { @@ -69,16 +73,18 @@ def test_discriminator_config(self): "in_channels": 3, "cond_channels": 0, "hidden_channels": 8, - "depth": 4 + "depth": 4, } + def get_vq_and_discriminator_configs(self, tmpdir): - vqmodel_config_path = os.path.join(tmpdir, 'vqmodel.json') - discriminator_config_path = os.path.join(tmpdir, 'discriminator.json') - with open(vqmodel_config_path, 'w') as fp: + vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json") + discriminator_config_path = os.path.join(tmpdir, "discriminator.json") + with open(vqmodel_config_path, "w") as fp: json.dump(self.test_vqmodel_config, fp) - with open(discriminator_config_path, 'w') as fp: + with open(discriminator_config_path, "w") as fp: json.dump(self.test_discriminator_config, fp) return vqmodel_config_path, discriminator_config_path + def test_vqmodel(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -101,7 +107,9 @@ def test_vqmodel(self): run_command(self._launch_args + test_args) # save_pretrained smoke test - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors"))) + self.assertTrue( + os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors")) + ) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) def test_vqmodel_checkpointing(self): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 426599ce8409..7e13adb83a4b 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -99,6 +99,7 @@ def test_output_pretrained(self): expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + def test_loss_pretrained(self): model = VQModel.from_pretrained("fusing/vqgan-dummy") model.to(torch_device).eval() @@ -113,4 +114,4 @@ def test_loss_pretrained(self): # fmt: off expected_output = torch.tensor([0.1936]) # fmt: on - self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) \ No newline at end of file + self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) From 97367cbbf646e4992cb223c43a266aa2fba6fb46 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sun, 28 Apr 2024 17:32:47 -0400 Subject: [PATCH 26/33] Added wandb to docs --- examples/vqgan/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md index a8417354cfa7..0b0f3589baf5 100644 --- a/examples/vqgan/README.md +++ b/examples/vqgan/README.md @@ -44,6 +44,7 @@ accelerate launch train_vqgan.py \ --report_to=wandb ``` +An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images). The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below. # Modifying the architecture From d797bcd3b709da96c70de8ce6b079ca750da77e9 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 29 Apr 2024 13:16:55 -0400 Subject: [PATCH 27/33] Fixed timm test --- examples/vqgan/test_vqgan.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py index ed70873292c5..b24706bccb48 100644 --- a/examples/vqgan/test_vqgan.py +++ b/examples/vqgan/test_vqgan.py @@ -14,12 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util import json import logging import os import shutil import sys import tempfile +import unittest import torch @@ -36,6 +38,31 @@ stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + + +def require_timm(test_case): + """ + Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. + """ + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + class TextToImage(ExamplesTestsAccelerate): @property @@ -85,6 +112,7 @@ def get_vq_and_discriminator_configs(self, tmpdir): json.dump(self.test_discriminator_config, fp) return vqmodel_config_path, discriminator_config_path + @require_timm def test_vqmodel(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -112,6 +140,7 @@ def test_vqmodel(self): ) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) + @require_timm def test_vqmodel_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -196,6 +225,7 @@ def test_vqmodel_checkpointing(self): {"checkpoint-4", "checkpoint-6"}, ) + @require_timm def test_vqmodel_checkpointing_use_ema(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -280,6 +310,7 @@ def test_vqmodel_checkpointing_use_ema(self): {"checkpoint-4", "checkpoint-6"}, ) + @require_timm def test_vqmodel_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -318,6 +349,7 @@ def test_vqmodel_checkpointing_checkpoints_total_limit(self): # checkpoint-2 should have been deleted self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) + @require_timm def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) From 1149dfbfa98b6ccc1e92479d569f5582d764b126 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 30 Apr 2024 08:20:31 +0530 Subject: [PATCH 28/33] get testing suite ready. --- .github/workflows/pr_tests.yml | 2 +- .github/workflows/push_tests.yml | 1 + .github/workflows/push_tests_fast.yml | 2 +- examples/vqgan/test_vqgan.py | 34 ++------------------------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 12 ++++++++++ src/diffusers/utils/testing_utils.py | 8 +++++++ 7 files changed, 26 insertions(+), 34 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index b1bed6568aa4..d5d1fc719305 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -156,7 +156,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install peft + python -m uv pip install peft timm python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index a6cb123a7035..90dbb570c9cf 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -424,6 +424,7 @@ jobs: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install timm python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ - name: Failure short reports diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 7c50da7b5c34..54ff48993768 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -107,7 +107,7 @@ jobs: if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" - python -m uv pip install peft + python -m uv pip install peft timm python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ --make-reports=tests_${{ matrix.config.report }} \ examples diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py index b24706bccb48..664a7f7365b0 100644 --- a/examples/vqgan/test_vqgan.py +++ b/examples/vqgan/test_vqgan.py @@ -14,18 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib.util import json import logging import os import shutil import sys import tempfile -import unittest import torch from diffusers import VQModel +from diffusers.utils.testing_utils import require_timm sys.path.append("..") @@ -38,32 +37,8 @@ stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) -# The package importlib_metadata is in a different place, depending on the python version. -if sys.version_info < (3, 8): - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - -_timm_available = importlib.util.find_spec("timm") is not None -if _timm_available: - try: - _timm_version = importlib_metadata.version("timm") - logger.info(f"Timm version {_timm_version} available.") - except importlib_metadata.PackageNotFoundError: - _timm_available = False - - -def is_timm_available(): - return _timm_available - - -def require_timm(test_case): - """ - Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. - """ - return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) - +@require_timm class TextToImage(ExamplesTestsAccelerate): @property def test_vqmodel_config(self): @@ -112,7 +87,6 @@ def get_vq_and_discriminator_configs(self, tmpdir): json.dump(self.test_discriminator_config, fp) return vqmodel_config_path, discriminator_config_path - @require_timm def test_vqmodel(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -140,7 +114,6 @@ def test_vqmodel(self): ) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors"))) - @require_timm def test_vqmodel_checkpointing(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -225,7 +198,6 @@ def test_vqmodel_checkpointing(self): {"checkpoint-4", "checkpoint-6"}, ) - @require_timm def test_vqmodel_checkpointing_use_ema(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -310,7 +282,6 @@ def test_vqmodel_checkpointing_use_ema(self): {"checkpoint-4", "checkpoint-6"}, ) - @require_timm def test_vqmodel_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) @@ -349,7 +320,6 @@ def test_vqmodel_checkpointing_checkpoints_total_limit(self): # checkpoint-2 should have been deleted self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"}) - @require_timm def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): with tempfile.TemporaryDirectory() as tmpdir: vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cdc92036613d..68c567a7b41e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -72,6 +72,7 @@ is_peft_version, is_scipy_available, is_tensorboard_available, + is_timm_available, is_torch_available, is_torch_npu_available, is_torch_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f5f57d8a5c5f..102a8c0c2cbb 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -295,6 +295,18 @@ except importlib_metadata.PackageNotFoundError: _torchvision_available = False +_timm_available = importlib.util.find_spec("timm") is not None +if _timm_available: + try: + _timm_version = importlib_metadata.version("timm") + logger.info(f"Timm version {_timm_version} available.") + except importlib_metadata.PackageNotFoundError: + _timm_available = False + + +def is_timm_available(): + return _timm_available + def is_torch_available(): return _torch_available diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index c1756d6590d1..8a6afd768428 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -33,6 +33,7 @@ is_onnx_available, is_opencv_available, is_peft_available, + is_timm_available, is_torch_available, is_torch_version, is_torchsde_available, @@ -340,6 +341,13 @@ def require_peft_backend(test_case): return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) +def require_timm(test_case): + """ + Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. + """ + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + + def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific From d705ed493e03530c1999bd92f27c8d568009819a Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 29 Apr 2024 23:21:52 -0400 Subject: [PATCH 29/33] remove return loss --- examples/vqgan/train_vqgan.py | 2 +- src/diffusers/models/vq_model.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index 93093badc1bd..b7beee1f3b26 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -910,7 +910,7 @@ def collate_fn(examples): discr_optimizer.zero_grad(set_to_none=True) # encode images to the latent space and get the commit loss from vq tokenization # Return commit loss - fmap, commit_loss = model(pixel_values, return_dict=False, return_loss=True) + fmap, commit_loss = model(pixel_values, return_dict=False) if generator_step: with accelerator.accumulate(model): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 2fd6fcf28981..590c9f665d99 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -158,7 +158,7 @@ def decode( return DecoderOutput(sample=dec, commit_loss=commit_loss) def forward( - self, sample: torch.FloatTensor, return_dict: bool = True, return_loss: bool = False + self, sample: torch.FloatTensor, return_dict: bool = True ) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]: r""" The [`VQModel`] forward method. @@ -179,9 +179,5 @@ def forward( dec = self.decode(h) if not return_dict: - if return_loss: - return dec.sample, dec.commit_loss - return (dec.sample,) - if return_loss: - return dec - return DecoderOutput(sample=dec.sample) + return dec.sample, dec.commit_loss + return dec From 75d36b6cf16a9ddb526752ec0f318fec13d23641 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 29 Apr 2024 23:24:04 -0400 Subject: [PATCH 30/33] remove return_loss --- tests/models/autoencoders/test_models_vq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 7e13adb83a4b..ec1764b1b265 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -110,7 +110,7 @@ def test_loss_pretrained(self): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): - output = model(image, return_loss=True).commit_loss.cpu() + output = model(image).commit_loss.cpu() # fmt: off expected_output = torch.tensor([0.1936]) # fmt: on From 6e3ef01e48e881717696ddc159928e84ace741f2 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 29 Apr 2024 23:53:32 -0400 Subject: [PATCH 31/33] Remove diffs --- src/diffusers/models/vq_model.py | 2 -- tests/models/autoencoders/test_models_vq.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 590c9f665d99..60e1bfdab485 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -167,8 +167,6 @@ def forward( sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. - return_loss (`bool`, *optional*, defaults to `False`): - Whether or not to return a commit loss. Returns: [`~models.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index ec1764b1b265..c61ae1bdf0ff 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -90,7 +90,6 @@ def test_output_pretrained(self): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) - print(help(VQModel.forward)) with torch.no_grad(): output = model(image).sample From adbed4591a8fd0c01473b1df339014ea6ac98f71 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 29 Apr 2024 23:54:28 -0400 Subject: [PATCH 32/33] Remove diffs --- src/diffusers/models/vq_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 60e1bfdab485..468ed19e795d 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -167,6 +167,7 @@ def forward( sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + Returns: [`~models.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` From 9f461212e857658071e3bf86cee1c22010208f95 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 14 May 2024 23:01:40 -0400 Subject: [PATCH 33/33] fix ruff format --- src/diffusers/utils/import_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 464802cb40f9..b8ce2d7c0466 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -307,6 +307,7 @@ def is_timm_available(): return _timm_available + _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None try: _bitsandbytes_version = importlib_metadata.version("bitsandbytes")