# Improved Consistency Training on CIFAR-10

[![arXiv](https://img.shields.io/badge/arXiv-2310.14189-b31b1b.svg)](https://arxiv.org/abs/2310.14189)
<a target="_blank" href="https://colab.research.google.com/github/leakedweights/mincy/blob/main/notebooks/ict_cifar.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

JAX & Flax implementation of [Improved Consistency Training](https://arxiv.org/abs/2310.14189).

In [None]:
%%capture
%pip install torch torchvision ipykernel einops wandb imageio
%pip install --upgrade jax[tpu] jaxlib flax

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [13]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn

from mincy.models.unet import UNet

cifar_config = {"channel_mults": (1, 2, 4, 8),
     "attention_mults": (2,),
     "kernel_size": (3,3),
     "dropout": 0.1,
     "num_init_channels": 16,
     "num_res_blocks": 4,
     "pos_emb_type": "fourier",
     "pos_emb_dim": 16,
     "rescale_skip_conns": True,
     "resblock_variant":"BigGAN++",
     "fourier_scale": 16,
     "nonlinearity": nn.swish,}

model = UNet(**cifar_config)

key = random.key(0)
img_shape = (32, 32, 3)
batch_size = 64
input_shape = (batch_size, *img_shape)

init_image = jnp.ones(input_shape)
init_t = jnp.ones((batch_size,))
variables = model.init(key, init_image, init_t, train=True)

num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 4, channels: 16
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 8, channels: 32
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 16, channels: 64
num groups: 32, channels: 128
num groups: 32, channels: 128
num groups: 32, channels: 128
num groups: 32, channels: 128
nu