# Improved Consistency Training on CIFAR-10

[![arXiv](https://img.shields.io/badge/arXiv-2310.14189-b31b1b.svg)](https://arxiv.org/abs/2310.14189)
<a target="_blank" href="https://colab.research.google.com/github/leakedweights/mincy/blob/main/notebooks/ict_cifar.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

JAX & Flax implementation of [Improved Consistency Training](https://arxiv.org/abs/2310.14189).

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

!git clone https://github.com/leakedweights/mincy.git
%pip install torch torchvision ipykernel einops wandb imageio
%pip install --upgrade jax[tpu] jaxlib flax

os.chdir('/content/mincy/notebooks')
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [41]:
import jax
from jax import random
import optax

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

from mincy.models.unet import UNet
from mincy.configs.ict_cifar import cifar_config, cifar_trainer_config
from mincy.configs.ict import consistency_config
from mincy.training.trainer import ConsistencyTrainer
from mincy.training.dataloader import *

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
drive_base_dir = "/content/drive/MyDrive/mincy"

if not(os.path.exists(drive_base_dir)):
    os.makedirs(drive_base_dir)

cifar_trainer_config["checkpoint_dir"] = f"{drive_base_dir}/checkpoints"
cifar_trainer_config["sample_dir"] = f"{drive_base_dir}/samples"

In [None]:
if not(os.path.exists(cifar_trainer_config["checkpoint_dir"])):
    os.makedirs(cifar_trainer_config["checkpoint_dir"])

if not(os.path.exists(cifar_trainer_config["sample_dir"])):
    os.makedirs(cifar_trainer_config["sample_dir"])

In [6]:
model = UNet(**cifar_config)

key = random.key(0)
img_shape = (32, 32, 3)
batch_size = 64

In [32]:
dataset = CIFAR10('/tmp/cifar', download=True, transform=transform)
dataloader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         collate_fn=numpy_collate,
                         drop_last=True)

Files already downloaded and verified


In [42]:
model = UNet(**cifar_config)
optimizer = optax.radam(cifar_trainer_config["learning_rate"])

trainer = ConsistencyTrainer(random_key=random.PRNGKey(0),
                             model=model,
                             optimizer=optimizer,
                             dataloader=dataloader,
                             img_shape=(32, 32, 3),
                             batch_size=256,
                             num_devices=jax.local_device_count(),
                             config=cifar_trainer_config,
                             consistency_config=consistency_config)

In [43]:
trainer.train(400_000)

  0%|          | 0/400000 [00:00<?, ?it/s]

device_key_shape = (8, 2)





ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function discretize at /home/betonitcso/projects/mincy/mincy/components/schedule.py:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[] = floor b
    from line /home/betonitcso/projects/mincy/mincy/components/schedule.py:11 (discretize)

  operation a:f32[] = log b
    from line /home/betonitcso/projects/mincy/mincy/components/schedule.py:11 (discretize)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /home/betonitcso/projects/mincy/mincy/components/schedule.py:11 (discretize)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /home/betonitcso/projects/mincy/mincy/components/schedule.py:11 (discretize)

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /home/betonitcso/projects/mincy/mincy/components/schedule.py:12 (discretize)

(Additional originating lines are not shown.)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError