# All in one sampling ipynb

This file is for evaluation of CM-SH model that measure the FID score of 
1. one-step sampling for CM
2. multi-step sampling for CM
3. multi-step sampling for SH 

In [1]:
import jax
import jax.numpy as jnp
import flax

import matplotlib.pyplot as plt
import yaml
import omegaconf
import hydra
from hydra import initialize, compose

from framework.unifying_framework import UnifyingFramework
from framework.diffusion.consistency_framework import CMFramework
from framework.diffusion.edm_framework import EDMFramework

from utils import common_utils
from utils.fid_utils import FIDUtils
from utils.fs_utils import FSUtils

from tqdm import tqdm

from functools import partial

import os 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Make new directory for storing samples for each sampling mode
default_path = "tmp/1024_20/"
edm_multistep = default_path+ "edm_multistep"
cm_onestep = default_path+ "cm_onestep"
cm_multistep = default_path+ "cm_multistep"
dir_list = [cm_onestep, cm_multistep, edm_multistep]

for dir_elem in dir_list:
    os.makedirs(dir_elem, exist_ok=True)

In [3]:
sampling_batch_num = 512
total_sampling_num = 50000
sigma_max = 80.0
sigma_min = 0.002
rho = 7

In [4]:
# Setting CM framework for sampling

config_path = "configs"
default_config_path = "config_1024_20_2560"
rng = jax.random.PRNGKey(42)

with initialize(version_base=None, config_path=config_path):
    default_config = compose(config_name=default_config_path)
default_config["do_training"] = False
model_type = default_config.type

rng, denoiser_rng = jax.random.split(rng, 2)
rng, consistency_rng = jax.random.split(rng, 2)
# denoiser_framework = diffusion_framework.framework
fid_utils = FIDUtils(default_config)
fs_utils = FSUtils(default_config)

consistency_framework = CMFramework(default_config, consistency_rng, fs_utils, None)

In [5]:
print(default_config)

