In [1]:
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""

import argparse
import logging
import math
import os
import random
import sys
import time
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from datetime import datetime
from threading import Lock

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers import Mel

import fma_utils

import IPython.display as ipd

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
#check_min_version("0.15.0.dev0")

logger = get_logger(__name__, log_level="INFO")


In [2]:
class FMADataset(torch.utils.data.Dataset):
    def __init__(self, fma_dir='~/data2', subset='small', fs=22050, audio_subdir='fma_large', resolution=512):
        assert subset in ['small', 'medium', 'large']
        self.fma_dir = os.path.expanduser(os.path.expandvars((fma_dir)))
        self.audio_subdir = audio_subdir
        self.subset = subset
        self.fs = fs
        tracks = fma_utils.load(os.path.join(fma_dir, 'fma_metadata', 'tracks.csv'))
        self.tracks = tracks[tracks['set', 'subset'] <= self.subset]
        self.genres = fma_utils.load(os.path.join(fma_dir, 'fma_metadata', 'genres.csv'))
        # self.features = fma_utils.load(os.path.join(fma_dir, 'fma_metadata', 'features.csv'))
        self.echonest = fma_utils.load(os.path.join(fma_dir, 'fma_metadata', 'echonest.csv'))
        self.mel = Mel(x_res= resolution, y_res=resolution, sample_rate=self.fs, n_fft=2048,
                       hop_length=resolution, top_db=80, n_iter=32,)
        # Find top 200 tags for captioning
        tagseries = self.tracks.loc[:, 'track'].tags
        alltags = []
        for tag in tagseries:
            alltags.extend(tag)
        tags, counts = np.unique(alltags, return_counts=True)
        isort = list(reversed(np.argsort(counts)))
        tags = tags[isort]
        counts = counts[isort]
        self.top_tags = tags[:200]
        # Force getitem to be atomic, just in case
        self.lock = Lock()

    def get_echonest_description(self, track_id):
        description = ""
        try:
            en = self.echonest.loc[track_id, ('echonest', 'audio_features',)]
            if en.acousticness > 0.7:
                description += "acoustic "
            if en.danceability > 0.6:
                description += "danceable "
            if en.energy > 0.6:
                description += "energetic "
            if en.instrumentalness > 0.8:
                description += "instrumental "
            if en.liveness > 0.5:
                description += "live "
            if en.speechiness > 0.4:
                description += "spoken "
            if en.tempo < 81:
                description += "down-tempo "
            elif en.tempo > 150:
                description += "up-tempo "
        except KeyError:
            pass
        return description.strip()

    def get_genre_description(self, track_id):
        description = ""
        genres = self.tracks.loc[track_id, ('track', 'genres_all')]
        subgenres = set()
        topgenres = set()
        for genre in genres:
            info = self.genres.loc[genre]
            if info.parent == 0:
                topgenres.add(info.name)
            else:
                subgenres.add(info.name)
                topgenres.add(info.parent)
        subnames = [self.genres.loc[g].title for g in subgenres]
        topnames = [self.genres.loc[g].title for g in topgenres]
        description += " ".join(subnames) + " " + " and ".join(topnames)
        return description

    def get_caption(self, track_id):
        caption = "a "
        # Add echonest features
        caption += self.get_echonest_description(track_id)
        # Add genre
        caption += self.get_genre_description(track_id)
        caption += " song"
        # Add tags
        tags = self.tracks.loc[track_id, ('track', 'tags')]
        to_tag = []
        for t in tags:
            if t in self.top_tags:
                to_tag.append(t)
        if to_tag:
            caption += ", tagged " + ", ".join(to_tag)
        return caption

    def __len__(self):
        return len(self.tracks)

    def __getitem__(self, idx):
        self.lock.acquire()
        track_id = self.tracks.iloc[idx].name
        filepath = fma_utils.get_audio_path(os.path.join(self.fma_dir, self.audio_subdir), track_id)
        self.mel.load_audio(filepath)
        nslices = self.mel.get_number_of_slices()
        islice = np.random.choice(nslices)
        Im = self.mel.audio_slice_to_image(islice)
        caption = self.get_caption(track_id)
        self.lock.release()
        return Im, caption

    def get_track_id(self, idx):
        return self.tracks.iloc[idx].name

    def get_audio_plus_stft(self, idx):
        track_id = self.tracks.iloc[idx].name
        filepath = fma_utils.get_audio_path(os.path.join(self.fma_dir, self.audio_subdir), track_id)
        self.mel.load_audio(filepath)
        nslices = self.mel.get_number_of_slices()
        islice = np.random.choice(nslices)
        Im = self.mel.audio_slice_to_image(islice)
        return Im, self.mel.get_audio_slice(islice)

    def get_unique_genres(self):
        genres = set()
        for i in range(len(self)):
            tid = self.get_track_id(i)
            genres.add(self.get_genre_description(tid))
        return genres

In [3]:
ds = FMADataset()

