## More with embeddings: structural similarity

Protein sequences can have similar folded structure (and in turn, similar function) but often fairly diverging sequences. This seemingly contradictory behavior arises from the principle of co-evolution in proteins: when proteins do mutate, they tend to do so in pairs or groups such that they still have similar structure, to be able to carry out the original function.

A model purely looking at sequences would have very different embeddings for such protein pairs despite their similar structures. A model that can build an internal representation for the structure of a protein would be able to place embeddings close. This notebook briefly explores whether ESM-1b does so.

### Data
We first download the [SCOPe 2.07](https://scop.berkeley.edu/astral/ver=2.07) dataset, which contains sequence data for a diverse range of protein *domains* (not whole proteins), classified on structural features. We use the dataset with sequences having less than 40% similarity to each other.

In [1]:
!wget https://scop.berkeley.edu/downloads/scopeseq-2.07/astral-scopedom-seqres-gd-sel-gs-bib-40-2.07.fa -O ../data/scope207.fa

--2021-10-03 10:04:29--  https://scop.berkeley.edu/downloads/scopeseq-2.07/astral-scopedom-seqres-gd-sel-gs-bib-40-2.07.fa
Resolving scop.berkeley.edu (scop.berkeley.edu)... 128.32.236.13
Connecting to scop.berkeley.edu (scop.berkeley.edu)|128.32.236.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4026229 (3.8M)
Saving to: ‘../data/scope207.fa’


2021-10-03 10:04:29 (8.77 MB/s) - ‘../data/scope207.fa’ saved [4026229/4026229]



In [2]:
import re

from Bio import SeqIO

# we load in the proteins
records = list(SeqIO.parse("../data/scope207.fa", "fasta"))

# for each protein we seperate the name, sequence and structure label
names = [record.name for record in records]
sequences = [str(record.seq.upper()) for record in records]
structure_label = [re.search(r".[.]\d*[.]\d*[.]\d*", rec.description).group(0) for rec in records]

In [3]:
print(names[0], sequences[0], structure_label[0])

d1dlwa_ SLFEQLGGQAAVQAVTAQFYANIQADATVATFFNGIDMPNQTNKTAAFLCAALGGPNAWTGRNLKEVHANMGVSNAQFTTVIGHLRSALTGAGVAAALVEQTVAVAETVRGDVVTV a.1.1.1


For this example, we only look at domains with made of alpha-helices, that is a structure label starting with `a`.

In [4]:
# we zip up the records to make filtering easier
prot_domains = zip(names, sequences, structure_label)

filtered_prot_domains = filter(lambda x: x[2][0]=='a', prot_domains)
filtered_prot_domains = filter(lambda x: len(x[1]) <= 510, filtered_prot_domains)
names, sequences, structure_label = list(zip(*filtered_prot_domains))

In [5]:
len(sequences)

2505

### Model

Let's quickly set up inference across all accelerators (here, 8 TPU cores) to make quick work of embedding all 2505 protein domains. Similar to previously, we load in the model params and construct the model.

In [6]:
import sys
sys.path.insert(0, '..')

In [7]:
from esmjax import models, modelio, tokenize

import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp

In [8]:
params_dict = modelio.load_model("../data/esm1b.h5")

# Instead of calling jnp.numpy, which moves the array to the *first* device
# we use `device_put_replicated` to send a copy of weights to *all* devices
devices = jax.local_devices()
distrib_params_dict = jax.tree_map(lambda x: jax.device_put_replicated(x, devices), params_dict)

params = hk.data_structures.to_immutable_dict(distrib_params_dict)

A quick note is that running `device_put_replicated` on an array adds a new first dimension, which represents the copies of the array across all devices. The array itself is now a `ShardedDeviceArray`, which is logically one array, but is physically split across all devices. For example:

In [9]:
print(params['esm1b/embed']['embeddings'].shape, type(params['esm1b/embed']['embeddings']))

(8, 33, 1280) <class 'jax.interpreters.pxla.ShardedDeviceArray'>


Now we just construct the model, except this time, instead of `jit`-ing it, we `pmap` it. JAX takes care of the rest from here and it's parallelized:

In [10]:
esm1b_f = hk.transform(lambda x: models.ESM1b()(x))
esm1b_f = hk.without_apply_rng(esm1b_f)
esm1b_apply = jax.pmap(esm1b_f.apply)

## Computing Embeddings

We use a batch size of 16 protein domains per TPU, so 128 proteins at a time. All TPUs must be passed in an array of the exact same size, and 2505 isn't divisible by 128, so we add in some extra blank sequences at the end that we'll discard.

In [11]:
BATCH_SIZE = 16
NUM_DEVICES = len(jax.devices())

superbatch_size = BATCH_SIZE * NUM_DEVICES
num_sequences = len(sequences)
num_sequences_padded = int(np.ceil(num_sequences / superbatch_size) * superbatch_size)

sequences_list = list(sequences)
sequences_list.extend([""] * (num_sequences_padded - num_sequences))

names_list = list(names)
names_list.extend([None] * (num_sequences_padded - num_sequences))

Since the number of sequences is fairly small, it's fine to tokenize them all in one go.

In [12]:
names_list, sequences_list, tokens = tokenize.convert(zip(names_list, sequences_list))

We then embed all the sequences, in batches of size 16. This isn't particularly efficient (takes roughly a full minute on a TPU v2-8), and a better implementation of data infeed (such as by using `tf.data`) would allow it to scale to larger protein datasets. For this example we keep things relatively straightforward.

Note that we sum the embeddings along the sequence dimensions; we'll average them later by dividing by the length of each protein. (This approach is fine, as the padding and other "add-on" tokens have their embeddings masked out to 0)

In [13]:
embeddings = []

max_seq_len = tokens.shape[-1]

with jax.default_matmul_precision('float32'):
    for i in range(0, num_sequences_padded, superbatch_size):
        superbatch_tokens = tokens[i:i+superbatch_size, :]
        batch_tokens = jnp.reshape(superbatch_tokens, (NUM_DEVICES, BATCH_SIZE, max_seq_len))

        per_residue_embeddings = esm1b_apply(params, batch_tokens)["embeddings"]
        whole_prot_embeddings = per_residue_embeddings.sum(axis=-2)
        embeddings.append(whole_prot_embeddings.reshape(superbatch_size, 1280))

Concatenate the outputs of all the batches, and index out all the padding sequences, and convert the sum to a mean.

In [14]:
embeddings = np.concatenate(embeddings, axis=0)
embeddings = embeddings[:len(sequences), :]

prot_lens = np.array([len(seq) for seq in sequences]).reshape(-1, 1)
embeddings = embeddings / prot_lens

## Visualizing the embeddings

Now that we have the embeddings, we'd ideally like to see how they relate to each other, especially for protein domains with similar structural features. We'll be using the [TriMap](https://github.com/eamid/trimap) algorithm, which preserves global structure (and not just local ones) better than UMAP or t-SNE. At this point our embeddings are just `ndarray`s, so you can in theory use any analysis tool of your choosing.

In [15]:
import trimap

In [16]:
low_d_embeddings = trimap.TRIMAP(n_dims=3).fit_transform(embeddings)

TRIMAP(n_inliers=10, n_outliers=5, n_random=5, distance=euclidean, lr=1000.0, n_iters=400, weight_adj=500.0, apply_pca=True, opt_method=dbd, verbose=True, return_seq=False)
running TriMap on 2505 points with dimension 1280
pre-processing
applied PCA
found nearest neighbors
sampled triplets
running TriMap with dbd
Iteration:  100, Loss: 76.999, Violated triplets: 0.0559
Iteration:  200, Loss: 73.609, Violated triplets: 0.0534
Iteration:  300, Loss: 71.704, Violated triplets: 0.0520
Iteration:  400, Loss: 70.677, Violated triplets: 0.0513
Elapsed time: 0:00:05.541536


We color the points by the fold that they belong to; there are almost 300 fold types however (e.g. `a.1`, `a.2`..., you can learn more about what they correspond to [here](https://scop.berkeley.edu/sunid=46456&ver=2.07)), and a static plot would obscure a lot of important detail. The interactive plot saved below can be found [here](https://htmlpreview.github.io/?https://github.com/irhum/esm-jax/blob/main/notebooks/embeddings_vis.html). You can actually see how proteins of a similar fold pattern end up clustering together, even though their sequences are very different. 

In [17]:
folds = ['.'.join(label.split('.')[:2]) for label in structure_label]

In [18]:
import plotly.express as px

In [19]:
fig = px.scatter_3d(x=low_d_embeddings[:, 0], 
                    y=low_d_embeddings[:, 1],
                    z=low_d_embeddings[:, 2],
                    color=folds,
                    size_max=1)

In [20]:
fig.write_html('embeddings_vis.html')