In [None]:
#| echo: false 
#| output: false
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import os 
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax.numpy as np  # import here so that any warnings about no GPU are not shown in website.
np.arange(3)
import shutup
shutup.please()

# Sampling Proteins

In this chapter, I am going to explore 
how we might sample new protein sequences using score models.
Score models are usually useful for sampling new _continuous_ data,
such as audio and images,
but I don't think we've seen much activity in the realm of sampling new _discrete_ data,
such as text.
This is something I have been intellectually interested in,
given my interest in sequence machine learning since graduate school,
where I built the seeds of what would become my Insight project,
the [Flu Sequence Forecaster](http://fluforecaster.herokuapp.com).

## Recap

As we've seen before, score models start with a _density_ of input data.
That can be 1D, such as Gaussian draws,
or 2D, such as the half-moons data,
or n-dimensional, such as numerical representations of proteins.
That's where we'll start, obtaining a numerical representation of proteins
by using a _non-variational_ autoencoder.

## Encoding Proteins

Our autoencoder model will be a relatively simple one: basically a linear autoencoder.

In [None]:
import equinox as eqx 
from jax import random 
from jax.nn import sigmoid
from jax.scipy.special import expit


class LinearAutoEncoder(eqx.Module):
    encoder: eqx.Module
    decoder: eqx.Module 

    def __init__(self, in_size: int, latent_dim_size: int = 512, key=random.PRNGKey(45)):
        encoder_key, decoder_key = random.split(key)
        self.encoder = eqx.nn.Linear(in_features=in_size, out_features=latent_dim_size, key=encoder_key)
        self.decoder = eqx.nn.Linear(in_features=latent_dim_size, out_features=in_size, key=decoder_key)

    def __call__(self, x):
        latent = self.encode(x)
        out = self.decoder(latent)
        return sigmoid(out)

    def encode(self, x):
        return self.encoder(x)


lae = LinearAutoEncoder(2048)
one_batch = random.normal(key=random.PRNGKey(42), shape=(2048,))

lae(one_batch).min(), lae(one_batch).max()

Now, we will grab a bunch of real protein sequences to play around with.
The FASTA file we will take is from my flu forecaster repository.

In [None]:
import wget


url = "https://raw.githubusercontent.com/ericmjl/flu-sequence-predictor/master/data/20170531-H3N2-global.fasta"
filename = wget.download(url, out="/tmp/h3n2.fasta")
filename

## Obtain a multiple sequence alignment

One of the easiest (though not the only) ways 
to obtain numerical representations for a linear autoencoder 
is to use SeqLike to generate a multiple sequence alignment
and then convert the alignment into a one-hot encoded NumPy array.

In [None]:
from Bio import SeqIO
from seqlike import aaSeqLike 
import pandas as pd
from tqdm.auto import tqdm

seqs = [aaSeqLike(s) for s in tqdm(SeqIO.parse(filename, "fasta"))]

In [None]:
seqs = pd.Series(seqs).sample(2000, random_state=44).seq.align()
seqs

In [None]:
seqs_oh = seqs.seq.to_onehot()


In [None]:

from jax import vmap 
import jax.numpy as np

def flatten(x: np.ndarray):
    return x.flatten()

seqs_flattened = vmap(flatten)(seqs_oh)

In [None]:
def binary_cross_entropy(y_hat, y):
    xent = y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat)
    return -np.mean(xent)

binary_cross_entropy(lae(one_batch), np.round(np.clip(one_batch, 0, 1)))
# binary_cross_entropy(one_batch, one_batch)

In [None]:
binary_cross_entropy(one_batch, one_batch)

In [None]:
def binary_cross_entropy_loss(model, y, tol=1e-6):
    """Binary cross entropy loss function.
    
    :param y_hat: Batches (n >= 1) of predictions.
    :param y: Batches (n >= 1) of ground truth data.
    :returns: Scalar loss.
    """
    y_hat = vmap(model)(y)
    y_hat = np.clip(y_hat, tol, 1 - tol)
    xents = binary_cross_entropy(y_hat, y)
    return np.mean(xents)

In [None]:
model = LinearAutoEncoder(in_size=len(seqs_flattened[0]))

binary_cross_entropy_loss(model, seqs_flattened)

In [None]:
encoded = model.encode(seqs_flattened[0])
out = model.decoder(encoded)
sigmoid(out), model(seqs_flattened[0])
# vmap(sigmoid)(out)


vmap(model)(seqs_flattened)

In [None]:
seqs_flattened[0].shape

In [None]:
# Train model
import optax 


model = LinearAutoEncoder(in_size=len(seqs_flattened[0]))
optimizer = optax.chain(optax.adam(5e-3), optax.clip(0.001))
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
dloss = eqx.filter_value_and_grad(binary_cross_entropy_loss)

n_steps = 20
iterator = tqdm(range(n_steps))
loss_history = []
key = random.PRNGKey(555)
keys = random.split(key, n_steps)

for step in iterator:
    loss_score, grads = dloss(model, seqs_flattened)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    iterator.set_description(f"Score· {loss_score:.3f}")
    loss_history.append(float(loss_score))


In [None]:
import matplotlib.pyplot as plt
plt.plot(loss_history)

In [None]:
embeddings = vmap(model.encoder)(seqs_flattened)
embeddings

## Visualize

In [None]:
from umap import UMAP
import seaborn as sns
sns.set_context(context="notebook")

um = UMAP(random_state=212)
um_embed = um.fit_transform(embeddings)
plt.scatter(um_embed[:, 0], um_embed[:, 1])
plt.xlabel("Dim 1")
plt.ylabel("Dim 2")
plt.title("UMAP Embedding of Encoded Sequences")
plt.gca().set_aspect("equal")
sns.despine()

## Noising Data with Variance-Preserving SDE

If you noticed, in the previous chapter, whenever we noised up data,
our total data variance would also increase.