# Improved Consistency Training on CIFAR-10

[![arXiv](https://img.shields.io/badge/arXiv-2310.14189-b31b1b.svg)](https://arxiv.org/abs/2310.14189)
<a target="_blank" href="https://colab.research.google.com/github/leakedweights/mincy/blob/main/notebooks/ict_cifar.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

JAX & Flax implementation of [Improved Consistency Training](https://arxiv.org/abs/2310.14189).

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

!git clone https://github.com/leakedweights/mincy.git
%pip install torch torchvision ipykernel einops wandb imageio clean-fid
%pip install --upgrade jax[tpu] jaxlib flax

os.chdir('/content/mincy/notebooks')
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [44]:
import jax
from jax import random
import optax

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

import wandb

from mincy.models.unet import UNet
from mincy.configs.ict_cifar_config import cifar_config, cifar_trainer_config
from mincy.configs.ict_config import consistency_config
from mincy.training.trainer import ConsistencyTrainer
from mincy.training.dataloader import *

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
drive_base_dir = "/content/drive/MyDrive/mincy"

if not(os.path.exists(drive_base_dir)):
    os.makedirs(drive_base_dir)

cifar_trainer_config["checkpoint_dir"] = f"{drive_base_dir}/checkpoints"
cifar_trainer_config["snapshot_dir"] = f"{drive_base_dir}/samples"

In [None]:
if not(os.path.exists(cifar_trainer_config["checkpoint_dir"])):
    os.makedirs(cifar_trainer_config["checkpoint_dir"])

if not(os.path.exists(cifar_trainer_config["snapshot_dir"])):
    os.makedirs(cifar_trainer_config["snapshot_dir"])

In [None]:
batch_size = 512
dataset = CIFAR10('/tmp/cifar', download=True, transform=transform)
dataloader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         collate_fn=numpy_collate,
                         drop_last=True)

In [None]:
run = wandb.init(
    project="mincy-cifar",
    config={
        "model": cifar_config,
        "trainer": cifar_trainer_config
    }
)

In [45]:
training_key = random.PRNGKey(0)
model = UNet(**cifar_config)
optimizer = optax.radam(cifar_trainer_config["learning_rate"])

trainer = ConsistencyTrainer(random_key=training_key,
                             model=model,
                             optimizer=optimizer,
                             dataloader=dataloader,
                             img_shape=(32, 32, 3),
                             num_devices=jax.local_device_count(),
                             config=cifar_trainer_config,
                             consistency_config=consistency_config)

In [None]:
trainer.load_checkpoint()
trainer.train(400_000)