In [1]:
# !pip3 install fair-esm  # latest release, OR:
# !pip3 install git+https://github.com/facebookresearch/esm.git  # bleeding edge, current repo main branch
# !pip3 install -U scikit-learn
# !pip3 install torch torchvision torchaudio

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

from sklearn.decomposition import PCA
import math

import torch
import esm

In [3]:
class config:
    sequence_idx_to_mutations = 1 # add + 1 to map from structure idx to sequence idx. eg K50Q -> K51Q in sequence (0 index)

In [4]:
def mutate_full_sequence(wt_sequence: str, mutations ) -> str:
    """replace residues in wild type sequence with list of mutations of type M100N
    """
    new_sequence = list(wt_sequence)
    for mut in mutations:
        if mut == "":
            continue
        try:
            wt_aa = mut[0]
            idx = int(mut[1:-1]) + config.sequence_idx_to_mutations
            new_aa = mut[-1]
        except Exception as e:
            print(mut, mutations)
            raise(e)
            
        assert wt_sequence[idx] == wt_aa
        new_sequence[idx] = new_aa

    return "".join(new_sequence)

In [5]:
def clean_up_name(name):
    return name.split("_")[0]

In [6]:
df = pd.read_excel("KempData.xlsx", sheet_name="SeqAct_combined_cleaned")

In [7]:
df = df[["Name", "Mutations", "Average Slope"]]

In [8]:
df.loc[df.Name == "HG3.R5"]["Average Slope"].values[0]

0.0032823688146261126

In [9]:
df["activity"] = df["Average Slope"] / df.loc[df.Name == "HG3.R5"]["Average Slope"].values[0]

In [10]:
data = [
    ("hg3", "MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMKWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSQLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3_k50Q", "MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSQLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.3", "MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMHWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGCLVWHRDLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRNWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.7", "MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGCLVWHRHLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRNWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.14", "MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGAGCLVWHSHLPSWVSSITDKNTLINVMKNHITTLMTRYKGKIRTWDVVGEAFNEDGSLRQNVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQMHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASSTPLLFDGNFNPKPAYNAIVQNLQQGSIEGRGHHHHHH"),
    ("hg3.17", "MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVWPEESMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGAGCLVWHNFLPSWVSSITDKNTLINVMKNHITTLMTRYKGKIRTWDVVGEAFNEDGSLRQNVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQMHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVMGVADPDSAFASSTPLLFDGNFNPKPAYNAIVQNLQQGSIEGRGHHHHHH"),
    
    ("hg3.R1", "MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSQLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPSAKLYIADYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.R2", "MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSHLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIADYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSMTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.R3", "MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSLQWDAIEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSHLPSWVSSITDKETLTNVMKNHITTLMTRYKGKIRAWDVVGSAFNEDGSLRQTVFLNVIGEDYIKIAFQTARAADPNAKLYIADYNLDSASYPKTQAIVNKVKQWRAAGVPIDGIGSMTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.R4", "MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSLQWDAIEPSQGNFNFAGADYVVNWAQQNGKLIGGGMLVWHSHLPSWVSSITDKETLTNVMKNHITTLMTRYKGKIRCWDVVGSAFNEDGSLRQTVFLNVIGEDYIKIAFQTARAADPNAKLYIADYNLDSASYPKTQAIVNKVKQWRAAGVPIDGIGSMTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITSFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"),
    ("hg3.R5", "MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSLQWDAVEPSQGNFNFAGADYVVNWAQQNGKLIGGGMLVWHSHLPSWVSSITDKETLTNVMKNHITTLMTRYKGKIRVWDVVGSAFNEDGSLRQTVFLNVIGEDYIKIAFQTARAADPNAKLYIADSNLDSASYPKTQAIVNKVKQWRAAGVPIDGIGSMTNLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITSFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH"), 

    ("HG3R5w17", "MAEAAQSIDQLMKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVWPEESLQWDAVEPSQGNFNFAGADYVVNWAQQNGKLIGAGCLVWHNHLPSWVSSITDKETLINVMKNHITTLMTRYKGKIRVWDVVGSAFNEDGSLRQNVFLNVIGEDYIKIAFQTARAADPNAKLYIADSNLDSASYPKTQAIVNKVKQWRAAGVPIDGIGSMMNLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITSMGVADPDSAFASSTPLLFDGNFNPKPAYNAIVQNLQQGSIEGRGHHHHHH"),
    ("HG317wR5", "MAEAAQSIDQLMKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVWPEESLQWDAVEPSQGNFNFAGADYVVNWAQQNGKLIGAGCLVWHNFLPSWVSSITDKETLINVMKNHITTLMTRYKGKIRTWDVVGSAFNEDGSLRQNVFLNVIGEDYIKIAFQTARAADPNAKLYIADSNLDSASYPKTQAIVNKVKQWRAAGVPIDGIGSMMNLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITSMGVADPDSAFASSTPLLFDGNFNPKPAYNAIVQNLQQGSIEGRGHHHHHH"),
]