{'available_gpus': '0, 1, 2, 3, 4, 5, 6, 7', 'dataset': {'data_size': [32, 32, 3], 'name': 'cifar10'}, 'do_sampling': True, 'do_training': False, 'ema': {'beta': 0.99993}, 'exp': {'autoencoder_prefix': 'ae_', 'best_dir': 'experiments/0114_iCT_1024_20_2560/checkpoints/best', 'checkpoint_dir': 'experiments/0114_iCT_1024_20_2560/checkpoints', 'cm_prefix': 'cm_', 'current_exp_dir': 'experiments/0114_iCT_1024_20_2560', 'diffusion_prefix': 'diffusion_', 'discriminator_prefix': 'discriminator_', 'exp_dir': 'experiments', 'in_process_dir': 'experiments/0114_iCT_1024_20_2560/in_process', 'sampling_dir': 'experiments/0114_iCT_1024_20_2560/sampling'}, 'exp_name': '0114_iCT_1024_20_2560', 'fid_during_training': True, 'framework': {'diffusion': {'alignment_loss': False, 'alignment_loss_scale': 1.0, 'alignment_loss_weight': 'lognormal', 'alignment_threshold': 50000, 'gradient_flow_from_head': True, 'joint_training_weight': 1, 'learn_sigma': False, 'loss': 'huber', 'n_timestep': 18, 'params_ema_for_t

In [6]:
# First, save SH and CM simultaneously if CM-SH mode
current_num = 0
print("Sampling CM onestep and SH multistep")
sampling_shape = (sampling_batch_num, 32, 32, 3)

while total_sampling_num > current_num:
    effective_batch_size = total_sampling_num - current_num if current_num + sampling_batch_num > total_sampling_num else sampling_batch_num
    # if not default_config.framework.diffusion.only_cm_training:
    #     edm, cm = consistency_framework.sampling_edm_and_cm(sampling_batch_num)
    #     edm = jnp.reshape(edm, sampling_shape)
    #     fs_utils.save_images_to_dir(edm, starting_pos=current_num, save_path_dir=edm_multistep)
    # else:
    #     cm = consistency_framework.sampling_cm(sampling_batch_num)
    edm, cm = consistency_framework.sampling_edm_and_cm(sampling_batch_num)
    edm = jnp.reshape(edm, sampling_shape)
    fs_utils.save_images_to_dir(edm, starting_pos=current_num, save_path_dir=edm_multistep)
    
    cm = jnp.reshape(cm, sampling_shape)
    fs_utils.save_images_to_dir(cm, starting_pos=current_num, save_path_dir=cm_onestep)
    current_num += sampling_batch_num

Sampling CM onestep and SH multistep


18it [01:29,  4.96s/it]
18it [00:36,  2.02s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.04s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.03s/it]
18it [00:36,  2.

In [None]:
fs_utils.delete_images_from_dir(edm_multistep, starting_pos=total_sampling_num)
fs_utils.delete_images_from_dir(cm_onestep, starting_pos=total_sampling_num)

In [7]:
# if not default_config.framework.diffusion.only_cm_training:
edm_multistep_fid_score = fid_utils.calculate_fid(edm_multistep)
print(f"EDM multistep score: {edm_multistep_fid_score:.3f}")
cm_onestep_fid_score = fid_utils.calculate_fid(cm_onestep)
print(f"CM onestep score: {cm_onestep_fid_score:.3f}")

Loading cifar10 statistics


100%|██████████| 50176/50176 [04:30<00:00, 185.65it/s]


EDM multistep score: 2.888
Loading cifar10 statistics


100%|██████████| 50176/50176 [04:22<00:00, 191.23it/s]


CM onestep score: 3.865


In [13]:
# Next, do some CM multistep sampling
multisampling_fid_dict = {}

In [14]:


def stochastic_iterative_sampler(
    rng,
    num_sampling,
    distiller: CMFramework,
    ts,
    t_min=0.002,
    t_max=80.0,
    rho=7.0,
    steps=40,
):
    # t_max_rho = t_max ** (1 / rho)
    # t_min_rho = t_min ** (1 / rho)

    params = flax.jax_utils.replicate(distiller.torso_state.params_ema)
    sampling_fn = distiller.p_sample_cm # params: sampling_params, latent_sample, rng_key, gamma, t_max, t_min
    
    # Sampling x from sampler
    input_shape = (jax.local_device_count(), num_sampling // jax.local_device_count(), 32, 32, 3)
    rng, sampling_key = jax.random.split(rng, 2)
    x = jax.random.normal(sampling_key, input_shape) * t_max
    
    for i in range(len(ts) - 1):
        # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        t = ts[i]

        # x0 = distiller(x, t * s_in)
        sampling_key, p_sample_key = jax.random.split(sampling_key, 2)
        p_sample_key = jax.random.split(p_sample_key, jax.local_device_count())
        
        t_param = jnp.asarray([t] * jax.local_device_count())
        t_min_param = jnp.asarray([t_min] * jax.local_device_count())
        gamma = jnp.zeros((jax.local_device_count(),))

        x0 = sampling_fn(params, x, p_sample_key, gamma, t_param, t_min_param)
        # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho
        next_t = ts[i+1]
        next_t = jnp.clip(next_t, t_min, t_max)

        rng, normal_rng = jax.random.split(rng, 2)
        x = x0 + jax.random.normal(normal_rng, input_shape) * jnp.sqrt(next_t**2 - t_min**2)
    
    return x

In [15]:
def get_fid(rng, sampling_fn, p, begin=(0,), end=(17, ), sample_dir="."):
    total_size = 50000
    batch_size = sampling_batch_num
    current_sampling_num = 0

    ts = begin + (p,) + end

    if multisampling_fid_dict.get(ts, None) is not None:
        return multisampling_fid_dict[ts]

    tmp_dir = os.path.join(sample_dir, f"{p}")
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir)

    filelist = os.listdir(tmp_dir)
    
    if len(filelist) == total_size:
        fid = fid_utils.calculate_fid(tmp_dir)
        multisampling_fid_dict[ts] = fid
        return fid

    while current_sampling_num < total_size:
        effective_sampling = batch_size \
            if (total_size - current_sampling_num) // batch_size != 0 \
            else total_size - current_sampling_num

        rng, sampling_rng = jax.random.split(rng, 2)
        x0 = sampling_fn(rng=sampling_rng, num_sampling=effective_sampling, ts=ts)
        x0 = x0.reshape(-1, *x0.shape[-3:])
        fs_utils.save_images_to_dir(x0, tmp_dir, current_sampling_num)
        current_sampling_num += effective_sampling
    
    fid = fid_utils.calculate_fid(tmp_dir)
    multisampling_fid_dict[ts] = fid
    return fid

In [16]:
steps = 18
# begin = (0,)
# end = (steps - 1,)
begin = (80.0, )
end = (0.002, )

In [17]:
sampler = partial(stochastic_iterative_sampler, # need rng, num_sampling, ts
    distiller=consistency_framework,
    t_min=sigma_min,
    t_max=sigma_max,
    rho=rho,
    steps=steps)
rng, cm_multistep_sampling = jax.random.split(rng, 2)

# ternery_search(cm_multistep_sampling, sampler, fid_utils, begin, end, cm_multistep)
get_fid(cm_multistep_sampling, sampler, 0.821, begin, end, cm_multistep)

AttributeError: module 'os' has no attribute 'makdeirs'