# All in one sampling ipynb

This file is for evaluation of CM-SH model that measure the FID score of 
1. one-step sampling for CM
2. multi-step sampling for CM
3. multi-step sampling for SH 

In [1]:
import jax
import jax.numpy as jnp
import flax

import matplotlib.pyplot as plt
import yaml
import omegaconf
import hydra
from hydra import initialize, compose

from framework.unifying_framework import UnifyingFramework
from framework.diffusion.consistency_framework import CMFramework
from framework.diffusion.edm_framework import EDMFramework

from utils import common_utils
from utils.fid_utils import FIDUtils
from utils.fs_utils import FSUtils

from tqdm import tqdm

from functools import partial

import os 

2023-12-07 07:56:49.816411: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-07 07:56:49.816472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-07 07:56:49.818848: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Make new directory for storing samples for each sampling mode
default_path = "tmp/"
edm_multistep = default_path+ "edm_multistep"
cm_onestep = default_path+ "cm_onestep"
cm_multistep = default_path+ "cm_multistep"
dir_list = [cm_onestep, cm_multistep, edm_multistep]
for dir_elem in dir_list:
    os.makedirs(dir_elem, exist_ok=True)

In [3]:
sampling_batch_num = 512
total_sampling_num = 50000
sigma_max = 80.0
sigma_min = 0.002
rho = 7

In [4]:
# Setting CM framework for sampling

config_path = "configs"
default_config_path = "config"
rng = jax.random.PRNGKey(42)

with initialize(version_base=None, config_path=config_path):
    default_config = compose(config_name=default_config_path)
default_config["do_training"] = False
model_type = default_config.type

rng, denoiser_rng = jax.random.split(rng, 2)
rng, consistency_rng = jax.random.split(rng, 2)
# denoiser_framework = diffusion_framework.framework
fid_utils = FIDUtils(default_config)
fs_utils = FSUtils(default_config)

consistency_framework = CMFramework(default_config, consistency_rng, fs_utils, None)



Checkpoint 600010 loaded




Checkpoint 600010 loaded


In [5]:
# First, save SH and CM simultaneously if CM-SH mode
current_num = 0
print("Sampling CM onestep and SH multistep")
sampling_shape = (sampling_batch_num, 32, 32, 3)
while total_sampling_num > current_num:
    effective_batch_size = total_sampling_num - current_num if current_num + sampling_batch_num > total_sampling_num else sampling_batch_num
    edm, cm = consistency_framework.sampling_edm_and_cm(sampling_batch_num)
    edm = jnp.reshape(edm, sampling_shape)
    cm = jnp.reshape(cm, sampling_shape)
    fs_utils.save_images_to_dir(edm, starting_pos=current_num, save_path_dir=edm_multistep)
    fs_utils.save_images_to_dir(cm, starting_pos=current_num, save_path_dir=cm_onestep)
    current_num += sampling_batch_num

Sampling CM onestep and SH multistep


