In [2]:
from scripts.model import SPIRED_Stab
import torch

# load parameter
model = SPIRED_Stab(device_list = ['cpu', 'cpu', 'cpu', 'cpu'])
model.load_state_dict(torch.load('data/model/SPIRED-Stab.pth'))
model.eval()

# load ESM-2 650M model
esm2_650M, _ = torch.hub.load('facebookresearch/esm:main', 'esm2_t33_650M_UR50D')
esm2_650M.eval()
    
# load ESM-2 3B model
esm2_3B, esm2_alphabet = torch.hub.load('facebookresearch/esm:main', 'esm2_t36_3B_UR50D')
esm2_3B.eval()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

Using cache found in /home/xiaopeng/.cache/torch/hub/facebookresearch_esm_main
Using cache found in /home/xiaopeng/.cache/torch/hub/facebookresearch_esm_main


# Evaluate variants one-by-one

In [6]:
def get_data_single(seq, device = 'cpu'):
    
    with torch.no_grad():
        _, _, target_tokens = esm2_batch_converter([('', seq)])
        results = esm2_3B(target_tokens.to(device), repr_layers = range(37), need_head_weights = False, return_contacts = False)
        f1d_esm2_3B = torch.stack([v for _, v in sorted(results["representations"].items())], dim = 2)
        f1d_esm2_3B = f1d_esm2_3B[:, 1:-1]
        f1d_esm2_3B = f1d_esm2_3B.to(dtype = torch.float32)
        
        result_esm2_650m = esm2_650M(target_tokens.to(device), repr_layers = [33], return_contacts = False)
        f1d_esm2_650M = result_esm2_650m['representations'][33][0, 1:-1, :].unsqueeze(0)
    
    data = {
            'target_tokens': target_tokens[:, 1:-1],
            'esm2-3B': f1d_esm2_3B,
            'embedding': f1d_esm2_650M
        }
    return data

In [7]:
import numpy as np

def pred_ddG_dTm(mut_seqs, wt_seqs):
    mut_data = [get_data_single(seq) for seq in mut_seqs]
    wt_data = [get_data_single(seq) for seq in wt_seqs]

    mut_pos_torch_list = [torch.tensor((np.array(list(wt_s)) != np.array(list(mut_s))).astype(int).tolist()) 
                          for wt_s, mut_s in zip(wt_seq, mut_seq)]

    ddG_list = []
    dTm_list = []
    with torch.no_grad():
        for wt_d, mut_d, mut_pos in zip(wt_data, mut_data, mut_pos_torch_list):
            ddG, dTm, _, _ = model(wt_d, mut_d, mut_pos)
            print(ddG.item(), dTm.item())
            ddG_list.append(ddG.item())
            dTm_list.append(dTm.item())
    return ddG_list, dTm_list


In [9]:
import pandas as pd
from scipy.stats import spearmanr, pearsonr

df_protherm_mm = pd.read_csv('data/protherm_multiple.csv')
mut_seq = df_protherm_mm.mut_seq[:5]
wt_seq = df_protherm_mm.wt_seq[:5]

ddG_list, dTm_list = pred_ddG_dTm(mut_seq, wt_seq)

target = df_protherm_mm['ddg'][:5]
print(spearmanr(target, ddG_list))
print(pearsonr(target, dTm_list))

SignificanceResult(statistic=0.8999999999999998, pvalue=0.03738607346849874)
PearsonRResult(statistic=0.8835027286968199, pvalue=0.046888792768475386)


In [10]:
savinase = 'AQSVPWGISRVQAPAAHNRGLTGSGVKVAVLDTGISTHPDLNIRGGASFVPGEPSTQDGNGHGTHVAGTIAALNNSIGVLGVAPSAELYAVKVLGASGSGSVSSIAQGLEWAGNNGMHVANLSLGSPSPSATLEQAVNSATSRGVLVVAASGNSGAGSISYPARYANAMAVGATDQNNNRASFSQYGAGLDIVAPGVNVQSTYPGSTYASLNGTSMATPHVAGAAALVKQKNPSWSNVQIRNHLKNTATSLGSTNLYGSGLVNAEAATR'

df_stab_mm = pd.read_csv('data/stab_data_bsj_r1.csv')

mut_seq = df_stab_mm.seq
wt_seq = [savinase] * len(mut_seq)

ddG_list, dTm_list = pred_ddG_dTm(mut_seq, wt_seq)

target = df_stab_mm['Stability']

print(spearmanr(target, ddG_list))
print(pearsonr(target, dTm_list))
