## Variational Autoencoders
### Notebook content modified from code given in the "Variational Autoencoder" section of "Deep Learning for Molecules and Materials\" textbook (https://dmol.pub/dl/VAE.html)
- ### White, A. D. (2022). Deep learning for molecules and materials. Living journal of computational molecular science, 3(1).

In [None]:
#Install packages
!pip install numpy matplotlib seaborn jax

## VAE for Discrete Data

Note: The features are classes; we *are not* trying to make a classifier that takes in features and outputs classes. VAEs are for unlabeled data.


Our first example will be to generate new example classes from a distribution of possible classes. An application for this might be to sample conditions of an experiment.
- **Q: Can you think of any other applications for using a VAE with discrete data?**

Task overview:
- Features: $x$, represented as one-hot vectors corresponding to a class
- Goal: Learn the distribution $P(x)$ so that we can sample new $x$'s
 - **Q: Along with being able to sample new points, what else can we do after learning the latent space?**
   - Learning the latent space can also provide a way to embed your features into low dimensional continuous vectors, allowing you to do things like optimization because you've moved from discrete classes to continuous vectors
   - This is an extra benefit, our loss and training goal are to create a new $P(x)$


Implementing the encoder & decoder:
- **Q: What are the inputs & outputs for the encoder?**
 - The encoder $q_\phi(z | x)$ should output a *probability distribution* for vectors of real numbers $z$ and take an input of a one-hot vector $x$
 - **Q: How can we have the encoder output a probability distribution? Is there a simpler way to do this?**
   - We defined $P(z)$ to be normally distributed, let's assume that the form of $q_\phi(z | x)$ should be normal. Then our neural network could output the parameters to a normal distribution (mean/variance) for $z$, rather than trying to output a probability at every possible $z$ value. It's up to you if you want to have $q_\phi(z | x)$ output a D-dimensional Gaussian distribution with a covariance matrix or just output D independent normal distributions. Having $q_\phi(z | x)$ output a normal distribution also allows us to analytically simplify the expectation/integral in the KL-divergence term.

- **Q: What are the inputs & outputs for the decoder?**
  - The decoder $p_\theta(x | z)$ should output a probability distribution over classes given the input, a real vector $z$.
   - We can use the same form we use for classification: softmax activation. Just remember that we're not trying to output a specific $x$, just a probability distribution of $x$'s.
   - Softmax is similar to the sigmoid function, but is used when there are multiple classes

- **Q: What parameters do we have control over/can change for the VAE?**
  - The hyperparameters of the encoder and decoder and the size of $z$. It makes sense to have the encoder and decoder share as many hyperparameters as possible, since they're somewhat symmetric. Just remember that the encoder in our example is outputting a mean and variance, which means using regression, and the decoder is outputting a normalized probability vector, which means using softmax. Let's get started!

### The Data
The data is 1024 points $\vec{x}_i$ where each $\vec{x}_i$ is a 32 dimensional one-hot vector indicating class. We won't define the classes -- the data is synthetic. Since a VAE is unsupervised learning, there are no labels. Let's start by examining the data. We'll sum the occurrences of each class to see what the distribution of classes looks like.

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import urllib
import seaborn as sns
import jax.numpy as jnp
from jax.example_libraries import optimizers
import jax

In [None]:
# Generate data to use for training (no labels, just a vector indicating the class a specific x_i belongs to)
sampled_z = np.random.choice([0, 1], size=1024) # Generates a random array of 1s and 0s (size = 1024)

# Generate array of shape (1024, 32), where each entry is a one-hot vector of length 32
data = ((sampled_z + 1) % 2) * np.random.normal(
    size=sampled_z.shape, loc=-1, scale=0.5
) + sampled_z * np.random.normal(size=sampled_z.shape, loc=1, scale=0.25)

nbins = 32
_, bins = np.histogram(data, bins=nbins)

class_data = np.apply_along_axis(lambda x: np.histogram(x, bins)[0], 1, data.reshape(-1, 1))

nclasses = nbins

In [None]:
# Print out a couple x_i values
print(f'x_8: {class_data[8]}')
print(f'x_24: {class_data[24]}')

