# All in one sampling ipynb

This file is for evaluation of CM-SH model that measure the FID score of 
1. one-step sampling for CM
2. multi-step sampling for CM
3. multi-step sampling for SH 

In [5]:
import jax
import jax.numpy as jnp
import flax

import matplotlib.pyplot as plt
import yaml
import omegaconf
import hydra
from hydra import initialize, compose

from framework.unifying_framework import UnifyingFramework
from framework.diffusion.consistency_framework import CMFramework
from framework.diffusion.edm_framework import EDMFramework

from utils import common_utils
from utils.fid_utils import FIDUtils
from utils.fs_utils import FSUtils

from tqdm import tqdm

from functools import partial

import os 

In [6]:
# Make new directory for storing samples for each sampling mode
default_path = "tmp/"
edm_multistep = default_path+ "edm_multistep"
cm_onestep = default_path+ "cm_onestep"
cm_multistep = default_path+ "cm_multistep"
dir_list = [cm_onestep, cm_multistep, edm_multistep]
for dir_elem in dir_list:
    os.makedirs(dir_elem, exist_ok=True)

In [7]:
sampling_batch_num = 512
total_sampling_num = 50000
sigma_max = 80.0
sigma_min = 0.002
rho = 7

In [8]:
# Setting CM framework for sampling



config_path = "configs"
default_config_path = "config"
rng = jax.random.PRNGKey(42)

with initialize(version_base=None, config_path=config_path):
    default_config = compose(config_name=default_config_path)
default_config["do_training"] = False
model_type = default_config.type

rng, denoiser_rng = jax.random.split(rng, 2)
rng, consistency_rng = jax.random.split(rng, 2)
# denoiser_framework = diffusion_framework.framework
fid_utils = FIDUtils(default_config)
fs_utils = FSUtils(default_config)

consistency_framework = CMFramework(default_config, consistency_rng, fs_utils, None)

In [9]:
# First, save SH and CM simultaneously if CM-SH mode
current_num = 0
print("Sampling CM onestep and SH multistep")
sampling_shape = (sampling_batch_num, 32, 32, 3)
while total_sampling_num > current_num:
    effective_batch_size = total_sampling_num - current_num if current_num + sampling_batch_num > total_sampling_num else sampling_batch_num
    edm, cm = consistency_framework.sampling_edm_and_cm(sampling_batch_num)
    edm = jnp.reshape(edm, sampling_shape)
    cm = jnp.reshape(cm, sampling_shape)
    fs_utils.save_images_to_dir(edm, starting_pos=current_num, save_path_dir=edm_multistep)
    fs_utils.save_images_to_dir(cm, starting_pos=current_num, save_path_dir=cm_onestep)
    current_num += sampling_batch_num

Sampling CM onestep and SH multistep


0it [00:00, ?it/s]

