In [1]:
%matplotlib inline
%config InlineBackend.figure_format = "retina"

from itertools import islice

import flax.linen as nn
import h5py
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import seaborn as sns
from flax.training import train_state
from jax import random
from sklearn.manifold import TSNE

from neurovae import Conv1DVAE, sse_loss, gaussian_kld, reparameterize
from helper import fig_path

sns.set_theme(context="paper", style="darkgrid", rc={"axes.facecolor": "0.96"})
fontsize = "x-large"
params = {
    "font.family": "serif",
    "font.sans-serif": ["Computer Modern"],
    "axes.labelsize": fontsize,
    "legend.fontsize": fontsize,
    "xtick.labelsize": fontsize,
    "ytick.labelsize": fontsize,
    "legend.handlelength": 2,
}
plt.rcParams.update(params)
plt.rc("text", usetex=True)

In [2]:
def create_batches(data, batch_size, drop_remainder):

    data_size = data.shape[0]
    remainder = data_size % batch_size

    if drop_remainder and remainder != 0:
        it = iter(data[: data_size - remainder, :])
    else:
        it = iter(data)

    batches = []

    while batch := tuple(islice(it, batch_size)):
        batches.append(jnp.asarray(batch))

    return batches


latent_dim = 20
output_dim = 10000
batch_size = 64
epochs = 1
seed = 42

# prepare HH data
t_sim = 100.0
dt = 0.01
N = int(t_sim / dt)
t = np.linspace(0, t_sim, N)

vs = []
for i in range(10):
    infile = f"./hh_data/hh_sim_data_{i}.h5"
    with h5py.File(infile, "r") as f:
        for grp_name in f.keys():
            v = f[grp_name]["v"][:]
            v_scaled = (v - v.min(keepdims=True)) / (v.max(keepdims=True) - v.min(keepdims=True))
            vs.append(v_scaled)

v_train = vs[:8000]
v_test = vs[8000:]

# cast to jax
v_train = jnp.asarray(v_train, dtype=jnp.float32)
v_test = jnp.asarray(v_test, dtype=jnp.float32)

batches = create_batches(v_train, batch_size, drop_remainder=True)

# set values for learning rate scheduler
total_steps = len(batches) * epochs
init_lr = 1e-3
alpha_lr = 1e-2


def model():
    return Conv1DVAE(latent_dim)


def init_model(rng):
    rng, init_key = random.split(rng)

    initial_variables = jnp.ones((batch_size, output_dim), jnp.float32)
    params = model().init(init_key, initial_variables, rng)["params"]
    del initial_variables, init_key

    lr_schedule = optax.cosine_decay_schedule(init_lr, decay_steps=total_steps, alpha=alpha_lr)
    optimizer = optax.chain(optax.clip(1.0), optax.adamw(lr_schedule, nesterov=True))

    state = train_state.TrainState.create(
        apply_fn=model().apply,
        params=params,
        tx=optimizer,
    )
    return rng, state


def compute_metrics(recon_x, x, mean, logvar):
    mse = sse_loss(recon_x, x).mean()  # mean over batch
    kld = gaussian_kld(mean, logvar).mean()  # mean over batch
    elbo = mse + kld
    return {"elbo": elbo, "mse": mse, "kld": kld}


@jax.jit
def train_step(state, batch, z_rng):

    def loss_fn(params):
        recon_x, mean, logvar = model().apply({"params": params}, batch, z_rng)
        mse = sse_loss(recon_x, batch).mean()
        kld = gaussian_kld(mean, logvar).mean()
        elbo = mse + kld
        return elbo

    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)


@jax.jit
def eval_f(params, v_traces, z, z_rng):
    def eval_model(vae):
        recon_v_traces, mean, logvar = vae(v_traces, z_rng)

        comparison = jnp.concatenate(
            [
                v_traces[:8],
                recon_v_traces[:8],
            ]
        )

        generate_v_traces = vae.generate(z, assumption="gaussian")
        metrics = compute_metrics(recon_v_traces, v_traces, mean, logvar)
        return metrics, comparison, generate_v_traces

    return nn.apply(eval_model, model())({"params": params})

In [3]:
rng = random.key(seed)
rng, state = init_model(rng)

rng, z_key, eval_rng = random.split(rng, 3)

z = random.normal(z_key, (batch_size, latent_dim))  # prior
del z_key

epoch_metrics = []
for epoch in range(epochs):
    for batch in batches:
        rng, key = random.split(rng)
        state = train_step(state, batch, key)

    metrics, comparison, samples = eval_f(state.params, v_test, z, eval_rng)
    metrics["epoch"] = epoch + 1
    epoch_metrics.append(metrics)
    print(f"epoch: {epoch + 1}, ELBO: {metrics['elbo']:.4f}, MSE: {metrics['mse']:.4f}, KLD: {metrics['kld']:.4f}")

Enc in: (64, 10000)
Enc conv1: (64, 32)
Enc conv2: (64, 64)
Enc reshape: (64, 64)
Dec in: (64, 20)
Dec reshape1: (64, 10000, 64)
Dec convt1: (64, 10000, 64)
Dec convt2: (64, 10000, 32)
Dec convt3, recon_x: (64, 10000, 1)
recon_x flatten: (64, 10000)
Enc in: (64, 10000)
Enc conv1: (64, 32)
Enc conv2: (64, 64)
Enc reshape: (64, 64)
Dec in: (64, 20)
Dec reshape1: (64, 10000, 64)
Dec convt1: (64, 10000, 64)
Dec convt2: (64, 10000, 32)
Dec convt3, recon_x: (64, 10000, 1)
recon_x flatten: (64, 10000)
Enc in: (2000, 10000)
Enc conv1: (2000, 32)
Enc conv2: (2000, 64)
Enc reshape: (2000, 64)
Dec in: (2000, 20)
Dec reshape1: (2000, 10000, 64)
Dec convt1: (2000, 10000, 64)
Dec convt2: (2000, 10000, 32)
Dec convt3, recon_x: (2000, 10000, 1)
recon_x flatten: (2000, 10000)
Dec in: (64, 20)
Dec reshape1: (64, 10000, 64)
Dec convt1: (64, 10000, 64)
Dec convt2: (64, 10000, 32)
Dec convt3, recon_x: (64, 10000, 1)
recon_x flatten: (64, 10000)
epoch: 1, ELBO: 355.3152, MSE: 349.9854, KLD: 5.3298