18it [01:16,  4.24s/it]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.40it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.39it/s]
18it [00:07,  2.

In [6]:
edm_multistep_fid_score = fid_utils.calculate_fid(edm_multistep)
print(f"EDM multistep score: {edm_multistep_fid_score:.3f}")
cm_onestep_fid_score = fid_utils.calculate_fid(cm_onestep)
print(f"CM onestep score: {cm_onestep_fid_score:.3f}")

Loading cifar10 statistics


  0%|          | 0/50176 [00:00<?, ?it/s]

100%|██████████| 50176/50176 [03:03<00:00, 273.57it/s]
100%|██████████| 1003/1003 [04:58<00:00,  3.36it/s]


EDM multistep score: 5.840
Loading cifar10 statistics


100%|██████████| 50176/50176 [02:31<00:00, 330.71it/s]
100%|██████████| 1003/1003 [04:00<00:00,  4.17it/s]


CM onestep score: 16.930


In [17]:
# Next, do some CM multistep sampling

def stochastic_iterative_sampler(
    rng,
    num_sampling,
    distiller: CMFramework,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    t_max_rho = t_max ** (1 / rho)
    t_min_rho = t_min ** (1 / rho)

    params = flax.jax_utils.replicate(distiller.torso_state.params_ema)
    sampling_fn = distiller.p_sample_cm # params: sampling_params, latent_sample, rng_key, gamma, t_max, t_min
    
    # Sampling x from sampler
    input_shape = (jax.local_device_count(), num_sampling // jax.local_device_count(), 32, 32, 3)
    rng, sampling_key = jax.random.split(rng, 2)
    x = jax.random.normal(sampling_key, input_shape) * t_max
    

    for i in range(len(ts) - 1):
        t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho

        # x0 = distiller(x, t * s_in)
        sampling_key, p_sample_key = jax.random.split(sampling_key, 2)
        p_sample_key = jax.random.split(p_sample_key, jax.local_device_count())
        
        t_param = jnp.asarray([t] * jax.local_device_count())
        t_min_param = jnp.asarray([t_min] * jax.local_device_count())
        gamma = jnp.zeros((jax.local_device_count(),))

        x0 = sampling_fn(params, x, p_sample_key, gamma, t_param, t_min_param)
        next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = jnp.clip(next_t, t_min, t_max)

        rng, normal_rng = jax.random.split(rng, 2)
        x = x0 + jax.random.normal(normal_rng, input_shape) * jnp.sqrt(next_t**2 - t_min**2)
    
    return x

In [18]:
def get_fid(rng, sampling_fn, p, begin=(0,), end=(17, ), sample_dir="."):
    total_size = 50000
    batch_size = sampling_batch_num
    current_sampling_num = 0

    ts = begin + (p,) + end
    tmp_dir = os.path.join(sample_dir, f"{p}")
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)

    filelist = os.listdir(tmp_dir)
    
    if len(filelist) == total_size:
        fid = fid_utils.calculate_fid(tmp_dir)
        return fid

    while current_sampling_num < total_size:
        effective_sampling = batch_size \
            if (total_size - current_sampling_num) // batch_size != 0 \
            else total_size - current_sampling_num

        rng, sampling_rng = jax.random.split(rng, 2)
        x0 = sampling_fn(rng=sampling_rng, num_sampling=effective_sampling, ts=ts)
        x0 = x0.reshape(-1, *x0.shape[-3:])
        fs_utils.save_images_to_dir(x0, tmp_dir, current_sampling_num)
        current_sampling_num += effective_sampling
    
    fid = fid_utils.calculate_fid(tmp_dir)
    return fid

In [19]:
def ternery_search(rng, sampling_fn, fid_obj, before, after, sample_dir="."):
    right = after[0]
    left = before[-1]
    while right - left >= 3:
        m1 = int(left + (right - left) / 3.0)
        m2 = int(right - (right - left) / 3.0)
        f1 = get_fid(rng, sampling_fn, m1, before, after, sample_dir=sample_dir)

        # logger.log(f"fid at m1 = {m1} is {f1}, IS is {is1}")
        print(f"fid at m1 = {m1} is {f1}")
        f2 = get_fid(rng, sampling_fn, m2, before, after, sample_dir=sample_dir)

        print(f"fid at m2 = {m2} is {f2}")

        if f1 < f2:
            right = m2
        else:
            left = m1

        print(f"new interval is [{left}, {right}]")

    if right == left:
        p = right
    elif right - left == 1:
        f1 = get_fid(rng, sampling_fn, left, before, after, sample_dir=sample_dir)
        f2 = get_fid(rng, sampling_fn, right, before, after, sample_dir=sample_dir)
        p = m1 if f1 < f2 else m2
    elif right - left == 2:
        mid = left + 1
        f1 = get_fid(rng, sampling_fn, left, before, after, sample_dir=sample_dir)
        f2 = get_fid(rng, sampling_fn, right, before, after, sample_dir=sample_dir)
        fmid = get_fid(rng, sampling_fn, mid, before, after, sample_dir=sample_dir)

        print(f"fmid at mid = {mid} is {fmid}")

        if fmid < f1 and fmid < f2:
            p = mid
        elif f1 < f2:
            p = m1
        else:
            p = m2

    return p

In [20]:
steps = 18
begin = (0,)
end = (steps - 1,)

In [21]:
sampler = partial(stochastic_iterative_sampler, # need rng, num_sampling, ts
    distiller=consistency_framework,
    t_min=sigma_min,
    t_max=sigma_max,
    rho=rho,
    steps=steps)
rng, cm_multistep_sampling = jax.random.split(rng, 2)

ternery_search(cm_multistep_sampling, sampler, fid_utils, begin, end, cm_multistep)

Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 332.18it/s]
100%|██████████| 1000/1000 [03:37<00:00,  4.59it/s]


fid at m1 = 5 is 14.208711181519902
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:32<00:00, 328.13it/s]
100%|██████████| 1000/1000 [03:45<00:00,  4.44it/s]


fid at m2 = 11 is 13.182310649979286
new interval is [5, 17]
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:32<00:00, 328.25it/s]
100%|██████████| 1000/1000 [03:44<00:00,  4.45it/s]


fid at m1 = 9 is 12.208069572387103
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:29<00:00, 333.48it/s]
100%|██████████| 1000/1000 [03:43<00:00,  4.48it/s]


fid at m2 = 13 is 16.016940955749988
new interval is [5, 13]
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 333.04it/s]
100%|██████████| 1000/1000 [03:42<00:00,  4.50it/s]


fid at m1 = 7 is 13.245693092105
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 331.43it/s]
100%|██████████| 1000/1000 [03:44<00:00,  4.46it/s]


fid at m2 = 10 is 12.416339673587913
new interval is [7, 13]
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 332.27it/s]
100%|██████████| 1000/1000 [03:44<00:00,  4.46it/s]


fid at m1 = 9 is 12.208069572387103
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 332.14it/s]
100%|██████████| 1000/1000 [03:43<00:00,  4.48it/s]


fid at m2 = 11 is 13.182310649979286
new interval is [7, 11]
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 331.53it/s]
100%|██████████| 1000/1000 [03:45<00:00,  4.44it/s]


fid at m1 = 8 is 12.763328256248315
Loading cifar10 statistics


100%|██████████| 50000/50000 [02:30<00:00, 331.98it/s]
100%|██████████| 1000/1000 [03:43<00:00,  4.48it/s]


fid at m2 = 9 is 12.208069572387103
new interval is [8, 11]
