In [1]:
import jax
import orbax
import flax
import tensorflow as tf

import jax.numpy as jnp
from jax import random
import numpy as np

import sys
import os
sys.path.append('../')
import datasets

from flax.training import checkpoints
from models import utils as mutils
from models import ddpm
import eval_utils as eutils
import evaluation
import train_utils as tutils
import diffrax
from functools import partial
from tqdm import trange
from run_lib import init_model
from dynamics import get_vpsde

2024-11-18 12:02:36.942077: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-18 12:02:37.215399: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-18 12:02:37.298047: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from configs.sm.cifar import vpsde as config
config = config.get_config()

In [3]:
def get_pools(sample_dir):
    all_pools = []
    stats = tf.io.gfile.glob(os.path.join(sample_dir, "statistics_*.npz"))
    for stat_file in stats:
        with tf.io.gfile.GFile(stat_file, "rb") as fin:
            stat = np.load(fin)
            all_pools.append(stat["pool_3"])
    all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]
    return all_pools

def load_dataset_stats(config, eval=False):
  """Load the pre-computed dataset statistics."""
  suffix = 'test' if eval else 'train'
  if config.data.dataset == 'CIFAR10':
    filename = f'../assets/stats/cifar10_{suffix}_stats.npz'
  else:
    raise ValueError(f'Dataset {config.data.dataset} stats not found.')

  with tf.io.gfile.GFile(filename, 'rb') as fin:
    stats = np.load(fin)
    return stats

train_pools = load_dataset_stats(config, eval=False)
test_pools = load_dataset_stats(config, eval=True)

def get_fid(pools):
    train_fid = evaluation.fid(train_pools["pool_3"], pools)
    test_fid = evaluation.fid(test_pools["pool_3"], pools)
    print(f'train FID: {train_fid}, test FID: {test_fid}', flush=True)

## FID Uncond

In [4]:
pools_det = get_pools('/network/scratch/k/kirill.neklyudov/5617628/eval/samples')
pools_stoch = get_pools('/network/scratch/k/kirill.neklyudov/5617628/eval/samples_stoch')

In [7]:
get_fid(pools_det)
get_fid(pools_stoch)
get_IS(pools_det)
get_IS(pools_stoch)

train FID: 6.0028710668863265, test FID: 8.321379992849277
train FID: 3.4969647777767126, test FID: 5.67004534558804
IS: [8.947068]
IS: [9.142855]


## FID Cond

In [4]:
pools_joint_det = get_pools('../checkpoint/inv_ab_joint_vf/eval/samples_stoch/')
pools_joint_stoch = get_pools('../checkpoint/high_temp_disjoint_joint_vf/eval/samples_stoch/')

In [6]:
get_fid(pools_joint_det)
get_fid(pools_joint_stoch)
get_IS(pools_joint_det)
get_IS(pools_joint_stoch)

train FID: 4.40918801546627, test FID: 6.471772195304419
train FID: 4.0057497035083145, test FID: 6.075971518519671
IS: [9.383149]
IS: [9.4856415]


In [4]:
pools_a = get_pools('/network/scratch/k/kirill.neklyudov/5294839/eval/samples')
pools_b = get_pools('/network/scratch/k/kirill.neklyudov/5294900/eval/samples')
pools_a_stoch = get_pools('/network/scratch/k/kirill.neklyudov/5294839/eval/samples_stoch')
pools_b_stoch = get_pools('/network/scratch/k/kirill.neklyudov/5294900/eval/samples_stoch')
pools_joint_det = get_pools('../checkpoint/cond_joint_vf/eval/samples/')
pools_joint_stoch = get_pools('../checkpoint/cond_joint_vf/eval/samples_stoch/')

In [5]:
get_fid(pools_a)
get_fid(pools_b)
get_fid(pools_a_stoch)
get_fid(pools_b_stoch)
get_fid(pools_joint_det)
get_fid(pools_joint_stoch)

train FID: 5.298048498542919, test FID: 7.583245178049728
train FID: 4.684625393252006, test FID: 6.7744546936673915
train FID: 2.834219423529825, test FID: 4.983764363001981
train FID: 4.857014731474025, test FID: 6.911617939993418
train FID: 4.032961132269818, test FID: 6.2216255184786675
train FID: 3.4742407122197836, test FID: 5.585484557959523


In [6]:
pools_mixed = jnp.vstack([pools_a, pools_b])
get_fid(pools_mixed[::2])
pools_mixed = jnp.vstack([pools_a[:25000], pools_b[:25000]])
get_fid(pools_mixed)
pools_mixed_stoch = jnp.vstack([pools_a_stoch[:25000], pools_b_stoch[:25000]])
get_fid(pools_mixed_stoch)

2024-09-22 21:23:18.008910: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.68). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


train FID: 19.805255889892578, test FID: 21.804908752441406
train FID: 4.090967178344727, test FID: 6.273442268371582
train FID: 3.55703067779541, test FID: 5.653024673461914


## IS

In [6]:
model = tf.keras.applications.InceptionV3(
    include_top=True,
    weights='imagenet',
    input_shape=(299, 299, 3),
    pooling='avg')

def get_IS(pools):
    def h(p):
        return -(jnp.log(p)*p).sum(1)
    W,b = model.layers[-1].get_weights()
    probs = jax.nn.softmax(pools@W + b)
    IS = jnp.exp(h(probs.mean(0,keepdims=True)) - h(probs).mean())
    print(f'IS: {IS}')

### sanity check

In [8]:
get_IS(train_pools['pool_3'])
get_IS(jax.random.normal(jax.random.PRNGKey(1), shape=train_pools['pool_3'].shape))

IS: [10.851412]
IS: [3.3749776]


## IS cond

In [9]:
get_IS(pools_a)
get_IS(pools_b)
get_IS(pools_a_stoch)
get_IS(pools_b_stoch)
get_IS(pools_joint_det)
get_IS(pools_joint_stoch)

IS: [9.042656]
IS: [9.149275]
IS: [9.444546]
IS: [9.538941]
IS: [9.061275]
IS: [9.533243]


In [10]:
pools_mixed = jnp.vstack([pools_a, pools_b])
get_IS(pools_mixed[::2])
pools_mixed = jnp.vstack([pools_a[:25000], pools_b[:25000]])
get_IS(pools_mixed)
pools_mixed_stoch = jnp.vstack([pools_a_stoch[:25000], pools_b_stoch[:25000]])
get_IS(pools_mixed_stoch)

IS: [8.255488]
IS: [9.13329]
IS: [9.499198]
