In [50]:
import torch
import esm
from torch.utils.data import TensorDataset
from utils import refdb_find_shift, refdb_get_cs_seq, refdb_get_shift_re, refdb_get_seq, get_HA_shifts, get_shifts, shiftx_get_cs_seq, shiftx_get_shift_re
from utils import align_bmrb_pdb
import os
import math
from torch.utils.data import DataLoader
from model import regression
from torch.utils.data import random_split
import argparse
import numpy as np
import pandas as pd
import sys
import math

### Load ESM model (A Protein Language Model)
In the data processing process, the ESM model is used in advance to convert the sequence to embeddings

In [51]:
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

### Load the pre-trained regression model for Chemical Shift Prediciton
This regression model use embeddings from ESM as input, and output the chemical shifts of a certain atom (CA, CB, C, N, H, HA)

In [52]:
atom_types = ["CA", "CB", "C", "N", "H", "HA"]
model = {}
for atom in atom_types:
    model[atom] = regression(1280, 512, 8, 0.1)
    file = os.path.join("plm_cs", "ckpt", "model_ckpt", "reg_" + atom.lower() + ".pth")
    model[atom].load_state_dict(torch.load(file, map_location=torch.device('cpu')))
    model[atom].eval()

  model[atom].load_state_dict(torch.load(file, map_location=torch.device('cpu')))


### A simple example: A sequence with 21 residues

In [None]:
sequence="MVKVYAPASSANMSVLIQDLM"    # test sequence
output_file = os.path.join("result", "result_test.csv")    # output file path and name

idx_repr = 33 # Using the output of the 33-rd layer as the embedding (do not change)
data = [("protein1", sequence)] 
batch_labels, batch_strs, batch_tokens = batch_converter(data)
with torch.no_grad():
    out_esm = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=False)  

token_representations = out_esm["representations"][idx_repr]
embedding = token_representations[:, 1:-1, :].squeeze()
padding_mask = torch.zeros(512).bool()
padding_mask[:embedding.shape[0]] = True
embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
mask = torch.tensor([True]*len(sequence))
mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
padding_mask = padding_mask.unsqueeze(0)

df = {"sequence": list(sequence), "CA": [0]*len(sequence), "CB": [0]*len(sequence), "C": [0]*len(sequence), "N": [0]*len(sequence), "H": [0]*len(sequence), "HA": [0]*len(sequence)}

for atom_type in atom_types:
    model_atom = model[atom_type]
    out = model_atom(embedding.unsqueeze(0), padding_mask)
    pred = out.squeeze(2).squeeze(0)[mask]
    df[atom_type] = pred.tolist()

df = pd.DataFrame(df)
df.to_csv(output_file)

### Evaluation on the ShiftX test set

In [136]:
def test_on_shiftxfile(file_path, out_path, atom_types):
    # For each file in the ShiftX dataset, calculate the RMSE of the predicted chemical shifts
    # file_path: the path of the input SHIFTX file
    # out_path: the path of the output file
    # atom_types: a list of atom types to be predicted
    # return: a list of RMSEs for each atom type
    # Note: the output file will be saved in the out_path folder with the same name as the input file
    #       and the extension changed to .csv

    bmrb_seq = refdb_get_seq(file_path)
    s, e = refdb_find_shift(file_path)
    #cs_seq = refdb_get_cs_seq(file_path, s, e)
    #bmrb_seq = refdb_get_seq(file_path)
    cs_seq = shiftx_get_cs_seq(file_path, s, e)
    matched = align_bmrb_pdb(bmrb_seq, cs_seq)
    six_rmse = []

    idx_repr = 33 # Using the output of the 33-rd layer as the embedding (do not change)
   
    if '_' not in bmrb_seq:
        df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
        for atom_type in atom_types:
            shift, mask = shiftx_get_shift_re(file_path, s, e, bmrb_seq, matched, atom_type)
            label= torch.tensor(shift)
            mask = torch.tensor(mask)
            label = torch.nn.functional.pad(label, (0, 512-label.shape[0]))
            data = [("protein1", bmrb_seq)]

            batch_labels, batch_strs, batch_tokens = batch_converter(data) 
            with torch.no_grad():
                out_esm = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=False)  
            token_representations = out_esm["representations"][idx_repr]
            embedding = token_representations[:, 1:-1, :].squeeze()
            padding_mask = torch.zeros(512).bool()
            padding_mask[:embedding.shape[0]] = True
            padding_mask = padding_mask.unsqueeze(0)
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
            
            out = model[atom_type](embedding.unsqueeze(0), padding_mask)
            #out = out.squeeze(2).squeeze(0)
            out = out.squeeze(2).squeeze(0)[0:len(bmrb_seq)]
            label = label[0:len(bmrb_seq)]
            mask = mask[0:len(bmrb_seq)]
            
            loss_func = torch.nn.MSELoss()
            loss = loss_func(out[mask], label[mask])
            #loss = loss_func(out, label)
            rmse = math.sqrt(loss.item())
            a = out.detach().numpy()
            b = label.detach().numpy()
            df[atom_type+'_pred'] = a
            df[atom_type+'_label']= b
            print(file_path + atom_type+" Inference finished, RMSE is: ", rmse)
            six_rmse.append(rmse)
    df = pd.DataFrame(df)
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    file_name = os.path.basename(file_path)
    out_file = os.path.join(out_path, file_name+".csv")
    df.to_csv(out_file)
    return six_rmse

