From d7ed1239a12b02d2b3d01958e5dec320441124e0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 23 Jun 2023 02:22:58 -0700 Subject: [PATCH] initial commit (in progress) --- examples/ddpo/finetune_ddpo.py | 860 +++++++++++++++++++++++++++++++++ examples/ddpo/requirements.txt | 4 + 2 files changed, 864 insertions(+) create mode 100644 examples/ddpo/finetune_ddpo.py create mode 100644 examples/ddpo/requirements.txt diff --git a/examples/ddpo/finetune_ddpo.py b/examples/ddpo/finetune_ddpo.py new file mode 100644 index 000000000000..ddd195ba5378 --- /dev/null +++ b/examples/ddpo/finetune_ddpo.py @@ -0,0 +1,860 @@ +import argparse +import inspect +import logging +import math +import os +import shutil +from collections import deque +from pathlib import Path +from typing import Optional + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import HfFolder, Repository, create_repo, whoami +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +# Copied from original implementation: https://github.com/jannerm/ddpo/blob/main/ddpo/utils/stat_tracking.py +# Uses a ring buffer to keep track of rewards seen per prompt for reward normalization. +class PerPromptStatTracker: + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return { + k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} + for k, v in self.stats.items() + } + + +# TODO: add helper functions to get prompts for data collection +# TODO: implement reward functions and integrate with training loop + + +def parse_args(): + parser = argparse.ArgumentParser(description="Example of finetuning a diffusion model using DDPO.") + # ---------- Model/Pipeline Args ---------- + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpo-finetune", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + # ---------- Dataset Args ------------ + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + # ---------- Random Seeds and Reproducibility ---------- + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ---------- Image Preprocessing Args ---------- + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ---------- Training and Evaluation Args ---------- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + # ---------- DDPO-Specific Training Args ---------- + # Data Collection from policy (that is, from the diffusion model + scheduler) + parser.add_argument( + "--sample_batch_size", + type=int, + default=64, + help="The number of images to generate when collecting samples from the diffusion model for RL training.", + ) + parser.add_argument( + "--num_sample_batches_per_epoch", + type=int, + default=1, + help=( + "The number of batches to sample from the diffusion model when collecting data in each epoch of training." + ), + ) + # Classifier-Free Guidance (CFG) Training + parser.add_argument( + "--train_cfg", + action="store_true", + help="Whether to use classifier-free guidance (CFG) training as described in Appendix B.1 of the DDPO paper.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=5.0, + help=( + "Guidance scale (w) for use in classifier-free guidance (CFG) training as described in Appendix B.1 of the" + " DDPO paper." + ), + ) + # PPO-Clip parameters + parser.add_argument( + "--num_inner_epochs", + type=int, + default=1, + help="Number of times to loop over each batch of collected data during PPO training.", + ) + parser.add_argument( + "--train_timestep_ratio", + type=float, + default=1.0, + help="The fraction of collected timesteps to use for PPO training.", + ) + parser.add_argument( + "--ppo_clip_range", + type=float, + default=1e-4, + help=( + "Clip range (centered around 1.0) to clip the surrogate objective in PPO-Clip. This limits the size of" + " the update made by the PPO algorithm." + ), + ) + # Reward Normalization + parser.add_argument( + "--per_prompt_stats_bufsize", + type=int, + default=32, + help=( + "The size of the ring buffer for tracking per-prompt reward statistics, which are use to normalize the" + " reward for RL training." + ), + ) + parser.add_argument( + "--per_prompt_stats_min_count", + type=int, + default=16, + help=( + "When calculating reward statistics for a prompt, if there are at least min_count samples for that" + " prompt, we will use those samples to calculate the stats; otherwise, we will back off to the stats for" + " all reward samples (regardless of prompt)." + ), + ) + # ---------- Sampling Args (for data collection and evaluation) ---------- + parser.add_argument( + "--num_inference_steps", + type=int, + default=50, + help="Num inference steps when generating samples (for both data collection and evaluation).", + ) + parser.add_argument("--eta", type=float, default=1.0, help="Eta parameter for the DDIM scheduler.") + # ---------- Learning Rate Schedule Args ---------- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + # ---------- Optimizer Args ---------- + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-4, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + # ---------- EMA Args ---------- + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + # ---------- Huggingface Hub Args ---------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + # ---------- Logging Args ---------- + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + help=( + "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)" + " for experiment tracking and logging of model metrics and model checkpoints" + ), + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + # ---------- Distributed Training Args ----------- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ---------- Checkpointing Args ---------- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.logger, + project_config=accelerator_project_config, + ) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if args.logger == "tensorboard": + if not is_tensorboard_available(): + raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.") + + elif args.logger == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the pretrained models, scheduler, and tokenizer. + noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Create EMA for the model. + if args.use_ema: + ema_unet = EMAModel( + unet.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + model_cls=UNet2DConditionModel, + model_config=unet.config, + ) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + optimizer = torch.optim.AdamW( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # TODO: actually, do we need a dataloader at all? + # In each loop of training we collect data from the policy (diffusion model) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + split="train", + ) + else: + # TODO: get default dataset??? (need to check this logic) + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets and DataLoaders creation. + augmentations = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + # Tokenize input prompts: + def tokenize_captions(captions): + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + def transform_images(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + logger.info(f"Dataset size: {len(dataset)}") + + # TODO: need to load the dataset and apply appropriate transforms to text and images. + # Can we do this with set_transform?? + # See https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py#L656 + dataset.set_transform(transform_images) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + + # Initialize the learning rate scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs), + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encoder and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + # DDPO-specific setup + + # TODO: actually, can this be handled by StableDiffusionPipeline internally? + if args.train_cfg: + # TODO: prepare the unconditional prompt embedding for CFG + pass + + # Initialize the stat tracker for reward normalization + per_prompt_stats = PerPromptStatTracker( + args.per_prompt_stats_bufsize, args.per_prompt_stats_min_count + ) + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Train! + for epoch in range(first_epoch, args.num_epochs): + unet.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + # TODO: Need to do anything before data collection? + + # Prepare a pipeline using current models for inference + # TODO: check whether we need to rewrap unet after unwrapping for data collection + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), + unet=unet, + revision=args.revision, + ) + + # Loop over args.num_sample_batches_per_epoch and get samples from policy + # (that is, samples from the diffusion model and scheduler). + for sample_epoch in range(args.num_sample_batches_per_epoch): + data_collection_progress_bar = tqdm( + total=args.num_sample_batches_per_epoch, disable=not accelerator.is_local_main_process + ) + data_collection_progress_bar.set_description(f"Epoch {epoch}: sample epoch {sample_epoch}") + + # TODO: get prompts + # TODO: Using prompts, generate a batch of samples from pipeline + # TODO: is there a way to get the entire sampling trajectory from __call__? + # See https://github.com/jannerm/ddpo/blob/main/ddpo/diffusers_patch/scheduling_ddim_flax.py + + data_collection_progress_bar.update(1) + + if args.use_ema: + ema_unet.restore(unet.parameters()) + + # TODO: calculate rewards for samples (implemented via callbacks in the original code) + + # TODO: want to accumulate gradients along trajectories, how does that interact with + # the accelerator.accumulate context manager? + with accelerator.accumulate(unet): + # Loop over inner epochs of PPO training. + for inner_epoch in range(args.num_inner_epochs): + ppo_progress_bar = tqdm( + total=args.num_inner_epochs, disable=not accelerator.is_local_main_process + ) + ppo_progress_bar.set_description(f"Epoch {epoch}: PPO epoch {inner_epoch}") + + # TODO: implement PPO RL training on collected sampling trajectories + + ppo_progress_bar.update(1) + + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_unet.cur_decay_value + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # TODO: generate reward curve + + if accelerator.is_main_process: + # If it's time to save samples, save the samples gathered from the data collection loop. + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + # TODO: save collected samples + pass + + # If it's time to save a checkpoint, create the pipeline using the trained modules and save it. + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + model = accelerator.unwrap_model(unet) + + if args.use_ema: + ema_unet.store(model.parameters()) + ema_unet.copy_to(model.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + unet=unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + + pipeline.save_pretrained(args.output_dir) + + if args.use_ema: + ema_unet.restore(model.parameters()) + + if args.push_to_hub: + repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/examples/ddpo/requirements.txt b/examples/ddpo/requirements.txt new file mode 100644 index 000000000000..598b43daf5f2 --- /dev/null +++ b/examples/ddpo/requirements.txt @@ -0,0 +1,4 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +datasets \ No newline at end of file