In [None]:
# Visualize distribution of classes
plt.bar(np.arange(nclasses), height=np.sum(class_data, axis=0))
plt.xlabel("Class Index")
plt.ylabel("Frequency")
plt.show()

### The Encoder & Decoder

Our encoder will be a basic two hidden layer network.

We will output a $D\times2$ matrix, where the first column is means and the second is standard deviations for independent normal distributions that make up our guess for $q(z | x)$. Outputting a mean is simple, just use no activation. Outputting a standard deviation is unusual because they should be on $(0, \infty)$. `jax.nn.softplus` can accomplish this.

The decoder should output a vector of probabilities for $\vec{x}$. This can be achieved by just adding a softmax to the output. The rest is nearly identical to the encoder.

In [None]:
# Helper function for generating a random vector of a specific size from the normal distribution
def random_vec(size):
    return np.random.normal(size=size, scale=1)

In [None]:
# Define parameters
latent_dim = 1
hidden_dim = 16
input_dim = nclasses # = 32; defined this above when generating the data


def encoder(x, theta):
    """The encoder takes as input x and gives out probability of z,
    expressed as normal distribution parameters. Assuming each z dim is independent,
    output |z| x 2 matrix"""
    w1, w2, w3, b1, b2, b3 = theta #get weights
    hx = jax.nn.relu(w1 @ x + b1) #hidden layer 1 (dense layer)
    hx = jax.nn.relu(w2 @ hx + b2) #hidden layer 2 (dense layer)
    out = w3 @ hx + b3 #output - get mean & standard deviation from this
    # slice out stddeviation and make it positive
    reshaped = out.reshape((-1, 2))
    # we slice with ':' to keep rank same
    std = jax.nn.softplus(reshaped[:, 1:])
    mu = reshaped[:, 0:1]
    return jnp.concatenate((mu, std), axis=1)


def init_theta(input_dim, hidden_units, latent_dim):
    """Create inital theta parameters"""
    w1 = random_vec(size=(hidden_units, input_dim))
    b1 = np.zeros(hidden_units)
    w2 = random_vec(size=(hidden_units, hidden_units))
    b2 = np.zeros(hidden_units)
    # need two params per dim (mean, std)
    w3 = random_vec(size=(latent_dim * 2, hidden_units))
    b3 = np.zeros(latent_dim * 2)
    return [w1, w2, w3, b1, b2, b3]


# test encoder
theta = init_theta(input_dim, hidden_dim, latent_dim)
encoder_result = encoder(class_data[0], theta)
print(f'Shape of encoder output: {np.shape(encoder_result)}')
encoder_result

In [None]:
def decoder(z, phi):
    """decoder takes as input the latent variable z and gives out probability of x.
    Decoder outputes a real number, then we use softmax activation to get probability across
    possible values of x.
    """
    w1, w2, w3, b1, b2, b3 = phi
    hz = jax.nn.relu(w1 @ z + b1) #hidden layer
    hz = jax.nn.relu(w2 @ hz + b2) #hidden layer
    out = jax.nn.softmax(w3 @ hz + b3) #output
    return out


def init_phi(input_dim, hidden_units, latent_dim):
    """Create inital phi parameters"""
    w1 = random_vec(size=(hidden_units, latent_dim))
    b1 = np.zeros(hidden_units)
    w2 = random_vec(size=(hidden_units, hidden_units))
    b2 = np.zeros(hidden_units)
    w3 = random_vec(size=(input_dim, hidden_units))
    b3 = np.zeros(input_dim)
    return [w1, w2, w3, b1, b2, b3]


# test decoder
phi = init_phi(input_dim, hidden_dim, latent_dim)
decoder_result = decoder(np.array([1.2] * latent_dim), phi)
print(f'Shape of decoder output: {np.shape(decoder_result)}')
decoder_result

### Training

We use ELBO equation for training:

$$
l = -\textrm{E}_{z \sim q_\phi(z | x_i)}\left[\log p_{\theta}(x_i | z)\right] + \textrm{KL}\left[(q_\phi(z | x))|| P(z)\right]
$$