Test each shiftx file and save results

In [None]:
path_shiftx_testset = "./dataset/shiftx_test_set"
path_shiftx_results = "./result/shiftx_test_set"

all_ca_rmse = []
all_cb_rmse = []
all_c_rmse = []
all_n_rmse = []
all_ha_rmse = []
all_h_rmse = []
for root, directories, files in os.walk(path_shiftx_testset):
    for file in files:
        if file.startswith("."):
            continue
        file_path = os.path.join(path_shiftx_testset, file)
        six_rmse = test_on_shiftxfile(file_path, path_shiftx_results, atom_types)
        math.isnan(six_rmse[0]) or all_ca_rmse.append(six_rmse[0])
        math.isnan(six_rmse[1]) or all_ca_rmse.append(six_rmse[1])
        math.isnan(six_rmse[2]) or all_ca_rmse.append(six_rmse[2])
        math.isnan(six_rmse[3]) or all_ca_rmse.append(six_rmse[3])
        math.isnan(six_rmse[4]) or all_ca_rmse.append(six_rmse[4])
        math.isnan(six_rmse[5]) or all_ca_rmse.append(six_rmse[5])
print("CA_rmse: ", np.mean(all_ca_rmse))
print("CB_rmse: ", np.mean(all_cb_rmse))
print("C_rmse: ", np.mean(all_c_rmse))
print("N_rmse: ", np.mean(all_n_rmse))
print("HA_rmse: ", np.mean(all_ha_rmse))
print("H_rmse: ", np.mean(all_h_rmse))

./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresno


  label= torch.tensor(shift)
  mask = torch.tensor(mask)


./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoCA Inference finished, RMSE is:  1.4815124214879851
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoCB Inference finished, RMSE is:  1.6144127376220248
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoC Inference finished, RMSE is:  1.0556987859077407
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoN Inference finished, RMSE is:  3.7331244637325915
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoH Inference finished, RMSE is:  0.5598468463374503
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoHA Inference finished, RMSE is:  0.40396063639532037
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresno
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoCA Inference finished, RMSE is:  0.5202485202711437
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoCB Inference finished, RMSE is:  nan
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoC Inference finished, RMSE is:  0.801852

### Evaluation on the solution-NMR test set

In [151]:
from utils import extract_protein_sequence

