In [None]:
import esm
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
import torch
import umap
import colorsys
import random

In [None]:
models = {
    6: "esm2_t6_8M_UR50D",
    12: "esm2_t12_35M_UR50D",
    30: "esm2_t30_150M_UR50D",
    33: "esm2_t33_650M_UR50D",
}

NUM_LAYERS = 30

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = getattr(esm.pretrained, models[NUM_LAYERS])()
batch_converter = alphabet.get_batch_converter()
model.eval().to(device)

def run_esm(model, batch_converter, data, layers=range(NUM_LAYERS+1), contacts=False):
    """Runs the `model`, returns info from `layers`"""
    results = []
    for prot in tqdm(data):
        batch_labels, batch_strs, batch_tokens = batch_converter([prot])
        batch_tokens = batch_tokens.to(device)  # Ensure tokens are on the GPU
        with torch.no_grad():
            i = model.forward(batch_tokens, repr_layers=layers, return_contacts=contacts)
            detached_i = {}
            for k, v in i.items():
                if isinstance(v, dict):  # Check if value is a dictionary (like "representations")
                    detached_i[k] = {k1: v1.detach().cpu() for k1, v1 in v.items()}
                else:
                    detached_i[k] = v.detach().cpu()
            results.append(detached_i)
    return results

In [None]:
seq1 = "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFTYGVQCFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"
seq2 = ''.join(random.sample(seq1, len(seq1)))

def mutate(sequence, fraction):
    """Mutate a fraction of the amino acids in a given protein sequence."""
    if not 0 <= fraction <= 1:
        raise ValueError("Fraction must be between 0 and 1.")
    amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    num_mutations = int(len(sequence) * fraction)
    mutation_positions = random.sample(range(len(sequence)), num_mutations)
    mutated_sequence = list(sequence)
    for pos in mutation_positions:
        choices = list(amino_acids.replace(mutated_sequence[pos], ''))
        mutated_sequence[pos] = random.choice(choices)
    return ''.join(mutated_sequence)


data = [("prot1", seq1), ("prot2", seq2)]
for idx, i in enumerate(range(49)):
    data.append((f"prot1_{idx}", mutate(seq1, 0.05)))
    data.append((f"prot2_{idx}", mutate(seq2, 0.05)))

In [None]:
embeddings = run_esm(model, batch_converter, data)
res = []
for i in embeddings:
    prot_array = []
    for layer in range(NUM_LAYERS + 1):
        prot_array.append(i["representations"][layer].squeeze(0).mean(0))
    res.append(torch.stack(prot_array))
embeddings = torch.stack(res).transpose(0, 1)

# embeddings = torch.stack([embeddings[0]["representations"][x].squeeze(0).mean(0) for x in range(NUM_LAYERS + 1)])
embeddings_list = [pd.DataFrame(embeddings[i].numpy()) for i in range(embeddings.shape[0])]

In [None]:
def make_relation(from_df, to_df):
    left = pd.DataFrame(data=np.arange(len(from_df)), index=from_df.index)
    right = pd.DataFrame(data=np.arange(len(to_df)), index=to_df.index)
    merge = pd.merge(left, right, left_index=True, right_index=True)
    return dict(merge.values)


relations = [make_relation(embeddings_list[i], embeddings_list[i + 1]) for i in range(len(embeddings_list) - 1)]

In [None]:
aligned_mapper = umap.AlignedUMAP(
    metric="euclidean",
    n_neighbors=5,
    alignment_regularisation=0.1,
    alignment_window_size=5,
    n_epochs=200,
    random_state=42,
    min_dist=0.1,
).fit(embeddings_list, relations=relations)

final_embeddings = aligned_mapper.embeddings_

In [None]:
df = []
for idx, x in enumerate(final_embeddings):
    df_i = pd.DataFrame(x, columns=["x", "y"])
    df_i["layer"] = idx
    df_i['source'] = ["gfp", "mut"] * (df_i.shape[0]//2)
    df_i['name'] = [x[0] for x in data]
    df.append(df_i)
df = pd.concat(df)

In [None]:
def generate_colors(n):
    hues = np.linspace(0.67, 0, n)  # Start from blue (0.67) to red (0)
    colors = [colorsys.hsv_to_rgb(h, 1, 1) for h in hues]  # saturation=1, value=1 for full color
    # Convert to plotly format: 'rgb(r,g,b)'
    colors = [f"rgb({int(r*255)},{int(g*255)},{int(b*255)})" for r, g, b in colors]
    return colors

In [None]:

fig = go.Figure()
for source, i in df.groupby("source"):
    if source == "gfp":
        color = "green"
    else:
        color = "red"
    for name, j in i.groupby("name"):
        fig.add_trace(
            go.Scatter3d(
                x=j["x"],
                y=j["layer"],
                z=j["y"],
                mode="lines",
                line=dict(color=color, width=5),
                # marker=dict(size=10, color=color, opacity=0.7),
                name=name,
            )
        )
fig.update_layout(width=1500, height=1000)
## make lines thicker
for trace in fig.data:
    trace.line.width = 5
fig.show()