where $P(z)$ is the standard normal distribution and we approximate expectations using a single sample from the encoder. We need to expand the KL-divergence term to implement this. Both $P(z)$ and $q_\theta(z | x)$ are normal. You can look-up the KL-divergence between two normal distributions:

\begin{equation}
KL(q, p) = \log \frac{\sigma_p}{\sigma_q} + \frac{\sigma_q^2 + (\mu_q - \mu_p)^2}{2 \sigma_p^2} - \frac{1}{2}
\end{equation}

we can simplify because $P(z)$ is standard normal ($\sigma = 1, \mu = 0$)

\begin{equation}
\textrm{KL}\left[(q_\theta(z | x_i))|| P(z)\right] = -\log \sigma_i + \frac{\sigma_i^2}{2} + \frac{\mu_i^2}{2} - \frac{1}{2}
\end{equation}

where $\mu_i, \sigma_i$ are the output from $q_\phi(z | x_i)$.

For a latent space dimension of 3, $x_i$ has the shape (1,n_classes), $\sigma_i$ has the shape (1,3), and $\mu_i$ has the shape (1,3).

In [None]:
@jax.jit
def loss(x, theta, phi, rng_key):
    """VAE Loss"""
    # reconstruction loss
    sampled_z_params = encoder(x, theta)
    # reparameterization trick
    # we use standard normal sample and multiply by parameters
    # to ensure derivatives correctly propogate to encoder
    sampled_z = (
        jax.random.normal(rng_key, shape=(latent_dim,)) * sampled_z_params[:, 1]
        + sampled_z_params[:, 0]
    )
    # log of prob
    rloss = -jnp.log(decoder(sampled_z, phi) @ x.T + 1e-8)
    # Array of KL loss for x_i (dimension = latent space dimension)
    klloss = (
        -0.5
        - jnp.log(sampled_z_params[:, 1])
        + 0.5 * sampled_z_params[:, 0] ** 2
        + 0.5 * sampled_z_params[:, 1] ** 2
    )
    # combined
    return jnp.array([rloss, jnp.mean(klloss)])


# test loss function
loss(class_data[0], theta, phi, jax.random.PRNGKey(0))

Our loss works! Now we need to make it batched so we can train in batches. Luckily this is easy with `vmap<jax.vmap>`.

In [None]:
batched_loss = jax.vmap(loss, in_axes=(0, None, None, None), out_axes=0)
batched_decoder = jax.vmap(decoder, in_axes=(0, None), out_axes=0)
batched_encoder = jax.vmap(encoder, in_axes=(0, None), out_axes=0)

# test batched loss
batched_loss(class_data[:4], theta, phi, jax.random.PRNGKey(0))

We'll make our gradient take the average over the batch

In [None]:
grad = jax.grad(
    lambda x, theta, phi, rng_key: jnp.mean(batched_loss(x, theta, phi, rng_key)),
    (1, 2),
)
fast_grad = jax.jit(grad)
fast_loss = jax.jit(batched_loss)

Alright, great! An important detail we've skipped so far is that when using `jax` to generate random numbers, we must step our random number generator forward. You can do that using `jax.random.split`. Otherwise, you'll get the same random numbers at each draw.

We're going to use a `jax` optimizer here. This is to simplify parameter updates. We have a lot of parameters and they are nested, which will be complex for treating with python for loops.

In [None]:
#Train for 16 epochs
batch_size = 32
epochs = 16

key = jax.random.PRNGKey(0)
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-1)
theta0 = init_theta(input_dim, hidden_dim, latent_dim)
phi0 = init_phi(input_dim, hidden_dim, latent_dim)
opt_state = opt_init((theta0, phi0))
losses = []
for e in range(epochs):
    for bi, i in enumerate(range(0, len(data), batch_size)):
        # make a batch into shape B x 1
        batch = class_data[i : (i + batch_size)]
        # update random number key
        key, subkey = jax.random.split(key)
        # get current parameter values from optimizer
        theta, phi = get_params(opt_state)
        last_state = opt_state
        # compute gradient and update
        grad = fast_grad(batch, theta, phi, key)
        opt_state = opt_update(bi, grad, opt_state)
        lvalue = jnp.mean(fast_loss(batch, theta, phi, subkey), axis=0)
        losses.append(lvalue)