activities = [1, 2.5, 8, 25,250 ,396,3,5, 66, 91, 458, 45, 45]
max_value = 458
activities = [x / max_value for x in activities]

In [11]:
new_data = []
for a,activity in zip(data, activities):
    name, seq = a
    new_data.append((name, seq, activity))

data = new_data

In [12]:
for i, row in df.iterrows():
    wt_sequence = data[0][1]
    mutations = row["Mutations"].split(",")
    mutated_sequence = mutate_full_sequence(wt_sequence, mutations)
    data.append((row.Name, mutated_sequence, row["activity"]))

In [13]:
for i in range(4):
    print(data[i])
print("...")

('hg3', 'MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMKWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSQLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH', 0.002183406113537118)
('hg3_k50Q', 'MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMQWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGMLVWHSQLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRAWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRASTTPLLFDGNFNPKPAYNAIVQDLQQGSIEGRGHHHHHH', 0.0054585152838427945)
('hg3.3', 'MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVWPENSMHWDATEPSQGNFNFAGADYLVNWAQQNGKLIGGGCLVWHRDLPSWVSSITDKNTLTNVMKNHITTLMTRYKGKIRNWDVVGEAFNEDGSLRQTVFLNVIGEDYIPIAFQTARAADPNAKLYIMDYNLDSASYPKTQAIVNRVKQWRAAGVPIDGIGSQTHLSAGQGAGVLQALPLLASAGTPEVSILMLDVAGASPTDYVNVVNACLNVQSCVGITVFGVADPDSWRA

In [14]:
len(data)

210

In [15]:
ndf = pd.DataFrame(data)

In [16]:
ndf = ndf.drop_duplicates(1).reset_index(drop=True)

In [17]:
ndf.sort_values(2, ascending=False)[:20]

Unnamed: 0,0,1,2
13,1,MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,1.125701
14,2,MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,1.111081
15,3,MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVW...,1.034208
16,4,MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVW...,1.007588
10,hg3.R5,MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,1.0
17,5,MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,0.985869
18,6,MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,0.975596
19,7,MAEAAQSIDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,0.968184
20,8,MAEAAQSVDQLMKARGKVYFGVATDQNRLTTGKNAAIIKADFGMVW...,0.963019
21,9,MAEAAQSVDQLIKARGKVYFGVATDQNRLTTGKNAAIIQADFGMVW...,0.955368


In [18]:
data = [(x,y, z) for a, (x,y, z) in ndf.iterrows()]

In [19]:
ls = range(len(data))

In [20]:
len(data)

208

In [21]:
chunks = np.array_split(ls, len(data))

In [22]:
import gc

In [23]:
reps = []

for c in chunks:
    dta = data[c.min(): c.max() + 1]
    model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
    batch_converter = alphabet.get_batch_converter()
    model.eval()  # disables dropout for deterministic results
    dta = [(a, b) for (a, b, c) in dta]
    # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
    batch_labels, batch_strs, batch_tokens = batch_converter(dta)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    print("predicting")
    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[36], return_contacts=True)
    token_representations = results["representations"][36]

    print("representations")
    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
    print("done")
    reps.append(sequence_representations)
    gc.collect()

In [25]:
reps_flat = [item for sublist in reps for item in sublist]

In [26]:
nx = np.array([x.numpy() for x in reps_flat])

In [33]:
nx.shape

(208, 2560)

In [34]:
pca = PCA(n_components=2)
ll = pca.fit_transform(nx)

In [35]:
names = [name for (name, _, _) in data]

In [36]:
new_data = []

for encoding, items in zip(ll, data):
    new_data.append((*items, encoding[0], encoding[1]))

In [37]:
results_df = pd.DataFrame(new_data, columns=["names", "seq", "activity", "x", "y"])

In [38]:
results_df.to_csv("plot_df_encoding.csv", index=False)