In [3]:
import torch
import os
import math
import numpy as np
import pandas as pd
import math

import esm
from utils import refdb_find_shift, refdb_get_seq, get_HA_shifts, get_shifts, shiftx_get_cs_seq, shiftx_get_shift_re
from utils import align_bmrb_pdb
from model import PLM_CS

### Load ESM model (A Protein Language Model)
In the data processing process, the ESM model is used in advance to convert the sequence to embeddings

In [3]:
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

### Load the pre-trained regression model for Chemical Shift Prediciton
This regression model use embeddings from ESM as input, and output the chemical shifts of a certain atom (CA, CB, C, N, H, HA)

In [4]:
atom_types = ["CA", "CB", "C", "N", "H", "HA"]
model = {}
for atom in atom_types:
    model[atom] = PLM_CS(1280, 512, 8, 0.1)
    file = os.path.join("plm_cs", "ckpt", "model_ckpt", "reg_" + atom.lower() + ".pth")
    model[atom].load_state_dict(torch.load(file, map_location=torch.device('cpu')))
    model[atom].eval()

  model[atom].load_state_dict(torch.load(file, map_location=torch.device('cpu')))


### A simple example: A sequence with 21 residues

In [7]:
sequence="MVKVYAPASSANMSVLIQDLM"    # test sequence
output_file = os.path.join("result", "result_test.csv")    # output file path and name

idx_repr = 33 # Using the output of the 33-rd layer as the embedding (do not change)
data = [("protein1", sequence)] 
batch_labels, batch_strs, batch_tokens = batch_converter(data)
with torch.no_grad():
    out_esm = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=False)  

token_representations = out_esm["representations"][idx_repr]
embedding = token_representations[:, 1:-1, :].squeeze()
padding_mask = torch.zeros(512).bool()
padding_mask[:embedding.shape[0]] = True
embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
mask = torch.tensor([True]*len(sequence))
mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
padding_mask = padding_mask.unsqueeze(0)

df = {"sequence": list(sequence), "CA": [0]*len(sequence), "CB": [0]*len(sequence), "C": [0]*len(sequence), "N": [0]*len(sequence), "H": [0]*len(sequence), "HA": [0]*len(sequence)}

for atom_type in atom_types:
    model_atom = model[atom_type]
    out = model_atom(embedding.unsqueeze(0), padding_mask)
    pred = out.squeeze(2).squeeze(0)[mask]
    df[atom_type] = pred.tolist()

df = pd.DataFrame(df)
df.to_csv(output_file)

### Evaluation on the ShiftX test set

In [9]:
def test_on_shiftxfile(file_path, out_path, atom_types):
    # For each file in the ShiftX dataset, calculate the RMSE of the predicted chemical shifts
    # file_path: the path of the input SHIFTX file
    # out_path: the path of the output file
    # atom_types: a list of atom types to be predicted
    # return: a list of RMSEs for each atom type
    # Note: the output file will be saved in the out_path folder with the same name as the input file
    #       and the extension changed to .csv

    bmrb_seq = refdb_get_seq(file_path)
    s, e = refdb_find_shift(file_path)
    #cs_seq = refdb_get_cs_seq(file_path, s, e)
    #bmrb_seq = refdb_get_seq(file_path)
    cs_seq = shiftx_get_cs_seq(file_path, s, e)
    matched = align_bmrb_pdb(bmrb_seq, cs_seq)
    six_rmse = []

    idx_repr = 33 # Using the output of the 33-rd layer as the embedding (do not change)
   
    if '_' not in bmrb_seq:
        df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
        for atom_type in atom_types:
            shift, mask = shiftx_get_shift_re(file_path, s, e, bmrb_seq, matched, atom_type)
            label= torch.tensor(shift)
            mask = torch.tensor(mask)
            label = torch.nn.functional.pad(label, (0, 512-label.shape[0]))
            data = [("protein1", bmrb_seq)]

            batch_labels, batch_strs, batch_tokens = batch_converter(data) 
            with torch.no_grad():
                out_esm = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=False)  
            token_representations = out_esm["representations"][idx_repr]
            embedding = token_representations[:, 1:-1, :].squeeze()
            padding_mask = torch.zeros(512).bool()
            padding_mask[:embedding.shape[0]] = True
            padding_mask = padding_mask.unsqueeze(0)
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
            
            out = model[atom_type](embedding.unsqueeze(0), padding_mask)
            #out = out.squeeze(2).squeeze(0)
            out = out.squeeze(2).squeeze(0)[0:len(bmrb_seq)]
            label = label[0:len(bmrb_seq)]
            mask = mask[0:len(bmrb_seq)]
            
            loss_func = torch.nn.MSELoss()
            loss = loss_func(out[mask], label[mask])
            #loss = loss_func(out, label)
            rmse = math.sqrt(loss.item())
            a = out.detach().numpy()
            b = label.detach().numpy()
            df[atom_type+'_pred'] = a
            df[atom_type+'_label']= b
            print(file_path + atom_type+" Inference finished, RMSE is: ", rmse)
            six_rmse.append(rmse)
    df = pd.DataFrame(df)
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    file_name = os.path.basename(file_path)
    out_file = os.path.join(out_path, file_name+".csv")
    df.to_csv(out_file)
    return six_rmse