In [None]:
#Plot loss vs epoch results
plt.plot([l[0] for l in losses], label="Reconstruction")
plt.plot([l[1] for l in losses], label="KL")
plt.plot([l[1] + l[0] for l in losses], label="ELBO")
plt.legend()
plt.ylim(-5, 5)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

### Evaluating the VAE

Remember our goal with the VAE is to reproduce $P(x)$. We can sample from our VAE using the chosen $P(z)$ and our decoder. Let's compare that distribution with our training distribution.

In [None]:
#Plot distribution of training data & values sampled from VAE

zs = np.random.normal(size=(1024, 1))
sampled_x = batched_decoder(zs, phi)
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
axs[0].set_title("Training Data")
axs[0].bar(np.arange(nbins), height=np.sum(class_data, axis=0))
axs[0].set_xlabel("Class Index")
axs[0].set_ylabel("Frequency")
axs[1].set_title("VAE Samples")
axs[1].bar(np.arange(nbins), height=np.sum(sampled_x, axis=0))
axs[1].set_xlabel("Class Index")
plt.tight_layout()
plt.show()

It appears we have succeeded! There were two more goals of the VAE model: making the encoder give output similar to $P(z)$ and be able to reconstruct. These goals are often opposed and they represent the two terms in the loss: reconstruction and KL-divergence. Let's examine the KL-divergence term, which causes the encoder to give output similar to a standard normal. We'll sample from our training data in histogram look at the resulting average mean and std dev.

In [None]:
d = batched_encoder(class_data, theta)
print("Average mu = ", np.mean(d[..., 0]), "Average std dev = ", np.mean(d[..., 1]))

Wow! Very close to a standard normal. So our model satisfied the match between the decoder and the $P(z)$. The last thing to check is reconstruction. These are distributions, so I'll only look at the maximum $z$ value to do the reconstruction.

In [None]:
plt.plot(decoder(encoder(class_data[2], theta)[0:1, 0], phi), label="P(x)")
plt.axvline(np.argmax(class_data[2]), color="C1", label="x")
plt.legend()
plt.show()

The reconstruction is not great, it puts a lot of probability mass on other points. In fact, the reconstruction seems to not use the encoder's information at all -- it looks like $P(x)$. The reason for this is that our KL-divergence term dominates. It has a very good fit.

## Re-balancing VAE Reconstruction and KL-Divergence

Often we desire more reconstruction at the cost of making the latent space less normal. This can be done by adding a term that adjusts the balance between the reconstruction loss and the KL-divergence. You would choose to do this if you want to use the latent space for something and are not just interested in creating a model $\hat{P}(x)$. Here is the modified ELBO equation for training:

$$
l = -\textrm{E}_{z \sim q_\phi(z | x_i)}\left[\log p_{\theta}(x_i | z)\right] + \beta\cdot\textrm{KL}\left[(q_\phi(z | x))|| P(z)\right]
$$

where $\beta > 1$ emphasizes the encoder distribution matching chosen latent distribution (standard normal) and $\beta < 1$ emphasizes reconstruction accuracy.

In [None]:
def modified_loss(x, theta, phi, rng_key, beta):
    """This loss allows you to vary which term is more important
    with beta. Beta = 0 - all reconstruction, beta = 1 - ELBO"""
    bl = batched_loss(x, theta, phi, rng_key)
    l = bl @ jnp.array([1.0, beta])
    return jnp.mean(l)


new_grad = jax.grad(modified_loss, (1, 2))
fast_grad = jax.jit(new_grad)

In [None]:
#Train for 32 epochs
#note we used a lower step size for this loss and more epochs
opt_init, opt_update, get_params = optimizers.adam(step_size=5e-2)
epochs = 32
theta0 = init_theta(input_dim, hidden_dim, latent_dim)
phi0 = init_phi(input_dim, hidden_dim, latent_dim)
opt_state = opt_init((theta0, phi0))
beta = 0.2
losses = []
for e in range(epochs):
    for bi, i in enumerate(range(0, len(data), batch_size)):
        # make a batch into shape B x 1
        batch = class_data[i : (i + batch_size)]
        # udpate random number key
        key, subkey = jax.random.split(key)
        # get current parameter values from optimizer
        theta, phi = get_params(opt_state)
        last_state = opt_state
        # compute gradient and update
        grad = fast_grad(batch, theta, phi, key, beta)
        opt_state = opt_update(bi, grad, opt_state)
        lvalue = jnp.mean(fast_loss(batch, theta, phi, subkey), axis=0)
        losses.append(lvalue)

