<a href="https://colab.research.google.com/github/junhsss/consistency-models/blob/main/examples/consistency_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Consistency Models** ðŸŒƒ
*...using `consistency`*

**Consistency Models** are a new family of generative models that achieve high sample quality without adversarial training. They support *fast one-step generation* by design, while still allowing for few-step sampling to trade compute for sample quality. It's amazing!

### Setup

Please make sure you are using a GPU runtime to run this notebook. If the following command fails, use the `Runtime` menu above and select `Change runtime type`.

In [None]:
!nvidia-smi

In [None]:
!pip install datasets wandb consistency==0.2.4

In [None]:
!wandb login

In [None]:
DATASET_NAME = "cifar10"
RESOLUTION = 32
BATCH_SIZE = 128
MAX_EPOCHS = 200
LEARNING_RATE = 1e-4

SAMPLES_PATH = "./samples"
NUM_SAMPLES = 64
SAMPLE_STEPS = 1  # Set this value larger if you want higher sample quality.

In [None]:
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name: str, dataset_config_name=None):
        self.dataset = load_dataset(
            dataset_name,
            dataset_config_name,
            split="train",
        )
        self.image_key = [
            key for key in ("image", "img") if key in self.dataset[0]
        ][0]
        self.augmentations = transforms.Compose(
    [
        transforms.Resize(
            RESOLUTION,
            interpolation=transforms.InterpolationMode.BILINEAR,
        ),
        transforms.CenterCrop(RESOLUTION),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.augmentations(self.dataset[index][self.image_key].convert("RGB"))

dataloader = DataLoader(
    Dataset(DATASET_NAME),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

### Define Models

`Consistency` accepts any unet-like model as its backbone. 
We recommend `UNet2DModel` of `diffusers` ðŸ¤— as a default option.

In [None]:
from diffusers import UNet2DModel
from consistency import Consistency
from consistency.loss import PerceptualLoss

consistency = Consistency(
    model=UNet2DModel(
        sample_size=RESOLUTION,
        in_channels=3,
        out_channels=3,
        layers_per_block=1,
        block_out_channels=(128, 128, 256, 256),
        down_block_types=(
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
            "DownBlock2D"
        ),
        up_block_types=(
            "UpBlock2D",
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
        ),
    ),
    # You could use multiple net types. 
    # Recommended setting is "squeeze" + "vgg"
    # loss_fn=PerceptualLoss(net_type=("squeeze", "vgg"))
    # See https://github.com/richzhang/PerceptualSimilarity
    loss_fn=PerceptualLoss(net_type="squeeze"), 
    learning_rate=LEARNING_RATE,
    samples_path=SAMPLES_PATH,
    save_samples_every_n_epoch=1,
    num_samples=NUM_SAMPLES,
    sample_steps=SAMPLE_STEPS,
    sample_ema=True,
    sample_seed=42,
)

### Training

You can see the generated images in `SAMPLES_PATH` or in **Wandb Workspace** as the training progresses.

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.wandb import WandbLogger

trainer = Trainer(
    accelerator="auto",
    logger=WandbLogger(project="consistency", log_model=True),
    callbacks=[
        ModelCheckpoint(
            dirpath="ckpt", 
            save_top_k=3, 
            monitor="loss",
        )
    ],
    max_epochs=MAX_EPOCHS,
    precision=16 if torch.cuda.is_available() else 32,
    log_every_n_steps=30,
    gradient_clip_algorithm="norm",
    gradient_clip_val=1.0,
)

trainer.fit(consistency, dataloader)

### Generate samples 

You can now `sample` high quality images! ðŸŽ‰

In [None]:
consistency.sample(64, sample_steps=20)