# Improved Consistency Training on CIFAR-10

[![arXiv](https://img.shields.io/badge/arXiv-2310.14189-b31b1b.svg)](https://arxiv.org/abs/2310.14189)
<a target="_blank" href="https://colab.research.google.com/github/leakedweights/mincy/blob/main/notebooks/ict_cifar.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

JAX & Flax implementation of [Improved Consistency Training](https://arxiv.org/abs/2310.14189).

In [3]:
%load_ext autoreload
%autoreload 2

In [None]:
%%capture
%pip install torch torchvision ipykernel einops wandb imageio
%pip install --upgrade jax[tpu] jaxlib flax

In [4]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [17]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn

from mincy.models.unet import UNet

cifar_config = {"channel_mults": (1, 2, 4, 8),
     "attention_mults": (2,),
     "kernel_size": (3,3),
     "dropout": 0.1,
     "num_init_channels": 16,
     "num_res_blocks": 4,
     "pos_emb_type": "fourier",
     "pos_emb_dim": 16,
     "rescale_skip_conns": True,
     "resblock_variant":"BigGAN++",
     "fourier_scale": 16,
     "nonlinearity": nn.swish,}

model = UNet(**cifar_config)

key = random.key(0)
img_shape = (32, 32, 3)
batch_size = 64
input_shape = (batch_size, *img_shape)

init_image = jnp.ones(input_shape)
init_t = jnp.ones((batch_size,))
variables = model.init(key, init_image, init_t, train=True)

Number of values in x before input layer outside the range [-1, 1] = 0
Number of values in x after input layer outside the range [-1, 1] = 428480
nans = 0
Number of values in x after block outside the range [-1, 1] = 25102
Number of values in x after block outside the range [-1, 1] = 12558
Number of values in x after block outside the range [-1, 1] = 4638
Number of values in x after block outside the range [-1, 1] = 8247


In [21]:
from mincy.models.consistency import training_consistency, pseudo_huber_loss

t1 = jnp.repeat(10, batch_size)
t2 = jnp.repeat(20, batch_size)
noise = random.normal(key, input_shape)
denoising_fn = jax.tree_util.Partial(model.apply, variables, rngs=random.key(0))
xt1, xt2 = training_consistency(t1=t1, t2=t2,
                     x0=init_image,
                     noise=noise,
                     denoising_fn=denoising_fn,
                     sigma_data=0.5,
                     sigma_min=0.002)

data_dim =jnp.prod(jnp.array(xt1.shape[1:]))
c_data = 0.00054 * jnp.sqrt(data_dim)
jax.debug.print('cdata = {cdata}', cdata=c_data)
pseudo_huber_loss(xt1, xt2, c_data=c_data)

Number of values in x before input layer outside the range [-1, 1] = 62497
Number of values in x after input layer outside the range [-1, 1] = 310536
nans = 0
Number of values in x after block outside the range [-1, 1] = 23275
Number of values in x after block outside the range [-1, 1] = 10868
Number of values in x after block outside the range [-1, 1] = 5253
Number of values in x after block outside the range [-1, 1] = 8639
Number of values in x before input layer outside the range [-1, 1] = 128425
Number of values in x after input layer outside the range [-1, 1] = 934783
nans = 0
Number of values in x after block outside the range [-1, 1] = 117999
Number of values in x after block outside the range [-1, 1] = 14612
Number of values in x after block outside the range [-1, 1] = 5289
Number of values in x after block outside the range [-1, 1] = 9192
Number of values in xt1_consistency outside the range [-1, 1] = 181218
consistency_nans = 0
cdata = 0.029929835349321365


Array([[[[9.08507729e+00, 3.10851715e+02, 3.28020787e+00],
         [1.52168226e+01, 3.17006435e+01, 3.26010529e+02],
         [1.18831627e+02, 8.97385406e+00, 5.03558167e+02],
         ...,
         [2.30082428e+02, 1.84222126e+01, 2.17187667e+00],
         [4.65404449e+02, 1.43835771e+00, 1.61056352e+00],
         [9.21093994e+02, 1.00721169e+00, 8.46399403e+00]],

        [[7.87340879e-01, 3.58519821e+01, 7.74445190e+01],
         [1.01172492e-01, 5.04633808e+00, 8.72286072e+01],
         [1.14588809e+00, 7.25892029e+01, 3.26478546e+02],
         ...,
         [6.91541016e-01, 7.59067688e+01, 1.69847927e+01],
         [8.85904908e-01, 7.57903039e-01, 1.17337494e+01],
         [5.58566589e+01, 7.50650101e+01, 5.57220650e+01]],

        [[1.16658417e+02, 1.68610596e+02, 3.10204067e+01],
         [2.94496574e+01, 7.62707672e+01, 6.14444876e+00],
         [3.95574005e+02, 2.50672531e+01, 1.23690498e+02],
         ...,
         [2.57764130e+02, 1.69032516e+01, 9.67942119e-01],
         [