In [None]:
#Plot Loss vs Epoch results
plt.plot([l[0] for l in losses], label="Reconstruction")
plt.plot([l[1] for l in losses], label="KL")
plt.plot([l[1] + l[0] for l in losses], label="ELBO")
plt.legend()
plt.ylim(-5, 5)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

You can see the error is higher, but let's see how it did at our three metrics.

In [None]:
#Plot distribution of training data & values sampled from VAE
zs = np.random.normal(size=(1024, 1))
sampled_x = batched_decoder(zs, phi)
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
axs[0].set_title("Training Data")
axs[0].bar(np.arange(nbins), height=np.sum(class_data, axis=0))
axs[0].set_xlabel("Class Index")
axs[0].set_ylabel("Frequency")
axs[1].set_title("VAE Samples")
axs[1].bar(np.arange(nbins), height=np.sum(sampled_x, axis=0))
axs[1].set_xlabel("Class Index")
plt.tight_layout()
plt.show()

A little bit worse on $P(x)$, but overall not bad. What about our goal, the reconstruction?

In [None]:
plt.plot(decoder(encoder(class_data[4], theta)[0:1, 0], phi), label="P(x)")
plt.axvline(np.argmax(class_data[4]), color="C1", label="x")
plt.legend()
plt.show()

What about our encoder's agreement with a standard normal?

In [None]:
d = batched_encoder(class_data, theta)
print("Average mu = ", np.mean(d[..., 0]), "Average std dev = ", np.mean(d[..., 1]))

The standard deviation is much smaller! So we squeezed our latent space a little at the cost of better reconstruction.

### Disentangling $\beta$-VAE

You can adjust $\beta$ the opposite direction, to value matching the prior Gaussian distribution more strongly. This can better condition the encoder so that each of the latent dimensions are truly independent. This can be important if you want to disentangle your input features to arrive at an orthogonal projection. This of course comes at the loss of reconstruction accuracy, but can be more important if you're interested in the latent space rather than generating new samples (Mathieu et al. 2019).

References: <br>
Mathieu, E., Rainforth, T., Siddharth, N., & Teh, Y. W. (2019, May). Disentangling disentanglement in variational autoencoders. In International conference on machine learning (pp. 4402-4412). PMLR.

## Bead-Spring Polymer VAE

This polymer has each bead (atom) joined by a harmonic bond, a harmonic angle between each three, and a Lennard-Jones interaction potential. Knowing these items will not be necessary for the example. Each of our data points below will be the x and y coordinates of 12 beads. We'll construct a VAE that can compress the trajectory to some latent space and generate new conformations.

Since we're now work with continuous features $x$. We need to make a few key changes. The encoder will remain the same, but the decoder now must output a $p_\theta(x | z)$ that gives a probability to all possible $x$ values. Above, we only had a finite number of classes but now any $x$ is possible. As we did for the encoder, we'll assume that $p_\theta(x | z)$ should be normal and we'll output the parameters of the normal distribution from our network. This requires an update to the reconstruction loss to be a log of a normal, but otherwise things will be identical.

Don't worry about this too much but a small detail is that the log-likelihood for a normal distribution with a single observation cannot have unknown standard deviation. Our new normal distribution parameters for the decoder will have a single observation for a single $x$ in training. If you make the standard deviation trainable, it will just pick infinity as the standard deviation since that will for sure capture the point and you only have one point. Thus, I'll make the decoder standard deviation be a hyperparameter. We don't see this issue with the encoder, which also outputs a normal distribution, because we training the encoder with the KL-divergence term and not likelihood of observations (reconstruction loss).