Test each shiftx file and save results

In [None]:
path_shiftx_testset = "./dataset/shiftx_test_set"
path_shiftx_results = "./result/shiftx_test_set"

all_six_rmse = [[], [], [], [], [], []]
for root, directories, files in os.walk(path_shiftx_testset):
    for file in files:
        if file.startswith("."):
            continue
        file_path = os.path.join(path_shiftx_testset, file)
        six_rmse = test_on_shiftxfile(file_path, path_shiftx_results, atom_types)
        for i in range(len(six_rmse)):
            if not math.isnan(six_rmse[i]):
                all_six_rmse[i].append(six_rmse[i])
print("Average RMSE for CA: ", np.mean(all_six_rmse[0]))
print("Average RMSE for CB: ", np.mean(all_six_rmse[1]))
print("Average RMSE for C: ", np.mean(all_six_rmse[2]))
print("Average RMSE for N: ", np.mean(all_six_rmse[3]))
print("Average RMSE for H: ", np.mean(all_six_rmse[4]))
print("Average RMSE for HA: ", np.mean(all_six_rmse[5]))


  label= torch.tensor(shift)
  mask = torch.tensor(mask)


./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoCA Inference finished, RMSE is:  1.4815124214879851
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoCB Inference finished, RMSE is:  1.6144127376220248
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoC Inference finished, RMSE is:  1.0556987859077407
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoN Inference finished, RMSE is:  3.7331244637325915
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoH Inference finished, RMSE is:  0.5598468463374503
./dataset/shiftx_test_set/A018_bmr4879.str.corr.pdbresnoHA Inference finished, RMSE is:  0.40396063639532037
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoCA Inference finished, RMSE is:  0.5202485202711437
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoCB Inference finished, RMSE is:  nan
./dataset/shiftx_test_set/A040_bmr5358.str.corr.pdbresnoC Inference finished, RMSE is:  0.8018524820283517
./dataset/shiftx_test_set/A040_bmr5358.str.cor

In [14]:
print(all_ca_rmse)
print(np.mean(all_ca_rmse))