In [4]:
song_idx = 230
Im, cap = ds[song_idx]
print(cap)
Im, audio = ds.get_audio_plus_stft(song_idx)
ipd.display(Im)
ipd.display(ipd.Audio(audio, rate=ds.fs))
reconstructed = ds.mel.image_to_audio(Im)
ipd.display(ipd.Audio(reconstructed, rate=ds.fs))

In [5]:
def preprocess_file(dataset, idx, output_dir, verbose):
    if verbose:
        t = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f")
        print(f"{t} >> idx {idx}", flush=True)
    Im, caption = dataset[idx]
    track_id = dataset.get_track_id(idx)
    subdir = f"{track_id:06}"[:3]
    os.makedirs(f"{output_dir}/{subdir}", exist_ok=True)
    fname = f"{output_dir}/{subdir}/{track_id:06}.png"
    Im.save(fname)
    return f"{subdir}/{track_id:06}.png", caption


def preprocess_fma_audio(data_dir, fma_subset, output_dir, resolution=512, sample_rate=22050, skip_idxs=[], start_from=0, verbose=False):
    import pandas as pd
    import concurrent.futures

    ds = FMADataset(fma_dir=data_dir, subset=fma_subset, resolution=resolution, fs=sample_rate)
    N = len(ds)
    metadata = []
    os.makedirs(output_dir, exist_ok=True)
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        # Start the preprocessing operations and mark each future with its index
        future_to_idx = {executor.submit(preprocess_file, ds, i, output_dir, verbose): i for i in range(start_from, N) if i not in skip_idxs}
        progress = tqdm(total=N) if not verbose else False
        for future in concurrent.futures.as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                data = future.result()
            except Exception as e:
                print(f"Index {idx} generated exception: {e}", file=sys.stderr)
                raise e
            metadata.append(data)
            if progress:
                progress.update(1)
    if progress:
        progress.close()
    metadata = pd.DataFrame(metadata, columns=["file_name", "text"])
    metadata.to_csv(f"{output_dir}/metadata.csv", index=False)
    return metadata


In [6]:
# small_skip_idxs_all = [490, 901, 1181, 2265, 2267, 4423, 4424, 4425, 4470, 4903, 6965]
# small_skip_idxs_err = [4470, 4903, 6965, ]
# preprocess_fma_audio("~/data2", fma_subset='small', output_dir='test', skip_idxs=small_skip_idxs_all, start_from=0, verbose=True)

In [7]:
@dataclass
class LoRATrainConfig:
    pretrained_model_name_or_path: str = 'runwayml/stable-diffusion-v1-5'
    revision: Optional[str] = None
    dataset_name: Optional[str] = None
    dataset_config_name: Optional[str] = None
    train_data_path: str = './data/fma_preprocessed'
    image_column: str = 'image'
    caption_column: str = 'text'
    preprocess_audio: bool = False
    # Required if preprocess_audio is True
    fma_path: Optional[str] = '~/data2'
    validation_prompt: str = 'a energetic hip hop song'
    num_validation_images: int = 4
    validation_epochs: int = 1
    max_train_samples: Optional[int] = None
    output_dir: str = 'sd-stft-lora'
    cache_dir: Optional[str] = None
    seed: Optional[int] = 42
    resolution: int = 512
    sample_rate: int = 22050
    # center_crop: bool = False
    # random_flip: bool = False
    train_batch_size: int = 1
    num_train_epochs: int = 20
    max_train_steps: Optional[int] = None
    gradient_accumulation_steps: int = 1
    gradient_checkpointing: bool = False
    learning_rate: float = 0.0003
    scale_lr: bool = False
    lr_scheduler: str = 'constant'
    lr_warmup_steps: int = 0
    use_8bit_adam: bool = False
    allow_tf32: bool = False
    dataloader_num_workers: int = 2
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_weight_decay: float = 0.01
    adam_epsilon: float = 1e-08
    max_grad_norm: float = 1.0
    push_to_hub: bool = False
    hub_token: Optional[str] = None
    hub_model_id: Optional[str] = None
    logging_dir: str = 'logs'
    mixed_precision: Optional[str] = 'fp16'
    report_to: str = 'tensorboard'
    local_rank: int =-1
    checkpointing_steps: int = 5000
    checkpoints_total_limit: Optional[int] = None
    resume_from_checkpoint: Optional[int] = None
    enable_xformers_memory_efficient_attention: bool = True

In [8]:
def get_lora_layers(args, unet):
    # It's important to realize here how many attention weights will be added and of which sizes
    # The sizes of the attention layers consist only of two different variables:
    # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
    # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

    # Let's first see how many attention processors we will have to set.
    # For Stable Diffusion, it should be equal to:
    # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
    # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
    # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
    # => 32 layers

    # Set correct lora layers
    lora_attn_procs = {}
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]

        lora_attn_procs[name] = LoRACrossAttnProcessor(
            hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
        )

    unet.set_attn_processor(lora_attn_procs)

    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")
    return AttnProcsLayers(unet.attn_processors)