To begin, we'll need to align points from a trajectory in order to ensure the data is translationally and rotationally invariant. This will then serve as our training data. The space of our problem will be 12 2D vectors. Our system need not be permutation invariant, so we can flatten these vectors into a 24 dimensional input. The code belows loads and aligns the trajectory


In [None]:
# imports
import numpy as np
import jax.numpy as jnp
from jax.example_libraries import optimizers
import jax
import matplotlib.pyplot as plt
import urllib
import seaborn as sns

In [None]:
###---------Transformation Functions----###
# for rotational/translational invariance

def center_com(paths):
    """Align paths to COM at each frame"""
    # center of mass
    coms = np.mean(paths, axis=-2, keepdims=True)
    return paths - coms

def make_2drot(angle):
    """Defines a rotation matrix"""
    mats = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
    # swap so batch axis is first
    return np.swapaxes(mats, 0, -1)

def find_principle_axis(points):
    """Compute single principle axis for points"""
    inertia = points.T @ points
    evals, evecs = np.linalg.eigh(inertia)
    # get biggest eigenvalue
    order = np.argsort(evals)
    return evecs[:, order[-1]]

def align_principle(paths, axis_finder=find_principle_axis):
    """Rotates the data at each frame
    so that the principle axis remains constant"""
    vecs = [axis_finder(p) for p in paths]
    vecs = np.array(vecs)
    # find angle to rotate so these are pointed towards pos x
    cur_angle = np.arctan2(vecs[:, 1], vecs[:, 0])
    rot_angle = -cur_angle
    rot_mat = make_2drot(rot_angle)
    return paths @ rot_mat

In [None]:
# import data
urllib.request.urlretrieve(
    "https://github.com/whitead/dmol-book/raw/main/data/long_paths.npz",
    "long_paths.npz",
)
paths = np.load("long_paths.npz")["arr"]
# transform to be rot/trans invariant
data = align_principle(center_com(paths))
# visualize all the data
cmap = plt.get_cmap("cool")
for i in range(0, data.shape[0], 16):
    plt.plot(data[i, :, 0], data[i, :, 1], "-", alpha=0.1, color="C2")
plt.title("All Frames")
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# visualize a single data point
plt.plot(data[3, :, 0], data[3, :, 1], "-", color="C2")

Before training, let’s examine some of the marginals of the data to visualize the data. Marginals mean we’ve transformed (by integration) our probability distribution to be a function of only 1-2 variables so that we can plot nicely. We’ll look at the pairwise distance between beads (beads are indexed 0-11).

In [None]:
fig, axs = plt.subplots(ncols=4, squeeze=True, figsize=(16, 4))
for i, j in enumerate(range(1, 9, 2)):
    axs[i].set_title(f"Dist between 0-{j}")
    sns.histplot(np.linalg.norm(data[:, 0] - data[:, j], axis=1), ax=axs[i], kde=True, stat="density")
plt.tight_layout()

These look a little like the chi distribution with two degrees of freedom. Notice that the support (x-axis) changes between them though. We’ll keep an eye on these when we evaluate the efficacy of our VAE.

### VAE Model

Now we'll build the VAE similar to before.

In [None]:
input_dim = 12 * 2 # 12 points each with x and y coordinates
hidden_units = 256 # dimension of hidden layers in encoder/decoder
num_layers = 4 # number of hidden layers
latent_dim = 2 # dimension of z

# randomly initialize weights for the decoder
def init_theta(input_dim, hidden_units, latent_dim, num_layers, key, scale=0.1):
    key, subkey = jax.random.split(key)
    # theta[0] takes a vector from latent layer to a hidden layer
    w1 = jax.random.normal(key=subkey, shape=(hidden_units, latent_dim)) * scale
    b1 = jnp.zeros(hidden_units)
    theta = [(w1, b1)]
    # theta[i] takes a vector from hidden layer to hidden layer
    for i in range(1, num_layers - 1):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(key=subkey, shape=(hidden_units, hidden_units)) * scale
        b = jnp.zeros(hidden_units)
        theta.append((w, b))
    key, subkey = jax.random.split(key)
    # theta[-1] takes a vector from hidden layer to output layer
    w = jax.random.normal(key=subkey, shape=(input_dim, hidden_units)) * scale
    b = jnp.zeros(input_dim)
    theta.append((w, b))
    return theta, key