[1.4815124214879851, 1.6144127376220248, 1.0556987859077407, 3.7331244637325915, 0.5598468463374503, 0.40396063639532037, 0.5202485202711437, 0.8018524820283517, 0.8311067558141593, 0.2825131660532903, 0.17743242206408755, 1.0810216607456493, 1.0457117250621095, 1.9399532660528176, 0.277084914838728, 0.29713261870733787, 0.5101243434117528, 0.9591848936815447, 0.9349756906276584, 0.12760752962183342, 0.09633190041666932, 1.3056053542160593, 1.7583909502879516, 1.1943088880892863, 3.90357525613039, 0.594526084847389, 0.7647454406088586, 1.0274804898194094, 0.8472141572085484, 1.0874905355633586, 0.1426169590777091, 0.42948359072739195, 0.43141840051702085, 0.5177672130738682, 0.6203967808272849, 0.1882909938459481, 0.16890158666230493, 1.5682765230976232, 1.707662644578073, 1.3952766371130445, 3.7686931093391474, 0.6454022389008705, 0.3632444957885847, 1.323531566100767, 0.8517139895020409, 1.1591246671401563, 2.381271914884394, 0.5004116092662427, 0.30043205203756534, 0.863719880535357

### Evaluation on the solution-NMR test set

In [16]:
from utils import extract_protein_sequence

def test_on_solutionnmr(file_path, out_path, atom_types):
    bmrb_seq_list = extract_protein_sequence(file_path)
    six_rmse = []
    for i, bmrb_seq in enumerate(bmrb_seq_list):
        if '_' not in bmrb_seq and len(bmrb_seq) < 512:
            data = [("protein1", bmrb_seq_list[i])]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            with torch.no_grad():
                results = esm_model(batch_tokens, repr_layers=[idx_repr], return_contacts=True)
            token_representations = results["representations"][idx_repr]
            embedding = token_representations[:, 1:-1, :].squeeze()
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            
            df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
            for atom_type in atom_types:
                if atom_type == "HA":
                    shifts, masks = get_HA_shifts(file_path, "HA", bmrb_seq_list)
                else:
                    shifts, masks = get_shifts(file_path, atom_type, bmrb_seq_list)
                label= torch.tensor(shifts[i])
                mask = torch.tensor(masks[i])
                padding_mask = torch.zeros(512).bool()
                padding_mask[:label.shape[0]] = True
                label = torch.nn.functional.pad(label, (0, 512 - label.shape[0]))
                mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
                padding_mask = padding_mask.unsqueeze(0)
                
                out = model[atom_type](embedding.unsqueeze(0), padding_mask)
                
                out = out.squeeze(2).squeeze(0)[0:len(bmrb_seq)]
                label = label[0:len(bmrb_seq)]
                mask = mask[0:len(bmrb_seq)]

                loss_func = torch.nn.MSELoss()
                loss = loss_func(out[mask], label[mask])
                rmse = math.sqrt(loss.item())
                a = out.detach().numpy()
                b = label.detach().numpy()
                df[atom_type+'_pred'] = a
                df[atom_type+'_label']= b
                # print(file_path + atom_type+" Inference finished, rmse is: ", rmse)
                six_rmse.append(rmse)
    
    df = pd.DataFrame(df)

    file_name = os.path.basename(file_path)
    out_file = os.path.join(out_path, file_name+".csv")
    if not os.path.exists(out_path):
        os.makedirs(out_path)
    df.to_csv(out_file)
    return six_rmse

Test each solution_nmr_test_set file

In [40]:
path_solution_testset = "./dataset/solution_nmr_test_set"
path_solution_results = "./result/solution_nmr_test_set"

all_six_rmse = [[], [], [], [], [], []]

exclude_bmrb = {"16068":[0, 1, 2, 3], "15593":[2], "6812":[0], "5467":[2], 7178:[0]}
# This samples are removed for their high RMSE values from both ShiftX2 and PLM-CS

for root, directories, files in os.walk("./dataset/solution_nmr_test_set"):
    for file in files:
        if file.startswith("."):
            continue
        file_path = os.path.join(path_solution_testset, file)
        six_rmse = test_on_solutionnmr(file_path, path_solution_results, atom_types)


        if file.split(".")[0] != "11471":
            # The PDB file corresponding to 11471 has two chains, but both chains are part of the first sequence in bmrb, so it is calculated twice
            for i in range(len(six_rmse)):
                if file.split(".")[0] not in exclude_bmrb or i not in exclude_bmrb[file.split(".")[0]]:
                    if not math.isnan(six_rmse[i]):
                        if i<6:
                            all_six_rmse[i].append(six_rmse[i])
                        else:
                            all_six_rmse[i-6].append(six_rmse[i])
        elif file.split(".")[0] == "4246" or file.split(".")[0] == "7125":
            # Both chains of the two BMRB files are in the test set
            for i in range(len(six_rmse)):
                if not math.isnan(six_rmse[i]):
                    if i<6:
                        all_six_rmse[i].append(six_rmse[i])
                    else:
                        all_six_rmse[i-6].append(six_rmse[i-6])
        else:
            for i in range(6):
                if not math.isnan(six_rmse[i]):
                    all_six_rmse[i].append(six_rmse[i])

        
            
print("Average RMSE for CA: ", np.mean(all_six_rmse[0]))
print("Average RMSE for CB: ", np.mean(all_six_rmse[1]))
print("Average RMSE for C: ", np.mean(all_six_rmse[2]))
print("Average RMSE for N: ", np.mean(all_six_rmse[3]))
print("Average RMSE for H: ", np.mean(all_six_rmse[4]))
print("Average RMSE for HA: ", np.mean(all_six_rmse[5]))

Average RMSE for CA:  1.2189667953766565
Average RMSE for CB:  1.3769884999033648
Average RMSE for C:  1.236773217075236
Average RMSE for N:  2.68440393577155
Average RMSE for H:  0.41676108711190213
Average RMSE for HA:  0.29135119968359896