18it [01:00,  3.38s/it]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.10it/s]
18it [00:08,  2.10it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.10it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.11it/s]
18it [00:08,  2.12it/s]
18it [00:08,  2.

In [10]:
edm_multistep_fid_score = fid_utils.calculate_fid(edm_multistep)
print(f"EDM multistep score: {edm_multistep_fid_score:.3f}")
cm_onestep_fid_score = fid_utils.calculate_fid(cm_onestep)
print(f"CM onestep score: {cm_onestep_fid_score:.3f}")

Loading cifar10 statistics


  0%|          | 50/50176 [00:00<01:41, 492.25it/s]

100%|██████████| 50176/50176 [02:06<00:00, 397.07it/s]
100%|██████████| 1003/1003 [03:29<00:00,  4.80it/s]


EDM multistep score: 4.585
Loading cifar10 statistics


100%|██████████| 50176/50176 [01:42<00:00, 487.97it/s]
100%|██████████| 1003/1003 [02:33<00:00,  6.53it/s]


CM onestep score: 26.541


In [None]:
# Next, do some CM multistep sampling
multisampling_fid_dict = {}

In [11]:


def stochastic_iterative_sampler(
    rng,
    num_sampling,
    distiller: CMFramework,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)

    params = flax.jax_utils.replicate(distiller.torso_state.params_ema)
    sampling_fn = distiller.p_sample_cm # params: sampling_params, latent_sample, rng_key, gamma, t_max, t_min
    
    # Sampling x from sampler
    input_shape = (jax.local_device_count(), num_sampling // jax.local_device_count(), 32, 32, 3)
    rng, sampling_key = jax.random.split(rng, 2)
    x = jax.random.normal(sampling_key, input_shape) * t_max
    

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho

        # x0 = distiller(x, t * s_in)
        sampling_key, p_sample_key = jax.random.split(sampling_key, 2)
        p_sample_key = jax.random.split(p_sample_key, jax.local_device_count())
        
        t_param = jnp.asarray([t] * jax.local_device_count())
        t_min_param = jnp.asarray([t_min] * jax.local_device_count())
        gamma = jnp.zeros((jax.local_device_count(),))

        x0 = sampling_fn(params, x, p_sample_key, gamma, t_param, t_min_param)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = jnp.clip(next_t, t_min, t_max)

        rng, normal_rng = jax.random.split(rng, 2)
        x = x0 + jax.random.normal(normal_rng, input_shape) * jnp.sqrt(next_t**2 - t_min**2)
    
    return x

In [12]:
def get_fid(rng, sampling_fn, p, begin=(0,), end=(17, ), sample_dir="."):
    total_size = 50000
    batch_size = sampling_batch_num
    current_sampling_num = 0

    ts = begin + (p,) + end

    if multisampling_fid_dict.get(ts, None) is not None:
        return multisampling_fid_dict[ts]

    tmp_dir = os.path.join(sample_dir, f"{p}")
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)

    filelist = os.listdir(tmp_dir)
    
    if len(filelist) == total_size:
        fid = fid_utils.calculate_fid(tmp_dir)
        multisampling_fid_dict[ts] = fid
        return fid

    while current_sampling_num < total_size:
        effective_sampling = batch_size \
            if (total_size - current_sampling_num) // batch_size != 0 \
            else total_size - current_sampling_num

        rng, sampling_rng = jax.random.split(rng, 2)
        x0 = sampling_fn(rng=sampling_rng, num_sampling=effective_sampling, ts=ts)
        x0 = x0.reshape(-1, *x0.shape[-3:])
        fs_utils.save_images_to_dir(x0, tmp_dir, current_sampling_num)
        current_sampling_num += effective_sampling
    
    fid = fid_utils.calculate_fid(tmp_dir)
    multisampling_fid_dict[ts] = fid
    return fid

In [13]:
def ternery_search(rng, sampling_fn, fid_obj, before, after, sample_dir="."):
    right = after[0]
    left = before[-1]
    while right - left >= 3:
        m1 = int(left + (right - left) / 3.0)
        m2 = int(right - (right - left) / 3.0)

        f1 = get_fid(rng, sampling_fn, m1, before, after, sample_dir=sample_dir)

        # logger.log(f"fid at m1 = {m1} is {f1}, IS is {is1}")
        print(f"fid at m1 = {m1} is {f1}")
        f2 = get_fid(rng, sampling_fn, m2, before, after, sample_dir=sample_dir)

        print(f"fid at m2 = {m2} is {f2}")

        if f1 < f2:
            right = m2
        else:
            left = m1

        print(f"new interval is [{left}, {right}]")

    if right == left:
        p = right
    elif right - left == 1:
        f1 = get_fid(rng, sampling_fn, left, before, after, sample_dir=sample_dir)
        f2 = get_fid(rng, sampling_fn, right, before, after, sample_dir=sample_dir)
        p = m1 if f1 < f2 else m2
    elif right - left == 2:
        mid = left + 1
        f1 = get_fid(rng, sampling_fn, left, before, after, sample_dir=sample_dir)
        f2 = get_fid(rng, sampling_fn, right, before, after, sample_dir=sample_dir)
        fmid = get_fid(rng, sampling_fn, mid, before, after, sample_dir=sample_dir)

        print(f"fmid at mid = {mid} is {fmid}")

        if fmid < f1 and fmid < f2:
            p = mid
        elif f1 < f2:
            p = m1
        else:
            p = m2

    return p

In [14]:
steps = 18
begin = (0,)
end = (steps - 1,)

In [15]:
sampler = partial(stochastic_iterative_sampler, # need rng, num_sampling, ts
    distiller=consistency_framework,
    t_min=sigma_min,
    t_max=sigma_max,
    rho=rho,
    steps=steps)
rng, cm_multistep_sampling = jax.random.split(rng, 2)

ternery_search(cm_multistep_sampling, sampler, fid_utils, begin, end, cm_multistep)

Loading cifar10 statistics


100%|██████████| 50000/50000 [01:43<00:00, 484.42it/s]
100%|██████████| 1000/1000 [02:21<00:00,  7.09it/s]


fid at m1 = 5 is 26.628729914619782
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:42<00:00, 486.77it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.20it/s]


fid at m2 = 11 is 25.03028232655697
new interval is [5, 17]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:43<00:00, 484.68it/s]
100%|██████████| 1000/1000 [02:24<00:00,  6.93it/s]


fid at m1 = 9 is 27.916672329356174
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:42<00:00, 485.50it/s]
100%|██████████| 1000/1000 [02:21<00:00,  7.08it/s]


fid at m2 = 13 is 25.78110128157232
new interval is [9, 17]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 493.28it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.21it/s]


fid at m1 = 11 is 25.03028232655697
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:42<00:00, 487.89it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.16it/s]


fid at m2 = 14 is 25.975722885808466
new interval is [9, 14]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:42<00:00, 486.79it/s]
100%|██████████| 1000/1000 [02:20<00:00,  7.11it/s]


fid at m1 = 10 is 26.308213812628082
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:42<00:00, 488.12it/s]
100%|██████████| 1000/1000 [02:17<00:00,  7.30it/s]


fid at m2 = 12 is 24.775052015719893
new interval is [10, 14]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 492.61it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.20it/s]


fid at m1 = 11 is 25.03028232655697
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 493.31it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.20it/s]


fid at m2 = 12 is 24.775052015719893
new interval is [11, 14]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 493.69it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.19it/s]


fid at m1 = 12 is 24.775052015719893
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 491.04it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.21it/s]


fid at m2 = 13 is 25.78110128157232
new interval is [11, 13]
Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 491.48it/s]
100%|██████████| 1000/1000 [02:18<00:00,  7.21it/s]


Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 493.05it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.18it/s]


Loading cifar10 statistics


100%|██████████| 50000/50000 [01:41<00:00, 490.58it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.15it/s]


fmid at mid = 12 is 24.775052015719893


12