# Consistency Training, but with JAX 🎆

Improved Consistency Training (iCT) from [arxiv.org/abs/2310.14189](https://arxiv.org/abs/2310.14189), with JAX & Flax, partially based on [smsharma/consistency-models](https://github.com/smsharma/consistency-models/).

In [40]:
%%capture
%pip install jax flax torch torchvision ipykernel
%pip install ml_collections matplotlib einops wandb tqdm

### Imports

In [123]:
import jax
import einops
import jax.numpy as jnp
from jax import random
from flax.training.train_state import TrainState
import flax.linen as nn

import optax
from tqdm import trange
import numpy as np
from jax.tree_util import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor

from functools import partial
from typing import Optional, Sequence, Callable, Any
from dataclasses import dataclass
from abc import abstractmethod

Shape = int | Sequence

## Backbones

### MLP Mixer

In [124]:
class MLPBlock(nn.Module):
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.Dense(self.mlp_dim)(x)
        y = nn.gelu(y)
        return nn.Dense(x.shape[-1])(y)


class MixerBlock(nn.Module):
    tokens_mlp_dim: int
    channels_mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.LayerNorm()(x)
        y = jnp.swapaxes(y, 1, 2)
        y = MLPBlock(self.tokens_mlp_dim)(y)
        y = jnp.swapaxes(y, 1, 2)
        x = x + y
        y = nn.LayerNorm()(x)
        y = MLPBlock(self.channels_mlp_dim)(y)
        return x + y


class MLPMixer(nn.Module):
    patch_size: int
    num_blocks: int
    hidden_dim: int
    tokens_mlp_dim: int
    channels_mlp_dim: int
    num_classes: int

    @nn.compact
    def __call__(self, x, t, context):
        b, h, w, c = x.shape

        d_t_emb = t.shape[-1]

        context = nn.Embed(self.num_classes, t.shape[-1])(context)
        context = einops.repeat(context, "b t -> b (h p1) (w p2) t", h=h // self.patch_size, w=w // self.patch_size, p1=self.patch_size, p2=self.patch_size)

        t = einops.repeat(t, "b t -> b (h p1) (w p2) t", h=h // self.patch_size, w=w // self.patch_size, p1=self.patch_size, p2=self.patch_size)
        context = jnp.concatenate([context, t], axis=-1)

        context = nn.gelu(nn.Dense(self.tokens_mlp_dim)(context))
        context = nn.Dense(d_t_emb)(context)

        x = jnp.concatenate([x, context], axis=-1)

        x = nn.Conv(self.hidden_dim, [self.patch_size, self.patch_size], strides=[self.patch_size, self.patch_size])(x)
        x = einops.rearrange(x, "n h w c -> n (h w) c")

        for _ in range(self.num_blocks):
            x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
        x = nn.LayerNorm()(x)

        x = nn.Dense(self.patch_size * self.patch_size * c)(x)
        x = einops.rearrange(x, "b (hp wp) (ph pw c) -> b (hp ph) (wp pw) c", hp=h // self.patch_size, wp=w // self.patch_size, ph=self.patch_size, pw=self.patch_size, c=c)

        return x

### Config classes

In [125]:
@dataclass
class ModelConfig:
    batch_size: int
    input_size: int
    channel_size: int
    time_embed_size: int
    max_time: float = 80.0
    s0: int = 2
    s1: int = 150
    eps: float = 2e-3
    sigma: float = 0.5
    dtype: Any = jnp.float32

    @abstractmethod
    def _create_backbone(self) -> nn.Module:
        pass

@dataclass
class MLPMixerConfig(ModelConfig):
    patch_size: int = 4
    num_blocks: int = 4
    hidden_dim: int = 256
    tokens_mlp_dim: int = 256
    channels_mlp_dim: int = 256
    num_classes: int = 10
    
    def _create_backbone(self):
        return MLPMixer(patch_size=self.patch_size,
                        num_blocks=self.num_blocks,
                        hidden_dim=self.hidden_dim,
                        tokens_mlp_dim=self.tokens_mlp_dim,
                        channels_mlp_dim=self.channels_mlp_dim,
                        num_classes=self.num_classes)

@dataclass
class TrainerConfig:
    lr: float = 3e-4

## Consistency Models

In [129]:
class ConsistencyModel:
    def __init__(self, config: ModelConfig, key: jax.dtypes.prng_key):
        self.config = config
        self.random_key = key
        self.backbone = config._create_backbone()
    
    @staticmethod
    @partial(jax.jit, static_argnums=(1, 4, 5))
    def consistency_fn(apply_params: Any,
                       apply_fn: Callable,
                       x: jax.Array,
                       timestep_emb: jax.Array,
                       sigma_data: float,
                       eps: float):
        
        cskip = lambda t: sigma_data ** 2 / ((t-eps)**2 + sigma_data ** 2)[:, :, None, None]
        cout = lambda t: sigma_data * (t-eps) / jnp.sqrt(t**2 + sigma_data ** 2)[:, :, None, None]

        x = apply_fn(apply_params, x, timestep_emb)

        return x * cskip(timestep_emb) + x * cout(timestep_emb)

    def sample(self, timesteps: Sequence):
        initial_variance = self.config.max_time ** 2

        self.random_key, normal_key = random.split(self.random_key)
        xt = random.normal(normal_key, self.config.dim) * initial_variance

        x = self.consistency_fn(xt, self.config.max_time, self.backbone.apply_fn, self.backbone.params, self.config.sigma_data, self.config.eps)

        for t in timesteps:
            normal_key, _ = random.split(normal_key)
            z = random.normal(normal_key, self.config.dim)
            xt = x + jnp.sqrt(t**2 - self.config.eps**2) * z
            x = self.consistency_fn(xt, self.config.max_time, self.model.apply_fn, self.model.params, self.config.sigma_data, self.config.eps)
    
    def param_count(self):
        if self.train_state is None:
            raise "Model not initialized. Call `create_state` first to initialize."
        
        return sum(x.size for x in jax.tree_leaves(self.state.params))
    
    def discretize(self, sigma: float, eps: float, N: int):
        idx = jnp.arange(N)
        return (eps ** (1 / sigma) + idx / (N - 1) * (self.config.max_time ** (1 / sigma) - eps ** (1 / sigma))) ** sigma
    
    def get_boundaries(self, step: int):
        N = (step * ((self.config.s1 + 1) ** 2 - self.config.s0 ** 2) / self.config.max_time)
        N = jnp.ceil(jnp.sqrt(N + self.config.s0 ** 2) - 1) + 1
        boundaries = self.discretize(self.config.sigma, self.config.eps, N)
        return boundaries, N
    
    def create_state(self, key: jax.dtypes.prng_key, optimizer: Optional[Any] = None, keep_state: bool = True):
        noise_key, init_key = random.split(key)
        init_shape = (self.config.batch_size,
                      self.config.input_size,
                      self.config.input_size,
                      self.config.channel_size)
        x = jax.random.normal(noise_key, init_shape)
        t = jnp.ones((self.config.batch_size, self.config.time_embed_size))
        
        context = jnp.ones((self.config.batch_size,)).astype(jnp.int32)
        params = self.backbone.init(init_key, x, t, context)
        train_state = TrainState.create(apply_fn=self.backbone.apply, params=params, tx=optimizer)

        if keep_state:
            self.state = train_state

In [130]:
@dataclass
class ConsistencyTrainer:
    model: ConsistencyModel
    dataloader: DataLoader
    config: TrainerConfig
    random_key: jax.dtypes.prng_key

    def __post_init__(self):

        if isinstance(self.dataloader, DataLoader):
            self.dataloader = iter(self.dataloader)

        self.random_key, init_key = jax.random.split(self.random_key)
        tx = optax.radam(learning_rate=self.config.lr)
        self.train_state = self.model.create_state(optimizer=tx, key=init_key)


    """
    TODO: loss_fn and train_step are full of shit and aren't iCT.
    """
    
    @staticmethod
    @partial(jax.jit, static_argnums=(5,))
    def loss_fn(params, consistency_fn, x, t1, t2, score, key, y):
        z = jax.random.normal(key, shape=x.shape)

        x2 = x + z * t2[:, :, None, None]
        x2 = consistency_fn(params, score, x2, t2, y)

        x1 = x + z * t1[:, :, None, None]
        x1 = consistency_fn(score, x1, t1, y)

        return np.mean((x1 - x2) ** 2)

    @staticmethod
    @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(5, 6, 7, 8, 9, 10))
    def train_step(train_state: TrainState,
                   batch: jax.Array,
                   t1: jax.Array,
                   t2: jax.Array,
                   key: jax.dtypes.prng_key,
                   model: ConsistencyModel,
                   loss_fn: Callable,
                   sigma_data: float,
                   eps: float,
                   time_embed_size: float):
                   
        
        x_batch, y_batch = batch
        
        loss, grads = jax.value_and_grad(loss_fn)(train_state.params, x_batch, t1, t2, model, key, y_batch, sigma_data, eps, time_embed_size)
        grads = jax.lax.pmean(grads, "batch")
        loss = jax.lax.pmean(loss, "batch")

        train_state = train_state.apply_gradients(grads=grads)

        metrics = {"loss": loss}

        return train_state, metrics

    def train(self, timesteps: int):
        state = self.train_state
        
        with trange(timesteps) as steps:
            for step in steps:
                self.random_key, time_key, step_key = jax.random.split(self.random_key, 3)
                x_batch, y_batch = next(self.dataloader)

                boundaries, N = self.model.get_boundaries(step)
                n_batch = jax.random.randint(time_key, minval=0, maxval=N - 1, shape=(x_batch.shape[0], 1))

                state, loss = self.train_step(state,
                                              (x_batch, y_batch),
                                              boundaries[n_batch],
                                              boundaries[n_batch + 1],
                                              step_key,
                                              score,
                                              self.loss_fn)

                steps.set_postfix(val=loss)
                
        self.train_state = state

In [131]:
random_key = random.key(0)
random_key, model_key, trainer_key = random.split(random_key, 3)

model_config = MLPMixerConfig(batch_size=4,
                        input_size=28, 
                        channel_size=1,
                        time_embed_size=16)

model = ConsistencyModel(model_config, model_key)

transform = Compose([ToTensor()])
collate_fn = lambda x: tree_map(np.asarray, default_collate(x))
mnist_dataset = MNIST('/tmp/mnist', download=True, transform=transform)
dataloader = DataLoader(dataset=mnist_dataset,
                         batch_size=model_config.batch_size,
                         shuffle=True,
                         collate_fn=collate_fn)

trainer = ConsistencyTrainer(model=model,
                             dataloader=dataloader,
                             config=TrainerConfig(),
                             random_key=trainer_key)

trainer.train(timesteps=int(1e4))

  0%|          | 0/10000 [00:00<?, ?it/s]


NameError: name 'score' is not defined