def decoder(z, theta):
    num_layers = len(theta)
    for i in range(num_layers - 1):
        w, b = theta[i]
        # dense layer with relu activation
        z = jax.nn.relu(w @ z + b)
    w, b = theta[-1]
    # dense layer to get output
    x = w @ z + b
    # returning x which is the mean of the distribution
    return x

# randomly initialize weights for the encoder
def init_phi(input_dim, hidden_units, latent_dim, num_layers, key, scale=0.1):
    key, subkey = jax.random.split(key)
    # phi[0] takes a vector from input layer to a hidden layer
    w1 = jax.random.normal(key=subkey, shape=(hidden_units, input_dim)) * scale
    b1 = jnp.zeros(hidden_units)
    phi = [(w1, b1)]
    # phi[i] takes a vector from hidden layer to hidden layer
    for i in range(1, num_layers - 1):
        key, subkey = jax.random.split(key)
        w = jax.random.normal(key=subkey, shape=(hidden_units, hidden_units)) * scale
        b = jnp.zeros(hidden_units)
        phi.append((w, b))
    key, subkey = jax.random.split(key)
    # phi[-1] takes a vector from hidden layer to latent layer
    w = jax.random.normal(key=subkey, shape=(latent_dim * 2, hidden_units)) * scale
    b = jnp.zeros(latent_dim * 2)
    phi.append((w, b))
    return phi, key

# returns the mean and std deviation for the distribution in the latent space
def encoder(x, phi):
    num_layers = len(phi)
    for i in range(num_layers - 1):
        w, b = phi[i]
        # dense layer with relu activation
        x = jax.nn.relu(w @ x + b)
    w, b = phi[-1]
    # dense layer to get output
    hz = w @ x + b
    hz = hz.reshape(-1, 2)
    mu = hz[:, 0:1]
    # softplus ensures standard deviation is in the range [0, infty)
    std = jax.nn.softplus(hz[:, 1:2])
    return jnp.concatenate((mu, std), axis=1)

### Loss

The loss function is similar to above, but I will not bother sampling from the decoded distribution, and instead just take the value outputted from the decoder. You can see the only change is that we drop the output Gaussian standard deviation from the loss, which remember was not trainable anyway.

In [None]:
@jax.jit
def loss(x, theta, phi, rng_key):
    """VAE Loss"""
    # reconstruction loss
    sampled_z_params = encoder(x, phi)
    # reparameterization trick
    # we use standard normal sample and multiply by parameters
    # to ensure derivatives correctly propogate to encoder
    sampled_z = (
        jax.random.normal(rng_key, shape=(latent_dim,)) * sampled_z_params[:, 1]
        + sampled_z_params[:, 0]
    )
    # MSE to compute reconstruction loss
    xp = decoder(sampled_z, theta)
    rloss = jnp.sum((xp - x) ** 2)

    # KL divergence loss
    klloss = (
        -0.5
        - jnp.log(sampled_z_params[:, 1] + 1e-8)
        + 0.5 * sampled_z_params[:, 0] ** 2
        + 0.5 * sampled_z_params[:, 1] ** 2
    )
    # combined
    return jnp.array([rloss, jnp.mean(klloss)])


# update compiled functions
# vmap allows a function to compute on a vector
# by performing the operation entrywise
batched_loss = jax.vmap(loss, in_axes=(0, None, None, None), out_axes=0)
batched_decoder = jax.vmap(decoder, in_axes=(0, None), out_axes=0)
batched_encoder = jax.vmap(encoder, in_axes=(0, None), out_axes=0)

In [None]:
# incoorporates rebalancing into loss
def modified_loss(x, theta, phi, rng_key, beta):
    """This loss allows you to vary which term is more important
    with beta. Beta = 0 - all reconstruction, beta = 1 - ELBO"""
    bl = batched_loss(x, theta, phi, rng_key)
    l = bl @ jnp.array([1.0, beta])
    return jnp.mean(l)

