In [1]:
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from jaxtyping import Key, Array
import optax
import tensorflow_probability.substrates.jax.distributions as tfd
from sklearn import datasets
from tqdm.notebook import trange

from models import ResidualNetwork

In [2]:
class Encoder(eqx.Module):
    net: eqx.nn.MLP

    def __init__(self, mlp_kwargs, *, key):
        self.net = eqx.nn.MLP(**mlp_kwargs, key=key)
    
    def __call__(self, x_t, t):
        # This needs to output mu(x_t, t), sigma(x_t, t) for variational distribution
        # return self.net(jnp.concatenate([x_t.flatten(), t])) 
        # This is actually score of log(q(z|x_t, t))

        mu_t, sigma_t = jnp.split(self.net(jnp.concatenate([x_t.flatten(), t])), 2)
        return jax.scipy.stats.multivariate_normal.logpdf(z, mu_t, jnp.diag(sigma_t))

    def encode(self, x_t, t):
        return self.(x_t, t)

    def score(self, z, x_t, t):
        return jax.jacfwd(self, argnums=1)(z, x_t, t)

    def prior_log_prob_z(self, z):
        return jax.scipy.stats.norm.logpdf(z).sum()

    def kl(self, x_t, t):
        mu_t, sigma_t = jnp.split(self(x_t, t), 2) 
        return mu.T @ mu + sigma.sum() - mu.size - jnp.prod(sigma)

In [3]:
class VDAE(eqx.Module):
    encoder: Encoder
    score_network: ResidualNetwork

    def __init__(self, encoder, score_network):
        self.encoder = encoder
        self.score_network = score_network

    def score(self, x_t, t):
        return self.score_network(x_t, t) + self.encoder.score(x_t, t)


In [4]:
key = jr.key(0)

score_network = ResidualNetwork(
    in_size=2, 
    out_size=2, 
    width_size=128, 
    depth=2, 
    y_dim=1, # Just scalar time
    activation=jax.nn.gelu, 
    key=key
)

encoder = Encoder(
    dict(
        in_size=2 + 1, # [x_t, t]
        out_size=1 + 1, # Latent dim = 1
        width_size=32, 
        depth=2, 
        activation=jax.nn.tanh
    ),
    key=key
)

score_network = eqx.tree_deserialise_leaves("sgm.eqx", score_network)

vdae = VDAE(encoder, score_network)

In [5]:
def kl_loss(mu, sigma):
    # Assuming sigma is diagonal elements of covariance
    return mu.T @ mu + sigma.sum() - mu.size - jnp.prod(sigma)

def multivariate_normal(mu, sigma):
    return tfd.MultivariateNormalFullCovariance(mu, jnp.diag(sigma))

def multivariate_prior(dim):
    return tfd.MultivariateNormalFullCovariance(jnp.zeros(dim), jnp.eye(dim))

# multivariate_normal(jnp.zeros(10), jnp.eye(10)).kl_divergence(multivariate_prior(10))

In [6]:
from functools import partial

int_beta = lambda t: t  # Try experimenting with other options here!

weight = lambda t: 1 - jnp.exp(-int_beta(t))  # Just chosen to upweight the region near t=0.


def dataloader(x, batch_size, *, key):
    dataset_size = x.shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        key, subkey = jr.split(key, 2)
        perm = jr.permutation(subkey, indices)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield x[batch_perm]
            start = end
            end = start + batch_size


def single_loss_fn(encoder, score_network, weight, int_beta, x, t, key):
    # Encoder training objective given trained diffusion model
    t = jnp.atleast_1d(t)

    # Diffusion loss calculations
    mean = x * jnp.exp(-0.5 * int_beta(t))
    var = jnp.maximum(1. - jnp.exp(-int_beta(t)), 1e-5)
    std = jnp.sqrt(var)
    noise = jr.normal(key, x.shape)
    x_t = mean + std * noise

    mu_t, sigma_t = jnp.split(encoder(x_t, t), 2) 

    score = score_network(x_t, t) + encoder.score(x_t, t)

    # Score of encoder model plus score of diffusion model
    return weight(t) * jnp.square(score + noise / std) - kl_loss(mu_t, sigma_t)


def batch_loss_fn(encoder, score_network, weight, int_beta, x, t1, key):
    batch_size = x.shape[0]
    tkey, losskey = jr.split(key)
    losskey = jr.split(losskey, batch_size)
    # Low-discrepancy sampling over t to reduce variance
    t = jr.uniform(tkey, (batch_size,), minval=0., maxval=t1 / batch_size)
    t = t + (t1 / batch_size) * jnp.arange(batch_size)
    loss_fn = partial(single_loss_fn, encoder, score_network, weight, int_beta)
    loss_fn = jax.vmap(loss_fn)
    return jnp.mean(loss_fn(x, t, losskey))


@eqx.filter_jit
def make_step(encoder, score_network, weight, int_beta, x, t1, key, opt_state, opt_update):
    loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
    loss, grads = loss_fn(encoder, score_network, weight, int_beta, x, t1, key)
    updates, opt_state = opt_update(grads, opt_state, encoder)
    encoder = eqx.apply_updates(encoder, updates)
    key = jr.split(key, 1)[0]
    return loss, encoder, key, opt_state

In [7]:
opt = optax.adamw(1e-3)

opt_state = opt.init(eqx.filter(encoder, eqx.is_array)) # Gradients only needed for encoder

In [8]:
X, Y = datasets.make_moons(10_000, noise=0.05)
X, Y = jnp.asarray(X), jnp.asarray(Y)[:, jnp.newaxis]

In [9]:
key, train_key, loader_key = jr.split(key, 3) 

t1 = 1.
batch_size = 1000 
lr = 1e-3
num_steps = 100_000

int_beta = lambda t: t  # Try experimenting with other options here!
weight = lambda t: 1 - jnp.exp(-int_beta(t))  # Just chosen to upweight the region near t=0.

total_value = 0
total_size = 0
with trange(num_steps) as bar:
    for step, data in zip(
        bar, dataloader(X, batch_size, key=loader_key)
    ):
        value, encoder, train_key, opt_state = make_step(
            encoder, 
            score_network, 
            weight, 
            int_beta, 
            data, 
            t1, 
            train_key, 
            opt_state, 
            opt.update
        )
        total_value += value.item()
        total_size += 1
        if (step % 100) == 0 or step == num_steps - 1:
            bar.set_postfix_str(f"Loss={total_value / total_size:.3E}")
            total_value = 0
            total_size = 0

  0%|          | 0/100000 [00:00<?, ?it/s]

TypeError: Encoder.__call__() missing 1 required positional argument: 't'