In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from flax import linen as nn
import optax
from tqdm.auto import tqdm

from jax import jit

import tensorflow as tf
import tensorflow_datasets as tfds
from cone.utils.misc import pshape
from einops import rearrange
from cone.utils.plot import scatter_movie, imshow_movie
from cone.utils.misc import meanvmap, tracewrap
from jax import jacrev, jacfwd, vmap

In [None]:
# Load the EMNIST dataset with (image, label) pairs
ds_train = tfds.load("emnist/balanced", split="train", as_supervised=True)
ds_test = tfds.load("emnist/balanced", split="test", as_supervised=True)


def preprocess_emnist(image, label):
    image = tf.image.rot90(image, k=3)
    image = tf.image.flip_left_right(image)
    image = tf.cast(image, tf.float32) / 255.0
    return image, label


batch_size = 512
# Preprocess the training dataset
ds_train = ds_train.map(preprocess_emnist, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.shuffle(10000)
ds_train = ds_train.batch(batch_size)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# Preprocess the test dataset
ds_test = ds_test.map(preprocess_emnist, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Convert datasets to NumPy iterators for JAX
ds_train = tfds.as_numpy(ds_train)
ds_test = tfds.as_numpy(ds_test)

In [5]:
def get_loss_fn(vae, beta):
    def loss_fn(params, x, key):

        logits, mu, logvar = vae.apply(params, x, rngs={"latent": key})

        recon_loss = (
            optax.sigmoid_binary_cross_entropy(logits=logits, labels=x)
            .sum(axis=[1, 2, 3])
            .mean()
        )

        kl_loss = -0.5 * jnp.mean(
            jnp.sum(1 + logvar - (mu) ** 2 - jnp.exp(logvar), axis=-1)
        )

        # vars = jnp.ones_like(y)

        # kl_loss = kl_divergence_v(biases, vars, mu, logvar)

        total_loss = recon_loss + beta * kl_loss
        return total_loss, (recon_loss, kl_loss)

    return loss_fn

In [None]:
from cone.net.vae import VAE
from cone.utils.misc import count_params


latent_features = 32
encoder_features = [1, 64, 48, 32]
decoder_features = [32, 48, 64, 1]
kernel = 2
key = jax.random.PRNGKey(1)
vae = VAE(
    latent_features=latent_features,
    encoder_features=encoder_features,
    decoder_features=decoder_features,
    kernel_size=kernel,
    padding="SAME",
)


x_dummy = jnp.zeros((batch_size, 28, 28, 1))
key = jax.random.PRNGKey(0)
p_key, l_key, key = jax.random.split(key, num=3)
init_rng = {"params": p_key, "latent": l_key}

In [7]:
def get_samplers(vae, params, x):
    _, _, shape = vae.apply(params, x, method=vae.encode)

    @jit
    def encode(params, x, key):
        z = vae.apply(params, x, key, method=vae.encode_latent)
        return z

    @jit
    def decode(params, x):
        logits = vae.apply(params, x, shape, method=vae.decode)
        return nn.sigmoid(logits)

    return encode, decode

In [None]:
params_init = vae.init(init_rng, x_dummy)
encoder, decoder = get_samplers(vae, params_init, x_dummy)
print(count_params(params_init))

n_epochs = 25
# set up the optimizer
n_batches = len(ds_train)
n_steps = n_epochs * n_batches
learning_rate = optax.cosine_decay_schedule(1e-3, decay_steps=n_epochs * n_steps)
optimizer = optax.adam(learning_rate)
optimizer_state = optimizer.init(params_init)


beta = 0.01
loss_fn = get_loss_fn(vae, beta)


@jit
def train_step(params, x, optimizer_state, key):
    skey, key = jax.random.split(key)
    (loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, skey)
    updates, optimizer_state = optimizer.update(grads, optimizer_state)
    params = optax.apply_updates(params, updates)
    return params, optimizer_state, loss, aux, key


opt_params = params_init
# training loop
pbar = tqdm(total=n_batches)
for epoch in range(n_epochs):

    pbar.reset(total=n_batches)
    pbar.set_description(f"Epoch {epoch + 1}/{n_epochs}")

    for batch in ds_train:

        x, y = batch

        opt_params, optimizer_state, loss, aux, key = train_step(
            opt_params, x, optimizer_state, key
        )
        rc_loss, kl_loss = aux
        pbar.update(1)
        pbar.set_postfix(
            {
                "loss": f"{loss:.2E}",
                "rc_loss": f"{rc_loss:.2E}",
                "kl_loss": f"{kl_loss:.2E}",
            }
        )

In [None]:
n_plt = 25
n_grid = int(np.sqrt(n_plt))
logits, mean, logvar = vae.apply(opt_params, x[:n_plt], rngs={"latent": key})
logits = rearrange(logits, "(n1 n2) h w c -> (n1 h) (n2 w) c", n1=n_grid, n2=n_grid)
recon = nn.sigmoid(logits)
plt.imshow(recon)
plt.colorbar()
plt.show()

In [None]:
z = np.random.randn(n_plt, latent_features)
samples = decoder(opt_params, z)
samples = rearrange(samples, "(n1 n2) h w c -> (n1 h) (n2 w) c", n1=n_grid, n2=n_grid)
plt.imshow(samples)
plt.colorbar()
plt.show()

In [None]:
latents = []
labels = []
n_test_batches = 512
for i, (x, y) in enumerate(ds_train):
    key, skey = jax.random.split(key)
    z = encoder(opt_params, x, skey)
    latents.append(z)
    labels.append(y)
    if i > n_test_batches:
        break

latents = np.concatenate(latents)
labels = np.concatenate(labels)
latents.shape, labels.shape

In [12]:
t_latents = []
classes = np.sort(np.unique(labels))
min_amt = np.inf
for c in classes:
    d = latents[labels == c]
    t_latents.append(d)
    min_amt = min(min_amt, len(d))

t_latents = np.asarray([tl[:min_amt] for tl in t_latents])
t_eval = jnp.asarray(classes) / classes.max()

In [13]:
from cone.utils.misc import get_rand_idx, interplate_in_t
from cone.integrate.quad import get_simpson_quadrature, get_gauss_quadrature


def get_sample_fn(
    X_data, t_data, bs_tau=16, bs_t=256, bs_n=256, quad="simp", jit_fn=True
):

    t_data = t_data / t_data[-1]  # normalize in [0,1]

    # if we are doing monte carlo, we dont need to interpolate
    if quad != "mc":
        if quad == "simp":
            # odd number of points necessary for simpsons
            if (bs_t - 1) % 2 != 0:
                bs_t += 1
            t_batch, quad_weights = get_simpson_quadrature(bs_t)
        elif quad == "gauss":
            t_batch, quad_weights = get_gauss_quadrature(bs_t)

        # add start and end points for boundary term
        start, end = jnp.asarray([0]), jnp.asarray([1.0])
        t_batch = jnp.concatenate([start, t_batch, end])
        X_batch = interplate_in_t(X_data, t_data, t_batch)

        X_batch = jnp.asarray(X_batch)
        t_batch = jnp.asarray(t_batch)
    else:

        X_batch = jnp.asarray(X_data)
        t_batch = jnp.asarray(t_data)
        quad_weights = jnp.ones(bs_t) / bs_t

    t_batch = t_batch.reshape(-1, 1)

    def sample_fn(in_key):

        nonlocal X_batch
        nonlocal t_batch

        T, N, D = X_batch.shape
        T, one = t_batch.shape

        in_key, tau_key = jax.random.split(in_key)

        tau_batch = jax.random.uniform(tau_key, (bs_tau, 1))
        tau_batch = jnp.sort(tau_batch)
        bs_t_a = min(bs_t, T - 1)

        if quad == "mc":
            in_key, key_t = jax.random.split(in_key)
            t_idx = jax.random.choice(key_t, T - 1, shape=(bs_t_a,), replace=False)
            t_idx = jnp.sort(t_idx)
            t_batch = t_batch[t_idx]
            X_batch = X_batch[t_idx]

        # keys = jax.random.split(in_key, num=X_batch.shape[0])
        # sample_idx = vmap(get_rand_idx, (0, None, None))(keys, N, bs_n)
        # rows = jnp.arange(X_batch.shape[0])[:, jnp.newaxis]
        # X_batch = X_batch[rows, sample_idx]

        in_key, x_key = jax.random.split(in_key)
        sample_idx = get_rand_idx(x_key, N, bs_n)
        X_batch = X_batch[:, sample_idx]

        return tau_batch, X_batch, t_batch

    if jit_fn:
        sample_fn = jit(sample_fn)

    return sample_fn

In [16]:
from cone.utils.misc import meanvmap, tracewrap, sqwrap, key_tensor, fold_in_data
from jax import jacrev, jacfwd, grad


def get_flow_loss_fn(s_fn):

    @sqwrap
    def alpha(tau):
        t_fn = lambda tau: jnp.cos(jnp.pi * tau) ** 2
        f_fn = lambda tau: tau * 0.0
        res = jax.lax.cond(tau >= 0.5, t_fn, f_fn, tau)
        return res

    @sqwrap
    def beta(tau):
        t_fn = lambda tau: jnp.cos(jnp.pi * tau) ** 2
        f_fn = lambda tau: tau * 0.0
        res = jax.lax.cond(tau < 0.5, t_fn, f_fn, tau)
        return res

    @sqwrap
    def gamma(tau):
        return jnp.sin(jnp.pi * tau) ** 2
        #return jnp.sqrt(2 * tau * (1 - tau))
    # @sqwrap
    # def alpha(tau):
    #     return tau

    # @sqwrap
    # def beta(tau):
    #     return 1-tau

    def interpolant(tau, xt, xt_m1):
        return alpha(tau) * xt + beta(tau) * xt_m1

    @sqwrap
    def interpolant(tau, xt_p1, xt):
        return alpha(tau) * xt_p1 + beta(tau) * xt

    interpolant_dt = jacrev(interpolant)
    gamma_dt = grad(gamma)

    def flow_match(x_tp1, xt, t, tau, params, key):

        r_data = fold_in_data(x_tp1, xt, t, tau)
        key = jax.random.fold_in(key, r_data)

        tau = jnp.squeeze(tau)
        t = jnp.squeeze(t)

        D = x_tp1.shape[0]

        noise_i, noise_l = jax.random.normal(key, shape=(2, D))

        dt_i = interpolant_dt(tau, x_tp1, xt)

        g_dt = gamma_dt(tau)
        dt_i = jnp.squeeze(dt_i)

        x_tau_plus = interpolant(tau, x_tp1, xt) + gamma(tau) * noise_i
        s_plus = s_fn(tau.reshape(1), x_tau_plus, t.reshape(1), params)

        x_tau_minus = interpolant(tau, x_tp1, xt) - gamma(tau) * noise_i
        s_minus = s_fn(tau.reshape(1), x_tau_minus, t.reshape(1), params)

        l_plus = jnp.dot(s_plus, s_plus) - 2 * jnp.dot(s_plus, dt_i + g_dt * noise_l)
        l_minus = jnp.dot(s_minus, s_minus) - 2 * jnp.dot(
            s_minus, dt_i + g_dt * noise_l
        )

        l = (l_plus + l_minus) / 2
        return jnp.squeeze(l)

    fm_Vx = meanvmap(flow_match, in_axes=(0, 0, None, None, None, None))
    fm_Vx_Vtau = meanvmap(fm_Vx, in_axes=(None, None, None, 0, None, None))
    fm_Vx_Vtau_Vt = meanvmap(fm_Vx_Vtau, in_axes=(0, 0, 0, None, None, None))

    def loss_fn(params, tau_batch, X_tp1, X_t, t, key):


        loss = fm_Vx_Vtau_Vt(X_tp1, X_t, t, tau_batch, params, key)

        return loss

    return loss_fn

In [17]:
from cone.utils.misc import get_rand_idx, interplate_in_t
from cone.integrate.quad import get_simpson_quadrature, get_gauss_quadrature


def get_sample_fn(X_data, t_data, bs_tau=16, bs_t=256, bs_n=256):

    t_data = t_data / t_data[-1]  # normalize in [0,1]
    X_data = jnp.asarray(X_data)
    t_data = jnp.asarray(t_data).reshape(-1, 1)

    def sample_fn(in_key):

        nonlocal X_data
        nonlocal t_data

        T, N, D = X_data.shape
        T, one = t_data.shape

        key_t, x_key, tau_key = jax.random.split(in_key, num=3)

        # sample tau
        tau_batch = jax.random.uniform(tau_key, (bs_tau, 1), minval=0.0, maxval=1.0)
        tau_batch = jnp.sort(tau_batch)

        # sample x
        sample_idx = get_rand_idx(x_key, N, bs_n)
        X_batch = X_data[:, sample_idx]

        # sample t
        t_idx = jax.random.choice(key_t, 10, shape=(1,), replace=False)
        t_idx = jnp.squeeze(t_idx).astype(jnp.int32)

        
        t = t_data[t_idx:t_idx+1]
        X_t = X_batch[t_idx:t_idx+1]
        X_tp1 = X_batch[t_idx+1:t_idx+2]

        return tau_batch, X_tp1, X_t, t


    return sample_fn

In [18]:
from cone.net.mlp import MLP
from cone.net.adam import adam_opt
from cone.utils.misc import normalize_data

features = [64, 64, 64, 64, 64, latent_features]
s_net = MLP(features=features, squeeze=True)


def s_fn(tau, x, t, params):
    tau_x_t = jnp.concatenate([tau, x, t])
    return s_net.apply(params, tau_x_t)


x_dummy = jnp.ones((64, latent_features + 2))
s_params_init = s_net.init(key, x_dummy)


loss_fn = get_flow_loss_fn(s_fn)
sample_fn = get_sample_fn(t_latents, t_eval, bs_tau=1)

In [None]:
import random
key = jax.random.PRNGKey(random.randint(0,1e6))
tau_batch, X_t, X_tp1, t = sample_fn(key)
jnp.squeeze(t)

In [None]:
plt.scatter(x=X_t[0, :, 0], y=X_t[0, :, 1])
plt.scatter(x=X_tp1[0, :, 0], y=X_tp1[0, :, 1])
plt.show()

In [None]:
from cone.net.adam import adam_opt

opt_params_s, loss_history = adam_opt(
    s_params_init,
    loss_fn,
    sample_fn,
    steps=5000,
    learning_rate=5e-3,
    verbose=True,
    key=key,
    loss_key=True,
)

In [22]:
from cone.integrate.sde import odeint_rk4


def solve_test_cfm(s_fn, params, ics, t_int, T):

    s_Vx = vmap(s_fn, (None, 0, None, None))

    @jit
    def integrate(ics, physical_t, params):

        def fn(tau, y):

            return s_Vx(tau, y, physical_t, params)

        sol = odeint_rk4(fn, ics, taus)

        return sol

    sols = []
    taus = jnp.linspace(0, 1, T).reshape(-1, 1)
    t_int = t_int.reshape(-1, 1)
    for physical_t in tqdm(t_int, desc="CFM"):
        sol = integrate(ics, physical_t, params)
        sols.append(sol)
        ics = sol[-1]
        
        # test_sol.append(ics)

    sols = jnp.concatenate(sols)
    sols = jnp.squeeze(sols)

    return sols

In [None]:
ic = t_latents[0]
t_int = t_eval
T_tau = 32
n_plot = 1000
idx_sample = np.linspace(0, ic.shape[0] - 1, n_plot, dtype=np.uint32)

sol_cfm = solve_test_cfm(s_fn, opt_params_s, ic[idx_sample], t_int, T_tau)