In [19]:
import esm
import pandas as pd
import numpy as np
import plotly.express as px
from tqdm import tqdm
import torch
import umap
import colorsys

In [2]:
models = {
    6: "esm2_t6_8M_UR50D",
    12: "esm2_t12_35M_UR50D",
    30: "esm2_t30_150M_UR50D",
    33: "esm2_t33_650M_UR50D",
}

NUM_LAYERS = 30

model, alphabet = getattr(esm.pretrained, models[NUM_LAYERS])()
batch_converter = alphabet.get_batch_converter()
model.eval()


def run_esm(model, batch_converter, data, layers=range(NUM_LAYERS), contacts=False):
    """Runs the `model`, returns info from `layers`"""
    results = []
    for prot in tqdm(data):
        batch_labels, batch_strs, batch_tokens = batch_converter([prot])
        with torch.no_grad():
            i = model.forward(batch_tokens, repr_layers=layers, return_contacts=contacts)
            results.append(i)
    return results

In [4]:
seq = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFTYGVQCFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
data = [("prot1", seq)]

In [12]:
embeddings = run_esm(model, batch_converter, data)
embeddings = torch.stack([embeddings[0]["representations"][x] for x in range(NUM_LAYERS + 1)]).squeeze(1)
embeddings_list = [pd.DataFrame(embeddings[i].numpy()) for i in range(embeddings.shape[0])]

100%|██████████| 1/1 [00:01<00:00,  1.60s/it]


KeyError: 30

In [14]:
def make_relation(from_df, to_df):
    left = pd.DataFrame(data=np.arange(len(from_df)), index=from_df.index)
    right = pd.DataFrame(data=np.arange(len(to_df)), index=to_df.index)
    merge = pd.merge(left, right, left_index=True, right_index=True)
    return dict(merge.values)


relations = [make_relation(embeddings_list[i], embeddings_list[i + 1]) for i in range(len(embeddings_list) - 1)]

In [15]:
aligned_mapper = umap.AlignedUMAP(
    metric="euclidean",
    n_neighbors=5,
    alignment_regularisation=0.1,
    alignment_window_size=5,
    n_epochs=200,
    random_state=42,
    min_dist=0.5,
).fit(embeddings_list, relations=relations)

final_embeddings = aligned_mapper.embeddings_


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!


Graph is not fully connected, spectral embedding may not work as expected.


Graph is not fully connected, spectral embedding may not work as expected.


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling back to random initialisation!


failed. This is likely due to too small an eigengap. Consider
adding some noise or jitter to your data.

Falling ba

In [16]:
df = []
for idx, x in enumerate(final_embeddings):
    df_i = pd.DataFrame(x, columns=["x", "y"])
    df_i["layer"] = idx
    df_i["aminoacid"] = ["START"] + list(seq) + ["STOP"]
    df_i.reset_index(inplace=True)
    df.append(df_i)
df = pd.concat(df)

In [20]:
def generate_colors(n):
    hues = np.linspace(0.67, 0, n)  # Start from blue (0.67) to red (0)
    colors = [colorsys.hsv_to_rgb(h, 1, 1) for h in hues]  # saturation=1, value=1 for full color
    # Convert to plotly format: 'rgb(r,g,b)'
    colors = [f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" for r, g, b in colors]
    return colors

In [21]:
px.line_3d(
    df,
    x="x",
    y="layer",
    z="y",
    color="index",
    width=1500,
    height=1000,
    color_discrete_sequence=generate_colors(240),
    hover_name="aminoacid",
)