# Improved Techniques for Training Consistency Models

[![arXiv](https://img.shields.io/badge/arXiv-2310.14189-b31b1b.svg)](https://arxiv.org/abs/2310.14189)
<a target="_blank" href="https://colab.research.google.com/github/leakedweights/mincy/blob/main/notebooks/ict_mlp_mixer.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

JAX & Flax implementation of [Improved Consistency Training](https://arxiv.org/abs/2310.14189).
This code is partially based on [smsharma/consistency-models](https://github.com/smsharma/consistency-models/).

In [None]:
%pip install torch torchvision ipykernel einops wandb imageio
%pip install --upgrade jax flax

In [None]:
import os
import jax
import multiprocessing

hardware = jax.default_backend()

if hardware == "tpu":
    pass # setup tpu

elif hardware == "gpu":
    pass # setup gpu

elif hardware == "cpu":
    os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
        multiprocessing.cpu_count()
    )

    jax.config.update('jax_platform_name', 'cpu')

## 🚚 Imports

In [3]:
import jax
import einops
from jax import random
import jax.numpy as jnp
import flax.linen as nn
from jax.scipy.special import erf

import optax
import numpy as np
from optax.losses import l2_loss
from jax.tree_util import tree_map
from flax.training import checkpoints
from torchvision.datasets import MNIST
from flax.training.train_state import TrainState
from flax.jax_utils import replicate, unreplicate
from torch.utils.data import DataLoader, default_collate
from torchvision.transforms import Compose, ToTensor, Lambda

import wandb
import imageio
from tqdm import trange
from functools import partial
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any



## 🎲 Backbone

In [5]:
class MLPBlock(nn.Module):
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.gelu(y)
        return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.LayerNorm()(x)
        y = jnp.swapaxes(y, 1, 2)
        y = MLPBlock(self.tokens_mlp_dim)(y)
        y = jnp.swapaxes(y, 1, 2)
        x = x + y
        y = nn.LayerNorm()(x)
        y = MLPBlock(self.channels_mlp_dim)(y)
        return x + y


class MLPMixer(nn.Module):
    patch_size: int
    num_blocks: int
    hidden_dim: int
    tokens_mlp_dim: int
    channels_mlp_dim: int
    num_classes: int

    @nn.compact
    def __call__(self, x, t, context):
        b, h, w, c = x.shape
        d_t_emb = t.shape[-1]

        context = nn.Embed(self.num_classes, t.shape[-1])(context)
        context = einops.repeat(context, "b t -> b (h p1) (w p2) t", h=h // self.patch_size, w=w // self.patch_size, p1=self.patch_size, p2=self.patch_size)

        t = einops.repeat(t, "b t -> b (h p1) (w p2) t", h=h // self.patch_size, w=w // self.patch_size, p1=self.patch_size, p2=self.patch_size)
        context = jnp.concatenate([context, t], axis=-1)

        context = nn.gelu(nn.Dense(self.tokens_mlp_dim)(context))
        context = nn.Dense(d_t_emb)(context)

        x = jnp.concatenate([x, context], axis=-1)

        x = nn.Conv(self.hidden_dim, [self.patch_size, self.patch_size], strides=[self.patch_size, self.patch_size])(x)
        x = einops.rearrange(x, "n h w c -> n (h w) c")

        for _ in range(self.num_blocks):
            x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
        x = nn.LayerNorm()(x)

        x = nn.Dense(self.patch_size * self.patch_size * c)(x)
        x = einops.rearrange(x, "b (hp wp) (ph pw c) -> b (hp ph) (wp pw) c", hp=h // self.patch_size, wp=w // self.patch_size, ph=self.patch_size, pw=self.patch_size, c=c)

        return x

## ⚙️ Setup

In [20]:
@dataclass
class ModelConfig:
    batch_size: int
    input_size: int
    channel_size: int
    time_embed_size: int
    max_time: float = 80.0
    s0: int = 10
    s1: int = 1280
    c=5.4e-4
    eps: float = 2e-3
    sigma: float = 0.5
    rho: float = 7.0
    sigma_min: float = 0.002
    sigma_max: float = 80
    p_mean: float = -1.1
    p_std: float = 2.0
    dtype: Any = jnp.float32

    def __post_init__(self):
        data_dim = self.channel_size * self.input_size ** 2
        self.c_data = float(self.c * jnp.sqrt(data_dim))
        self.device_batch_size = self.batch_size // jax.device_count()
        self.init_shape = (self.device_batch_size,
                      self.input_size,
                      self.input_size,
                      self.channel_size)

    @abstractmethod
    def _create_backbone(self) -> nn.Module:
        pass


@dataclass
class MLPMixerConfig(ModelConfig):
    patch_size: int = 4
    num_blocks: int = 4
    hidden_dim: int = 256
    tokens_mlp_dim: int = 256
    channels_mlp_dim: int = 256
    num_classes: int = 10
    
    def _create_backbone(self):
        return MLPMixer(patch_size=self.patch_size,
                        num_blocks=self.num_blocks,
                        hidden_dim=self.hidden_dim,
                        tokens_mlp_dim=self.tokens_mlp_dim,
                        channels_mlp_dim=self.channels_mlp_dim,
                        num_classes=self.num_classes)

@dataclass
class TrainerConfig:
    lr: float = 3e-4
    log_wandb: bool = True
    log_granularity: int = 100
    generation_granularity: int = 1000
    generation_save_path = "/tmp/mincy/train_samples"
    generation_timesteps = [500, 100, 10]
    generation_classes = 1
    ckpt_granularity: int = 5000
    ckpt_path: str = "/tmp/mincy/checkpoints"
    ckpts_to_keep: int = 1

    def __post_init__(self):
        self.ckpt_path = os.path.abspath(self.ckpt_path)

## ✨ Consistency Models

In [17]:
def pseudo_huber_loss(x: jax.Array, y: jax.Array, c_data: float):
    loss = l2_loss(x, y)
    loss = jnp.sqrt(loss + c_data**2) - c_data
    return loss

In [18]:
def sinusoidal_embedding(time: jax.Array, embedding_size: int):
    time = time[..., 0]
    time = time * 1e3
    half_dim = embedding_size // 2
    emb_scale = jnp.log(1e4) / (half_dim - 1)
    
    emb = jnp.arange(half_dim) * -emb_scale
    emb = jnp.exp(emb)
    emb = emb[None, :] * time[:, None]

    sin_emb = jnp.sin(emb)
    cos_emb = jnp.cos(emb)
    embedding = jnp.concatenate([sin_emb, cos_emb], axis=-1)

    if embedding_size % 2 == 1:
        padding = ((0, 0), (0, 0), (0, 1))
        embedding = jnp.pad(embedding, padding, mode='constant')
    
    return embedding

In [19]:
class ConsistencyModel:
    def __init__(self, config: ModelConfig, key: jax.dtypes.prng_key):
        self.config = config
        self.random_key = key
        self.backbone = config._create_backbone()
    
    @staticmethod
    @partial(jax.jit, static_argnums=(3, 4))
    def consistency_fn(x: jax.Array,
                       h: jax.Array,
                       t: float,
                       sigma: float,
                       eps: float):
        
        cskip = lambda t: (sigma ** 2 / ((t-eps)**2 + sigma ** 2))[:, :, None, None]
        cout = lambda t: (sigma * (t-eps) / jnp.sqrt(t**2 + sigma ** 2))[:, :, None, None]
        return x * cskip(t) + h * cout(t)
    
    def get_param_count(self):
        if self.train_state is None:
            raise "Model not initialized. Call `create_state` first to initialize."
        
        return sum(x.size for x in jax.tree_leaves(self.state.params))
    
    def ict_discretize(self, step: int, max_steps):
        k_prime = jnp.floor(max_steps / (jnp.log2(jnp.floor(self.config.s1 / self.config.s0)) + 1))
        N = self.config.s0 * jnp.pow(2, jnp.floor(step / k_prime))
        N = N.at[N > self.config.s1].set(self.config.s1)
        return N + 1

    def karras_levels(self, N):
            rho = self.config.rho
            sigma_min = self.config.sigma_min
            sigma_max = self.config.sigma_max
            idx = jnp.arange(0, N-1)

            sigma_i = jnp.pow(sigma_min, 1.0 / rho) 
            sigma_i = sigma_i + (idx-1) / (N-1) * (jnp.pow(sigma_max, 1.0 / rho) - jnp.pow(sigma_min, 1.0 / rho) )
            sigma_i = jnp.pow(sigma_i, rho)
            return sigma_i

    def sample_timesteps(self, key, noise_levels, shape):
            sigma_erf = lambda sigma: (jnp.log(sigma[1:]) - self.config.p_mean) / (jnp.sqrt(2) * self.config.p_std)

            index_erfs = sigma_erf(noise_levels[1:]) - sigma_erf(noise_levels[:-1])
            probs = index_erfs / jnp.sum(index_erfs)
            timesteps = random.choice(key, len(probs), p=probs, replace=True, shape=shape)
            
            t1 = noise_levels[timesteps]
            t2 = noise_levels[timesteps + 1]
            return t1, t2
    
    def sample(self, timesteps, target_classes):
        self.random_key, noise_key = random.split(self.random_key)
        noise_shape = self.config.init_shape
        batch_timestep = einops.rearrange(jnp.repeat(timesteps[0], self.config.device_batch_size), "b -> b 1")
        timestep_embedding = sinusoidal_embedding(batch_timestep, self.config.time_embed_size)
        x = jax.random.normal(noise_key, noise_shape) * timesteps[0]

        h = self.state.apply_fn(self.state.params, x, timestep_embedding, target_classes)
        x = self.consistency_fn(x, h , batch_timestep, self.config.sigma, self.config.eps)

        for timestep in timesteps[1:]:
            noise_key, _ = random.split(noise_key)
            noise = jax.random.normal(noise_key, noise_shape)
            batch_timestep = einops.rearrange(jnp.repeat(timestep, self.config.device_batch_size), "b -> b 1")

            timestep_embedding = sinusoidal_embedding(batch_timestep, self.config.time_embed_size)
            x_noisy = x + jnp.sqrt(timestep**2 - self.config.eps**2) * noise
            h = self.state.apply_fn(self.state.params, x_noisy, timestep_embedding, target_classes)
            x = self.consistency_fn(x, h , batch_timestep, self.config.sigma, self.config.eps)

        return x

    
    def create_state(self, key: jax.dtypes.prng_key, optimizer: Optional[Any] = None):
        noise_key, init_key = random.split(key)
        x = jax.random.normal(noise_key, self.config.init_shape)
        t = jnp.ones((self.config.device_batch_size, self.config.time_embed_size))
        
        y = jnp.ones((self.config.device_batch_size,)).astype(jnp.int32)
        params = self.backbone.init(init_key, x, t, y)
        train_state = TrainState.create(apply_fn=self.backbone.apply, params=params, tx=optimizer)
        self.state = train_state

        return train_state

In [10]:
class ConsistencyTrainer:
    def __init__(self, model: ConsistencyModel, dataloader: DataLoader, config: TrainerConfig, random_key: jax.dtypes.prng_key):
        self.model = model
        self.dataloader = dataloader
        self.config = config
        self.random_key = random_key
        self.random_key, init_key = jax.random.split(self.random_key)

        tx = optax.radam(learning_rate=self.config.lr)
        self.model_state = self.model.create_state(optimizer=tx, key=init_key)
        
        if not os.path.exists(self.config.ckpt_path):
            os.makedirs(self.config.ckpt_path)

        if not os.path.exists(self.config.generation_save_path):
            os.makedirs(self.config.generation_save_path)

    def handle_ckpt(self, step: int, state: TrainState):
        if step % self.config.ckpt_granularity == self.config.ckpt_granularity - 1:
            checkpoints.save_checkpoint(ckpt_dir=self.config.ckpt_path,
                              target=state,
                              step=step,
                              overwrite=True,
                              keep=self.config.ckpts_to_keep)
                  

    def restore_ckpt(self, step: Optional[int] = None, load=True):
        restored_state = checkpoints.restore_checkpoint(ckpt_dir=self.config.ckpt_path,
                                                        target=self.model.state,
                                                        step=step)
        if load:
            self.model.state = restored_state
        return restored_state
    
    def handle_metrics(self, step: int, metrics: dict):
        if not self.config.log_wandb:
            return
        
        if step % self.config.log_granularity == self.config.log_granularity - 1:
            wandb.log(metrics, step=step)

    def handle_sampling(self, step: int, state: TrainState):
        if step % self.config.generation_granularity == self.config.generation_granularity - 1:
            self.model.state = state
            target_classes = jnp.repeat(self.config.generation_classes, self.model.config.device_batch_size)
            outputs = self.model.sample(self.config.generation_timesteps, target_classes)
            
            for i, image in enumerate(outputs):
                image_path = os.path.join(self.config.generation_save_path, f"sample_{step}_{i}.png")
                imageio.imwrite(image_path, image)
                print(f"Saved image to {image_path}")

    @staticmethod
    @partial(jax.jit, static_argnums=(1, 2, 6, 9, 10, 11))
    def loss_fn(params, apply_fn, consistency_fn, x, t1, t2, time_embed_size, y, key, sigma, eps, c_data):
        z = jax.random.normal(key, shape=x.shape)

        x1 = x + z * t1[:, :, None, None]
        t1_emb = sinusoidal_embedding(t1, time_embed_size)

        h1 = jax.lax.stop_gradient(apply_fn(params, x, t1_emb, y)) # theta^-
        x1 = consistency_fn(x1, h1, t1, sigma, eps)

        x2 = x + z * t2[:, :, None, None]
        t2_emb = sinusoidal_embedding(t2, time_embed_size)

        h2 = apply_fn(params, x, t2_emb, y)
        x2 = consistency_fn(x2, h2, t2, sigma, eps)

        loss = pseudo_huber_loss(x1, x2, c_data)
        weight = (1 / (t2 - t1))[:, :, None, None]

        return jnp.mean(weight * loss)

    @staticmethod
    @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(5, 6, 7, 8, 9, 10))
    def train_step(state, batch, t1, t2, key, consistency_fn, loss_fn, time_embed_size, sigma, eps, c_data):
        x, y = batch
        params = state.params
        apply_fn = state.apply_fn

        loss, grads = jax.value_and_grad(loss_fn)(params, apply_fn, consistency_fn, x, t1, t2, time_embed_size, y, key, sigma, eps, c_data)

        grads = jax.lax.pmean(grads, "batch")
        loss = jax.lax.pmean(loss, "batch")

        state = state.apply_gradients(grads=grads)
        return state, loss

    def train(self, timesteps: int):
        state = replicate(self.model.state)
        batch_size = self.model.config.batch_size
        num_devices = jax.device_count()
        device_batch_size = batch_size // num_devices

        assert batch_size % num_devices == 0, \
            f"Batch size must be divisible by the number of devices, but got {batch_size} and {num_devices}."
        
        with trange(timesteps) as steps:
            for step in steps:
                try:
                    batch = next(self.dataloader.__iter__())
                except StopIteration:
                    continue
                
                x_batch, y_batch = batch

                b, h, w, c = x_batch.shape
                x_batch = x_batch.reshape(num_devices, device_batch_size, h, w, c)
                y_batch = y_batch.reshape(num_devices, device_batch_size)

                sigma = self.model.config.sigma
                eps = self.model.config.eps
                c_data = self.model.config.c_data
                self.random_key, time_key = random.split(self.random_key)
                self.random_key, *train_keys = random.split(self.random_key, min(x_batch.shape[0], jax.local_device_count()) + 1)

                N = self.model.ict_discretize(step, timesteps)
                noise_levels = self.model.karras_levels(N)
                t1, t2 = self.model.sample_timesteps(time_key, noise_levels, (*x_batch.shape[:2], 1))
                
                state, loss = self.train_step(state,
                                              (x_batch, y_batch),
                                              t1, t2,
                                              jnp.asarray(train_keys),
                                              self.model.consistency_fn,
                                              self.loss_fn,
                                              self.model.config.time_embed_size,
                                              sigma, eps, c_data)
                loss = unreplicate(loss)
                steps.set_postfix(val=loss)
                self.handle_ckpt(step, unreplicate(state))
                self.handle_metrics(step, {"training loss": loss})
                self.handle_sampling(step, unreplicate(state))
                
        self.model.state = unreplicate(state)

## 🏋️ Training

In [None]:
wandb.init(
    project="minimal-consistency-jax",
    job_type="simple-train-loop"
)

In [None]:
random_key = random.key(0)
random_key, model_key, trainer_key = random.split(random_key, 3)

model_config = MLPMixerConfig(batch_size=512,
                        input_size=28, 
                        channel_size=1,
                        time_embed_size=16)

model = ConsistencyModel(model_config, model_key)

def numpy_collate(batch):
    batch = default_collate(batch)
    batch = tree_map(lambda x: np.asarray(x), batch)
    return batch

transform = Compose([
    ToTensor(),
    Lambda(lambda x: x.permute(1, 2, 0)),
    Lambda(lambda x: x * 2 - 1),
])


mnist_dataset = MNIST('/tmp/mnist', download=True, transform=transform)

dataloader = DataLoader(dataset=mnist_dataset,
                         batch_size=model_config.batch_size,
                         shuffle=True,
                         collate_fn=numpy_collate,
                         drop_last=True)

trainer = ConsistencyTrainer(model=model,
                             dataloader=dataloader,
                             config=TrainerConfig(),
                             random_key=trainer_key)

trainer.train(timesteps=int(1e4))