In [9]:
def get_dataloader(args, tokenizer, accelerator):
    # If needed, preprocess audio into image+caption pairs
    if args.preprocess_audio:
        assert os.path.exists(os.path.expanduser(args.fma_path)), f"FMA data path {args.fma_path} does not exist"
        small_skip_idxs_all = [490, 901, 1181, 2265, 2267, 4423, 4424, 4425, 4470, 4903, 6965]
        small_skip_idxs_err = [4470, 4903, 6965, ]
        preprocess_fma_audio(args.fma_path, fma_subset='small', output_dir=args.train_data_path, skip_idxs=small_skip_idxs_all, start_from=0)

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        dataset = load_dataset(
            'imagefolder',
            data_dir=args.train_data_path
        )

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    # 6. Get the column names for input/target.
    dataset_columns = ("image", "text")
    if args.image_column is None:
        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )
    if args.caption_column is None:
        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        caption_column = args.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
            )

    # Preprocessing the datasets.
    # We need to tokenize input captions and transform the images.
    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
        return inputs.input_ids

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            # transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            # transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            # transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.0], [0.5]),
        ]
    )

    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["input_ids"] = tokenize_captions(examples)
        return examples

    with accelerator.main_process_first():
        if args.max_train_samples is not None:
            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
        # Set the training transforms
        train_dataset = dataset["train"].with_transform(preprocess_train)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        return {"pixel_values": pixel_values, "input_ids": input_ids}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

    return train_dataset, train_dataloader

In [10]:
def get_models_scheduler(args, weight_dtype, device):
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    )
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
    )
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    )
    # freeze parameters of models to save more memory
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Move unet, vae and text_encoder to device and cast to weight_dtype
    unet.to(device, dtype=weight_dtype)
    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)

    return noise_scheduler, tokenizer, text_encoder, vae, unet

In [11]:
def main(args):
    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        logging_dir=logging_dir,
        project_config=accelerator_project_config,
    )
    if args.report_to == "wandb":
        if not is_wandb_available():
            raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
        import wandb

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
            else:
                repo_name = args.hub_model_id
            repo_name = create_repo(repo_name, exist_ok=True)
            repo = Repository(args.output_dir, clone_from=repo_name)

            with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Load scheduler, tokenizer and models.
    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    
    noise_scheduler, tokenizer, text_encoder, vae, unet = get_models_scheduler(args, weight_dtype, accelerator.device)

    # now we will add new LoRA weights to the attention layers
    lora_layers = get_lora_layers(args, unet)

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Initialize the optimizer
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        lora_layers.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
    train_dataset, train_dataloader = get_dataloader(args, tokenizer, accelerator)

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    # Prepare everything with our `accelerator`.
    lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        lora_layers, optimizer, train_dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        accelerator.init_trackers("text2image-fine-tune", config=vars(args))

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    for epoch in range(first_epoch, args.num_train_epochs):
        unet.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue

            with accelerator.accumulate(unet):
                # Convert images to latent space
                latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # Predict the noise residual and compute loss
                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                # Backpropagate
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = lora_layers.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                if global_step % args.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        if accelerator.is_main_process:
            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
                logger.info(
                    f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
                    f" {args.validation_prompt}."
                )
                # create pipeline
                pipeline = DiffusionPipeline.from_pretrained(
                    args.pretrained_model_name_or_path,
                    unet=accelerator.unwrap_model(unet),
                    revision=args.revision,
                    torch_dtype=weight_dtype,
                )
                pipeline = pipeline.to(accelerator.device)
                pipeline.set_progress_bar_config(disable=True)

                # run inference
                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
                images = []
                for _ in range(args.num_validation_images):
                    images.append(
                        pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
                    )

                if accelerator.is_main_process:
                    for tracker in accelerator.trackers:
                        if tracker.name == "tensorboard":
                            np_images = np.stack([np.asarray(img) for img in images])
                            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
                        if tracker.name == "wandb":
                            tracker.log(
                                {
                                    "validation": [
                                        wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
                                        for i, image in enumerate(images)
                                    ]
                                }
                            )

                del pipeline
                torch.cuda.empty_cache()

    # Save the lora layers
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        unet = unet.to(torch.float32)
        unet.save_attn_procs(args.output_dir)

        if args.push_to_hub:
            save_model_card(
                repo_name,
                images=images,
                base_model=args.pretrained_model_name_or_path,
                dataset_name=args.dataset_name,
                repo_folder=args.output_dir,
            )
            repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

    # Final inference
    # Load previous pipeline
    pipeline = DiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
    )
    pipeline = pipeline.to(accelerator.device)

    # load attention processors
    pipeline.unet.load_attn_procs(args.output_dir)

    # run inference
    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
    images = []
    for _ in range(args.num_validation_images):
        images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])

    if accelerator.is_main_process:
        for tracker in accelerator.trackers:
            if tracker.name == "tensorboard":
                np_images = np.stack([np.asarray(img) for img in images])
                tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
            if tracker.name == "wandb":
                tracker.log(
                    {
                        "test": [
                            wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
                            for i, image in enumerate(images)
                        ]
                    }
                )

    accelerator.end_training()


In [12]:
config = LoRATrainConfig(preprocess_audio=False, resume_from_checkpoint=None)
print(config)
print("*" * 40)
main(config)