In [1]:
import pandas as pd

df_protherm_mm = pd.read_csv('data/protherm_multiple.csv')


In [2]:
from scripts.model import SPIRED_Stab
import torch

# load parameter
model = SPIRED_Stab(device_list = ['cpu', 'cpu', 'cpu', 'cpu'])
model.load_state_dict(torch.load('data/model/SPIRED-Stab.pth'))
model.eval()

# load ESM-2 650M model
esm2_650M, _ = torch.hub.load('facebookresearch/esm:main', 'esm2_t33_650M_UR50D')
esm2_650M.eval()
    
# load ESM-2 3B model
esm2_3B, esm2_alphabet = torch.hub.load('facebookresearch/esm:main', 'esm2_t36_3B_UR50D')
esm2_3B.eval()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

Using cache found in /home/xiaopeng/.cache/torch/hub/facebookresearch_esm_main
Using cache found in /home/xiaopeng/.cache/torch/hub/facebookresearch_esm_main


# Single variant evaluation

In [3]:
def get_data_single(seq, device = 'cpu'):
    
    with torch.no_grad():
        _, _, target_tokens = esm2_batch_converter([('', seq)])
        results = esm2_3B(target_tokens.to(device), repr_layers = range(37), need_head_weights = False, return_contacts = False)
        f1d_esm2_3B = torch.stack([v for _, v in sorted(results["representations"].items())], dim = 2)
        f1d_esm2_3B = f1d_esm2_3B[:, 1:-1]
        f1d_esm2_3B = f1d_esm2_3B.to(dtype = torch.float32)
        
        result_esm2_650m = esm2_650M(target_tokens.to(device), repr_layers = [33], return_contacts = False)
        f1d_esm2_650M = result_esm2_650m['representations'][33][0, 1:-1, :].unsqueeze(0)
    
    data = {
            'target_tokens': target_tokens[:, 1:-1],
            'esm2-3B': f1d_esm2_3B,
            'embedding': f1d_esm2_650M
        }
    return data

In [4]:
mut_seq = df_protherm_mm.mut_seq# [:5]
wt_seq = df_protherm_mm.wt_seq# [:5]

In [5]:
mut_data = [get_data_single(seq) for seq in mut_seq]
wt_data = [get_data_single(seq) for seq in wt_seq]

In [6]:
import numpy as np

mut_pos_torch_list = [torch.tensor((np.array(list(wt_s)) != np.array(list(mut_s))).astype(int).tolist()) for wt_s, mut_s in zip(wt_seq, mut_seq)]

In [7]:
ddG_list = []
dTm_list = []
with torch.no_grad():
    for wt_d, mut_d, mut_pos in zip(wt_data, mut_data, mut_pos_torch_list):
        ddG, dTm, _, _ = model(wt_d, mut_d, mut_pos)
        print(ddG.item(), dTm.item())
        ddG_list.append(ddG.item())
        dTm_list.append(dTm.item())

-0.8959712982177734 -5.4088287353515625
0.17664146423339844 0.7661285400390625
-2.1187267303466797 -12.159454345703125
-2.5264663696289062 -13.423385620117188
-3.137950897216797 -17.839599609375
-0.7924861907958984 -4.6500244140625
-0.19548988342285156 -0.8749542236328125
0.38581275939941406 1.9339447021484375
-0.6361484527587891 -3.595428466796875
0.2747611999511719 1.463165283203125
-0.8142414093017578 -4.7284393310546875
-0.47129058837890625 -2.429443359375
0.7101898193359375 3.3847198486328125
-0.7559909820556641 -3.037139892578125
0.12155723571777344 0.8222808837890625
0.5845165252685547 3.7150115966796875
-1.1588821411132812 -5.9815521240234375
-1.3712310791015625 -7.8416290283203125
-0.5637397766113281 -3.096649169921875
-0.21100997924804688 -0.1269989013671875
-0.35350608825683594 -1.269378662109375
0.03242301940917969 0.7830047607421875
0.396270751953125 2.0268096923828125
-0.3726825714111328 -1.3959197998046875
-0.2538585662841797 -1.29864501953125
-0.5006160736083984 -2.6753

In [17]:
target = df_protherm_mm['ddg']
pred = ddG_list
# pred = dTm_list

# Calculate Spearman's correlation
from scipy.stats import spearmanr
print(spearmanr(target, pred))

# Calculate Pearson's correlation
from scipy.stats import pearsonr
print(pearsonr(target, pred))

SignificanceResult(statistic=0.6875936530362611, pvalue=1.9035339721538445e-119)
PearsonRResult(statistic=0.6427074292094982, pvalue=9.385036181633755e-100)


# Batch evaluation

In [8]:
# def get_data_batch(seqs, device='cpu'):
#     with torch.no_grad():
#         seq_to_convert = [('protein1', seq) for seq in seqs]
#         _, _, target_tokens = esm2_batch_converter(seq_to_convert)

#         result_esm2_650m = esm2_650M(target_tokens.to(device), repr_layers = [33], return_contacts = False)
#         f1d_esm2_650M = result_esm2_650m['representations'][33][:, 1:-1, :] #.unsqueeze(0)

#         results = esm2_3B(target_tokens.to(device), repr_layers=range(37), need_head_weights=False, return_contacts=False)
#         f1d_esm2_3B = torch.stack([v for _, v in sorted(results["representations"].items())], dim=2)
#         f1d_esm2_3B = f1d_esm2_3B[:, 1:-1].to(dtype = torch.float32)

#         data = [{'target_tokens': target_tokens[i][1:-1].unsqueeze(0),
#                 'esm2-3B': f1d_esm2_3B[i].unsqueeze(0),
#                 'embedding': f1d_esm2_650M[i].unsqueeze(0)} for i in range(len(seqs))]
    
#     return data


In [None]:
mut_seq = df_protherm_mm.mut_seq[:5]
wt_seq = df_protherm_mm.wt_seq[:5]

In [9]:
# mut_data = get_data_batch(mut_seq)
# wt_data = get_data_batch(wt_seq)

In [10]:
# ddG_list = []
# dTm_list = []
# for wt_d, mut_d, mut_pos in zip(wt_data, mut_data, mut_pos_torch_list):
#     ddG, dTm, _, _ = model(wt_d, mut_d, mut_pos)
#     print(ddG, dTm)
#     ddG_list.append(ddG)
#     dTm_list.append(dTm)