# use modified_loss to compute gradients
grad = jax.grad(modified_loss, (1, 2))
fast_grad = jax.jit(grad)
fast_loss = jax.jit(batched_loss)

### Training

Finally comes the training. We'll flatten our input data and shuffle to prevent each batch from having similar conformations.

In [None]:
batch_size = 32
epochs = 250
key = jax.random.PRNGKey(0)

flat_data = data.reshape(-1, input_dim)
# scramble it
flat_data = jax.random.permutation(key, flat_data, independent = True)
# optimizers from jax used to update parameters with stochastic gradient descent
# step size is the same as learning rate
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
# initialize theta and phi randomly
theta0, key = init_theta(input_dim, hidden_units, latent_dim, num_layers, key)
phi0, key = init_phi(input_dim, hidden_units, latent_dim, num_layers, key)
opt_state = opt_init((theta0, phi0))
losses = []
# KL/Reconstruction balance
# beta close to 0 favors reconstruction (see the loss plot below)
beta = 0.01

for e in range(epochs):
    # bi = batch number, i = index in the data
    for bi, i in enumerate(range(0, len(flat_data), batch_size)):
        # make a batch into shape B x 1
        batch = flat_data[i : (i + batch_size)].reshape(-1, input_dim)
        # update random number key
        key, subkey = jax.random.split(key)
        # get current parameter values from optimizer
        theta, phi = get_params(opt_state)
        last_state = opt_state
        # compute gradient and update
        grad = fast_grad(batch, theta, phi, key, beta)
        opt_state = opt_update(bi, grad, opt_state)
    # use large batch for tracking progress
    lvalue = jnp.mean(fast_loss(flat_data[:100], theta, phi, subkey), axis=0)
    losses.append(lvalue)

In [None]:
plt.plot([l[0] for l in losses], label="Reconstruction")
plt.plot([l[1] for l in losses], label="KL")
plt.plot([l[1] + l[0] for l in losses], label="ELBO")
plt.legend()
plt.ylim(0, 20)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

This model is undertrained. A latent space of 2, which we chose for plotting convenience, is also probably a little too compressed. Let's sample a conformation and see how it looks.

In [None]:
sampled_data = decoder(jax.random.normal(key, shape=[latent_dim]), theta).reshape(-1, 2)
plt.plot(sampled_data[:, 0], sampled_data[:, 1], "-o", alpha=1)
plt.xticks([])
plt.yticks([])
plt.show()

### Generate New Samples

Let's see how our samples look.

In [None]:
# visualize all the sampled data
fig, axs = plt.subplots(ncols=2, figsize=(12, 4))
sampled_data = batched_decoder(
    np.random.normal(size=(data.shape[0], latent_dim)), theta
).reshape(data.shape[0], -1, 2)
# visualize the training data
for i in range(0, data.shape[0]):
    axs[0].plot(data[i, :, 0], data[i, :, 1], "-", alpha=0.1, color="C2")
    axs[1].plot(
        sampled_data[i, :, 0], sampled_data[i, :, 1], "-", alpha=0.1, color="C2"
    )
axs[0].set_title("Training")
axs[1].set_title("Generated")
for i in range(2):
    axs[i].set_xticks([])
    axs[i].set_yticks([])
plt.show()

The samples are not perfect, but we're close. Let's examine the marginals to see how the distribution for the distances between beads for the generated data differs from the training data.

In [None]:
fig, axs = plt.subplots(ncols=4, squeeze=True, figsize=(16, 4))
for i, j in enumerate(range(1, 9, 2)):
    axs[i].set_title(f"Dist between 0-{j}")
    sns.histplot(np.linalg.norm(data[:, 0] - data[:, j], axis=1), ax=axs[i], kde=True, stat="density", label="training distribution")
    sns.kdeplot(
        np.linalg.norm(sampled_data[:, 0] - sampled_data[:, j], axis=1),
        ax=axs[i], color="red", label="sampled distribution"
    )
plt.legend()
plt.tight_layout()

You can see that there are some issues here as well. Remember that our  latent space is quite small: 2D. So we should not be that surprised that we're losing information from our 24D input space. I encourage you to play with some of the hyperparameters such as latent space dimension and beta to see if you can get better marginals.