# Latent Consistency Distillation

**Latent Consistency Models (LCMs)** are able to generate high-quality images in just a few steps, representing a big leap forward because many pipelines require at least 25+ steps.

LCMs are produced by applying the latent consistency disillation method to any Stable Diffusion model. This method works by applying *one-stage guided distillation* to the latent space, and incorporating a *skipping-step* method to consistently skip timesteps to accelerate the disillation process.

We will explore the [`train_lcm_disill_sd_wds.py`](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py)


As always, make sure to install the `diffusers` library from source:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```

We need to navigate to the following folder and install the corresponding dependencies:
```bash
cd examples/consistency_distillation
pip install -r requirements.txt
```
HuggingFace Accelerate is also helpful to train on multiple GPUs with mixed-precision.
```bash
pip install accelerate
```

Now we can initialize a HuggingFace Accelerate environment
```bash
accelerate config
```
To set up a default Acclerate environment without choosing any configurations:
```bash
accelerate config default
```
Or if our environment does not support an interactive shell like a notebook, we can use:
```python
from accelerate.utils import write_basic_config
write_basic_config()
```

## Script parameters

The training script provides many parameters to customize the training run. We can find all of the parameters and their descriptions in the `parse_args()` function.

Many parameters are described in the **Text-to-image** training guide, so we only focus on parameters relevant to LCD:
* `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model
* `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows use to specify an alternative VAE
* `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling
* `--num_ddim_timesteps`: the number of teimsteps for DDIM sampling
* `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers
* `--huber_c`: the Huber loss parameter

## Training script

The training script starts by creating a dataset class `Text2ImageDataset` for preprocessing the images and creating a training dataset:
```python
        def transform(example):
            # resize image
            image = example["image"]
            image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)

            # get crop coordinates and crop image
            c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
            image = TF.crop(image, c_top, c_left, resolution, resolution)
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.5], [0.5])

            example["image"] = image
            return example
```
For improved performance on reading and writing large datasets stored in the cloud, the script uses the `WebDataset` format to create a preprocessing pipeline to apply transforms and create a dataset and dataloader for training. Images are processed and fed to the training loop without having to download the full dataset first:
```python
        processing_pipeline = [
            wds.decode("pil", handler=wds.ignore_and_continue),
            wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
            wds.map(filter_keys({"image", "text"})),
            wds.map(transform),
            wds.to_tuple("image", "text"),
        ]

        # Create train dataset and loader
        pipeline = [
            wds.ResampledShards(train_shards_path_or_url),
            tarfile_to_samples_nothrow,
            wds.shuffle(shuffle_buffer_size),
            *processing_pipeline,
            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
        ]
```

In the `main()` function, all the necessary components like the noise scheduler, tokenizers, text encoders, and VAE are loaded. The teacher UNet is also loaded here and then we can create a student UNet from the teacher UNet. The student UNet is updated by the optimizer during training.
```python
    # 5. Load teacher U-Net from SD-XL checkpoint
    teacher_unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
    )

    # 6. Freeze teacher vae, text_encoder, and teacher_unet
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    teacher_unet.requires_grad_(False)

    # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
    # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
    if teacher_unet.config.time_cond_proj_dim is None:
        teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
    unet = UNet2DConditionModel(**teacher_unet.config)
    # load teacher_unet weights into unet
    unet.load_state_dict(teacher_unet.state_dict(), strict=False)
    unet.train()
```
Then we can create the optimizer to update the UNet parameters:
```python
    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    # 12. Optimizer creation
    optimizer = optimizer_class(
        unet.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
```
Next, we create the dataset:
```python
    dataset = Text2ImageDataset(
        train_shards_path_or_url=args.train_shards_path_or_url,
        num_train_examples=args.max_train_samples,
        per_gpu_batch_size=args.train_batch_size,
        global_batch_size=args.train_batch_size * accelerator.num_processes,
        num_workers=args.dataloader_num_workers,
        resolution=args.resolution,
        shuffle_buffer_size=1000,
        pin_memory=True,
        persistent_workers=True,
    )
    train_dataloader = dataset.train_dataloader
```
Next, we are ready to set up the training loop and implement the latent consistency distillation method. This section takes care of adding noise to the latents, sampling and creating a guidance scale embedding, and predicting the original image from the noise.

The training loop gets the teacher model predictions and the LCM predictions, calculates the loss, and then backpropagates it to the LCM.

## Launch the script

We will use the `--train_shards_path_or_url` to specify the path to the **Conceptual Captions 12M** dataset stored on the Hub.
```bash
export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sd_wds.py \
    --pretrained_teacher_model=$MODEL_DIR \
    --output_dir=$OUTPUT_DIR \
    --mixed_precision=fp16 \
    --resolution=512 \
    --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
    --max_train_steps=1000 \
    --max_train_samples=4000000 \
    --dataloader_num_workers=8 \
    --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
    --validation_steps=200 \
    --checkpointing_steps=200 --checkpoints_total_limit=10 \
    --train_batch_size=12 \
    --gradient_checkpointing --enable_xformers_memory_efficient_attention \
    --gradient_accumulation_steps=1 \
    --use_8bit_adam \
    --resume_from_checkpoint=latest \
    --report_to=wandb \
    --seed=453645634 \
    --push_to_hub
```

Once training is complete,

In [None]:
from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
import torch

unet = UNet2DConditionModel.from_pretrained(
    'our-username/our-model',
    torch_dtype=torch.float16,
    variant='fp16'
)

pipeline = DiffusionPipeline.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-v1-5',
    unet=unet,
    torch_dtype=torch.float16,
    variant='fp16'
)
pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
pipeline.to('cuda')

In [None]:
prompt = "sushi rolls in the form of panda heads, sushi platter"

image = pipeline(
    prompt,
    num_inference_steps=4,
    guidance_scale=1.0
).images[0]
image

## LoRA

We can use [`train_lcm_distill_lora_sd_wds.py`](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [`train_lcm_distill_lora_sdxl_wds.py`](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) scripts to train SD or SDXL with LoRA, respectively.

## SDXL

We can use [`train_lcm_distill_sdxl_wds.py`](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py) script to train a SDXL model with LCD.