def test_on_solutionnmr(file_path, out_path, atom_types):
    bmrb_seq_list = extract_protein_sequence(file_path)
    six_rmse = []
    for i, bmrb_seq in enumerate(bmrb_seq_list):
        if '_' not in bmrb_seq and len(bmrb_seq) < 512:
            data = [("protein1", bmrb_seq_list[i])]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            with torch.no_grad():
                results = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=True)
            token_representations = results["representations"][idx_repr]
            embedding = token_representations[:, 1:-1, :].squeeze()
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            
            df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
            for atom_type in atom_types:
                if atom_type == "HA":
                    shifts, masks = get_HA_shifts(file_path, "HA", bmrb_seq_list)
                else:
                    shifts, masks = get_shifts(file_path, atom_type, bmrb_seq_list)
                label= torch.tensor(shifts[i])
                mask = torch.tensor(masks[i])
                padding_mask = torch.zeros(512).bool()
                padding_mask[:label.shape[0]] = True
                label = torch.nn.functional.pad(label, (0, 512 - label.shape[0]))
                mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
                padding_mask = padding_mask.unsqueeze(0)
                
                out = model[atom_type](embedding.unsqueeze(0), padding_mask)
                
                out = out.squeeze(2).squeeze(0)[0:len(bmrb_seq)]
                label = label[0:len(bmrb_seq)]
                mask = mask[0:len(bmrb_seq)]

                loss_func = torch.nn.MSELoss()
                loss = loss_func(out[mask], label[mask])
                rmse = math.sqrt(loss.item())
                a = out.detach().numpy()
                b = label.detach().numpy()
                df[atom_type+'_pred'] = a
                df[atom_type+'_label']= b
                print(file_path + atom_type+" Inference finished, rmse is: ", rmse)
                six_rmse.append(rmse)
    
    df = pd.DataFrame(df)

    file_name = os.path.basename(file_path)
    out_file = os.path.join(out_path, file_name+".csv")
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    df.to_csv(out_file)
    return six_rmse

Test each solution_nmr_test_set file

In [None]:
path_solution_testset = "./dataset/solution_nmr_test_set"
path_solution_results = "./result/solution_nmr_test_set"

all_ca_rmse = []
all_cb_rmse = []
all_c_rmse = []
all_n_rmse = []
all_ha_rmse = []
all_h_rmse = []
for root, directories, files in os.walk("./dataset/solution_nmr_test_set"):
    for file in files:
        if file.startswith("."):
            continue
        file_path = os.path.join(path_solution_testset, file)
        six_rmse = test_on_solutionnmr(file_path, path_solution_results, atom_types)
        math.isnan(six_rmse[0]) or all_ca_rmse.append(six_rmse[0])
        math.isnan(six_rmse[1]) or all_ca_rmse.append(six_rmse[1])
        math.isnan(six_rmse[2]) or all_ca_rmse.append(six_rmse[2])
        math.isnan(six_rmse[3]) or all_ca_rmse.append(six_rmse[3])
        math.isnan(six_rmse[4]) or all_ca_rmse.append(six_rmse[4])
        math.isnan(six_rmse[5]) or all_ca_rmse.append(six_rmse[5])
print("CA_rmse: ", np.mean(all_ca_rmse))
print("CB_rmse: ", np.mean(all_cb_rmse))
print("C_rmse: ", np.mean(all_c_rmse))
print("N_rmse: ", np.mean(all_n_rmse))
print("HA_rmse: ", np.mean(all_ha_rmse))
print("H_rmse: ", np.mean(all_h_rmse))

./dataset/solution_nmr_test_set/11491.strCA Inference finished, rmse is:  0.9183314803936145
./dataset/solution_nmr_test_set/11491.strCB Inference finished, rmse is:  0.8407790188280356
./dataset/solution_nmr_test_set/11491.strC Inference finished, rmse is:  0.959058211022734
./dataset/solution_nmr_test_set/11491.strN Inference finished, rmse is:  2.283902442158029
./dataset/solution_nmr_test_set/11491.strH Inference finished, rmse is:  0.33693257156614165
./dataset/solution_nmr_test_set/11491.strHA Inference finished, rmse is:  0.21729673719082884
./dataset/solution_nmr_test_set/16116.strCA Inference finished, rmse is:  0.75678596577717
./dataset/solution_nmr_test_set/16116.strCB Inference finished, rmse is:  0.6305041183893799
./dataset/solution_nmr_test_set/16116.strC Inference finished, rmse is:  0.9691719858860544
./dataset/solution_nmr_test_set/16116.strN Inference finished, rmse is:  2.3894221452416873
./dataset/solution_nmr_test_set/16116.strH Inference finished, rmse is:  0.28