In [1]:
import torch
import esm
from torch.utils.data import TensorDataset
from utils import refdb_find_shift, refdb_get_cs_seq, refdb_get_shift_re, refdb_get_seq, get_HA_shifts, get_shifts, shiftx_get_cs_seq, shiftx_get_shift_re
from utils import align_bmrb_pdb
import os
import math
from torch.utils.data import DataLoader
from model import PLM_CS
from torch.utils.data import random_split
import argparse
import numpy as np
import pandas as pd
import sys

### Data process
In the data processing process, the esm model is used in advance to convert the sequence to embeddings and saved as tensordataset

Load the esm2 model

In [3]:
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [7]:
def data_process(refdb_path, save_path, atom_type):
    all_esm_vec = torch.zeros(1, 512, 1280)
    all_label = torch.zeros((1, 512))
    all_mask = torch.zeros((1, 512)).bool()
    all_padding_mask = torch.zeros((1, 512)).bool()
    for root, directories, files in os.walk(refdb_path):
        for file in files:
            file_path = os.path.join(root, file)
            bmrb_seq = refdb_get_seq(file_path)
            s, e = refdb_find_shift(file_path)
            cs_seq = refdb_get_cs_seq(file_path, s, e)
            matched = align_bmrb_pdb(bmrb_seq, cs_seq)
            shift, mask = refdb_get_shift_re(file_path, s, e, bmrb_seq, matched, atom_type)
            if '_' not in bmrb_seq and 0<len(bmrb_seq) < 512:
                data = [("protein1", bmrb_seq)]
                batch_labels, batch_strs, batch_tokens = batch_converter(data)
                with torch.no_grad():
                    results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
                token_representations = results["representations"][33]
                embedding = token_representations[:, 1:-1, :].squeeze()
                embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
                # padding the size of tensor from "res*1280" to 512*1280
                label = torch.tensor(shift)
                padding_mask = torch.zeros(512).bool()
                padding_mask[:label.shape[0]] = True
                label = torch.nn.functional.pad(label, (0, 512-label.shape[0]))
                # padding the size of tensor from "res" to 512
                mask = torch.tensor(mask)
                mask = torch.nn.functional.pad(mask, (0, 512-mask.shape[0]), value=False)
                if not torch.all(mask.eq(False)):
                    all_esm_vec = torch.cat((all_esm_vec, embedding.unsqueeze(0)), dim=0)
                    all_label = torch.cat((all_label, label.unsqueeze(0)), dim=0)
                    all_mask = torch.cat((all_mask, mask.unsqueeze(0)), dim=0)
                    all_padding_mask = torch.cat((all_padding_mask, padding_mask.unsqueeze(0)), dim=0)
        all_esm_vec = all_esm_vec[1:, :, :]
        all_label = all_label[1:, :]
        all_mask = all_mask[1:, :]
        all_padding_mask = all_padding_mask[1:, :]
        dataset = TensorDataset(all_esm_vec, all_label, all_mask, all_padding_mask)
        torch.save(dataset, save_path)
        print("Data saved successfully, size of dataset is: ", all_esm_vec.shape)

Create the tensordatasets of 6 atom types

In [8]:
from utils import extract_protein_sequence, refdb_find_shift, refdb_get_cs_seq, refdb_get_shift_re
from utils import align_bmrb_pdb
import os
atom_types = ["CA","CB","C","N","H","HA"]
refdb_path = "./dataset/RefDB_test_remove"
save_dir = "./dataset/tensordataset/"
for atom_type in atom_types:
    save_path = save_dir + atom_type + ".pt"
    data_process(refdb_path, save_path, atom_type)

  label = torch.tensor(shift)
  mask = torch.tensor(mask)


KeyboardInterrupt: 

### Train

Set random seeds

In [8]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# if you are using multi-GPU, you can use torch.cuda.manual_seed_all(seed) to set all seeds.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Set hyperparameters

In [9]:
parser = argparse.ArgumentParser()
sys.argv = ['train_your_model.ipynb', '--batchsize', '16', '--N', '6', '--dropout', '0.1', '--d_model', '512', '--d_vec', '1280', '--n_head', '8', '--shuffle', 'False', '--epoch', '20000', '--lr', '5e-4', '--device', 'cpu']

parser.add_argument('--batchsize', type=int, default=16, help='Batch size for training')
parser.add_argument('--N', type=int, default=6, help='number of encoder')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
parser.add_argument('--d_model', type=int, default=512, help='qkv d-model dimension')
parser.add_argument('--d_vec', type=int, default=1280, help='amino embedding dimension')
parser.add_argument('--n_head', type=int, default=8, help='number of attention heads')
parser.add_argument('--shuffle', type=bool, default=False, help='shuffle dataset')
parser.add_argument('--epoch', type=int, default=20000, help='epoch time')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--device', type=str, default="cpu", help='learning rate')
# Change if you have cuda devices
args = parser.parse_args()

In [10]:
def main(data, save_path):
    model = PLM_CS(args.d_vec, args.d_model, args.n_head, args.dropout)
    device = torch.device(args.device)
    train_loss_all = []
    val_loss_all = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-8,
                                   weight_decay=0)
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer.zero_grad()
    loss_func = torch.nn.MSELoss()
    model.to(device)
    def init_weights1(model):
        if isinstance(model, torch.nn.Linear):
            torch.nn.init.kaiming_uniform(model.weight)

    def init_weights_kaiming(model):
        if hasattr(model, 'weight') and model.weight.dim() > 1:
            torch.nn.init.kaiming_uniform(model.weight.data)

    def init_weights_xavier(model):
        if isinstance(model, torch.nn.MultiheadAttention):
            torch.nn.init.xavier_uniform_(model.in_proj_weight)
            torch.nn.init.xavier_uniform_(model.out_proj.weight)

    model.apply(init_weights1)
    train_size = int(len(data) * 0.8)
    val_size = len(data) - train_size
    train_dataset, val_dataset = random_split(data, [train_size, val_size])
    traindata_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
    valdata_loader = DataLoader(val_dataset, batch_size=args.batchsize, shuffle=True)

    def train(epoch):
        model.train()
        epoch_loss = 0
        all_CA = 0
        for i, batch in enumerate(traindata_loader):
            mask, label, seq_vec, padding_mask = batch[2], batch[1], batch[0], batch[3]
            mask = mask.to(device)
            label = label.to(device)
            seq_vec = seq_vec.to(device)
            padding_mask = padding_mask.to(device)
            out = model(seq_vec, padding_mask)
            loss = torch.sqrt(loss_func(out.squeeze(2)[mask], label[mask]))
            # out, log_sigma = model(seq_vec, padding_mask)
            # sigma_normal = torch.sqrt(torch.mean(0.5*(torch.exp((-1)*log_sigma)) * (out.squeeze(2)[mask] - label[mask])**2 + 0.5*log_sigma))
            all_CA += label[mask].shape[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mse = loss ** 2 * label[mask].shape[0]
            epoch_loss += mse.detach().item()
        return (epoch_loss / all_CA)

    def val(epoch):
        # model.train()
        epoch_loss = 0
        all_CA = 0
        for i, batch in enumerate(valdata_loader):
            mask, label, seq_vec, padding_mask = batch[2], batch[1], batch[0], batch[3]
            mask = mask.to(device)
            label = label.to(device)
            seq_vec = seq_vec.to(device)
            padding_mask = padding_mask.to(device)
            out = model(seq_vec, padding_mask)
            loss = loss_func(out.squeeze(2)[mask], label[mask])
            all_CA += label[mask].shape[0]
            loss = loss * label[mask].shape[0]
            epoch_loss += loss.item()
            rmse = math.sqrt(epoch_loss / all_CA)
        return rmse

    best_acc = 1.8
    for epoch in range(0, args.epoch):
        train_loss = train(epoch)
        val_loss = val(epoch)
        print(f'\tepoch{epoch:.3f}Train Loss: {train_loss:.3f} | val_rmse: {val_loss:7.3f}')
        train_loss_all.append(train_loss)
        val_loss_all.append(val_loss)

        if val_loss<best_acc:
            sp = save_path + f"epoch{epoch}_val{val_loss:.3f}.pth"
            state = {
                "epoch": epoch,
                "accuracy": val_loss,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            }
            torch.save(state, sp)
            best_acc = val_loss

Train the models of six atom types separately

In [4]:
atom_types = ["CA", "CB", "C", "N", "H", "HA"]
for atom_type in atom_types:
    save_path = "./dataset/your_model_ckpt/"+atom_type+".pt"
    # you can change your model save path
    data = np.load("./dataset/tensordataset/"+atom_type+".pt", allow_pickle=True)
    main(data, save_path)

FileNotFoundError: [Errno 2] No such file or directory: './dataset/tensordataset/CA.pt'

In [15]:
def your_model(sequence, result_file_name):
    df = {"sequence": list(sequence), "CA": [0]*len(sequence), "CB": [0]*len(sequence), "C": [0]*len(sequence), "N": [0]*len(sequence), "H": [0]*len(sequence), "HA": [0]*len(sequence)}
    atom_types = ["CA", "CB", "C", "N", "H", "HA"]
    pred_shifts = {}
    for atom_type in atom_types:
        model = PLM_CS(1280, 512, 8, 0.1)
        data = [("protein1", sequence)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        with torch.no_grad():
            results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_representations = results["representations"][33]
        embedding = token_representations[:, 1:-1, :].squeeze()
        padding_mask = torch.zeros(512).bool()
        padding_mask[:embedding.shape[0]] = True
        embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
        mask = torch.tensor([True]*len(sequence))
        mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
        padding_mask = padding_mask.unsqueeze(0)

        model = PLM_CS(1280, 512, 8, 0.1)
        # model.load_state_dict(
        #     torch.load("./dataset/your_model_ckpt/"+atom_type+".pt", map_location=torch.device('cpu')))
        model.load_state_dict(
            torch.load("./plm-cs/ckpt/model_ckpt/reg_ca.pth", map_location=torch.device('cpu')))
        model.eval()
        out = model(embedding.unsqueeze(0), padding_mask)
        pred = out.squeeze(2).squeeze(0)[mask]
        df[atom_type] = pred.tolist()

    df = pd.DataFrame(df)
    df.to_csv("./result/"+str(result_file_name)+".csv")

An example of how to use your model

In [16]:

your_model(sequence="MVKVYAPASSANMSVLIQDLM", result_file_name="result")
# An example

  torch.load("./plm-cs/ckpt/model_ckpt/reg_ca.pth", map_location=torch.device('cpu')))


### Evaluate

The example of evaluating proteins in the shiftx test set

In [5]:
def test_on_shiftxfile(file_path, atom_types):
    bmrb_seq = refdb_get_seq(file_path)
    s, e = refdb_find_shift(file_path)
    cs_seq = shiftx_get_cs_seq(file_path, s, e)
    matched = align_bmrb_pdb(bmrb_seq, cs_seq)
    six_rmse = []
    if '_' not in bmrb_seq:
        print(shift)
        df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
        for atom_type in atom_types:
            shift, mask = shiftx_get_shift_re(file_path, s, e, bmrb_seq, matched, atom_type)
            label= torch.tensor(shift)
            mask = torch.tensor(mask)
            label = torch.nn.functional.pad(label, (0, 512-label.shape[0]))
            data = [("protein1", bmrb_seq)]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            with torch.no_grad():
                results = esm_model(batch_tokens, repr_layers=[33], return_contacts=False)
            token_representations = results["representations"][33]
            embedding = token_representations[:, 1:-1, :].squeeze()
            padding_mask = torch.zeros(512).bool()
            padding_mask[:embedding.shape[0]] = True
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
            model = PLM_CS(1280, 512, 8, 0.1)
            # model = classification(1280, 512, 8, 0.1)
            padding_mask = padding_mask.unsqueeze(0)
            model.load_state_dict(
                torch.load("./dataset/your_model_ckpt/"+atom_type+".pt", map_location=torch.device('cpu')))
            model.eval()
            out = model(embedding.unsqueeze(0), padding_mask)
            loss_func = torch.nn.MSELoss()
            loss = loss_func(out.squeeze(2).squeeze(0)[mask], label[mask])
            rmse = math.sqrt(loss.item())
            a = out.squeeze(2).squeeze(0)[mask].detach().numpy()
            b = label[mask].detach().numpy()
            df[atom_type+'_pred'] = a
            df[atom_type+'_label']= b
            print(file_path + atom_type+" Inference finished, rmse is: ", rmse)
            six_rmse.append(rmse)
    df = pd.DataFrame(df)
    df.to_csv("./result/"+str(file_path)+".csv")
    return six_rmse

Test each shiftx file

In [6]:
all_ca_rmse = []
all_cb_rmse = []
all_c_rmse = []
all_n_rmse = []
all_ha_rmse = []
all_h_rmse = []
for root, directories, files in os.walk("./dataset/shiftx_test_set"):
    for file in files:
        file_path = "./dataset/shiftx_test_set" + str(file)
        six_rmse = test_on_shiftxfile(file_path, atom_types)
        all_ca_rmse.append(six_rmse[0])
        all_cb_rmse.append(six_rmse[1])
        all_c_rmse.append(six_rmse[2])
        all_n_rmse.append(six_rmse[3])
        all_ha_rmse.append(six_rmse[4])
        all_h_rmse.append(six_rmse[5])
print("CA_rmse: ", np.mean(all_ca_rmse))
print("CB_rmse: ", np.mean(all_cb_rmse))
print("C_rmse: ", np.mean(all_c_rmse))
print("N_rmse: ", np.mean(all_n_rmse))
print("HA_rmse: ", np.mean(all_ha_rmse))
print("H_rmse: ", np.mean(all_h_rmse))

FileNotFoundError: [Errno 2] No such file or directory: './dataset/shiftx_test_setA001_bmr4032.str.corr.pdbresno'

The example of evaluating proteins in the solution_nmr_testset

In [None]:
def test_on_solutionnmr(file_path, atom_types):
    bmrb_seq_list = extract_protein_sequence(file_path)
    six_rmse = []
    for i, bmrb_seq in enumerate(bmrb_seq_list):
        if '_' not in bmrb_seq and len(bmrb_seq) < 512:
            data = [("protein1", bmrb_seq_list[i])]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            with torch.no_grad():
                results = esm_model(batch_tokens, repr_layers=[33], return_contacts=True)
            token_representations = results["representations"][33]
            embedding = token_representations[:, 1:-1, :].squeeze()
            embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
            model = PLM_CS(1280, 512, 8, 0.1)
            df = {'CA_label':[], 'CA_pred':[], 'CB_label':[], 'CB_pred':[], 'C_label':[], 'C_pred':[], 'N_label':[], 'N_pred':[], 'HA_label':[], 'HA_pred':[], 'H_label':[], 'H_pred':[], }
            for atom_type in atom_types:
                if atom_type == "HA":
                    shifts, masks = get_HA_shifts(file_path, "HA", bmrb_seq_list)
                else:
                    shifts, masks = get_shifts(file_path, atom_type, bmrb_seq_list)
                label= torch.tensor(shifts[i])
                mask = torch.tensor(masks[i])
                padding_mask = torch.zeros(512).bool()
                padding_mask[:label.shape[0]] = True
                label = torch.nn.functional.pad(label, (0, 512 - label.shape[0]))
                mask = torch.nn.functional.pad(mask, (0, 512 - mask.shape[0]), value=False)
                padding_mask = padding_mask.unsqueeze(0)
                model.load_state_dict(
                    torch.load("F:\\nmrprediction\\CSpre\\inmemory\\best_model\\" + atom_types[atom_type] , map_location=torch.device('cpu')))
                model.eval()
                out = model(embedding.unsqueeze(0), padding_mask)
                loss_func = torch.nn.MSELoss()
                loss = loss_func(out.squeeze(2).squeeze(0)[mask], label[mask])
                rmse = math.sqrt(loss.item())
                a = out.squeeze(2).squeeze(0)[mask].detach().numpy()
                b = label[mask].detach().numpy()
                df[atom_type+'_pred'] = a
                df[atom_type+'_label']= b
                print(file_path + atom_type+" Inference finished, rmse is: ", rmse)
                six_rmse.append(rmse)
    df = pd.DataFrame(df)
    df.to_csv("F:\\nmrprediction\\CSpre\\dataset\\all_infer\\"+str(file_path)+".csv")
    
    return six_rmse

Test each solution_nmr_test_set file

In [None]:
all_ca_rmse = []
all_cb_rmse = []
all_c_rmse = []
all_n_rmse = []
all_ha_rmse = []
all_h_rmse = []
for root, directories, files in os.walk("./dataset/solution_nmr_test_set"):
    for file in files:
        file_path = "./dataset/solution_nmr_test_set" + str(file)
        six_rmse = test_on_solutionnmr(file_path, atom_types)
        all_ca_rmse.append(six_rmse[0])
        all_cb_rmse.append(six_rmse[1])
        all_c_rmse.append(six_rmse[2])
        all_n_rmse.append(six_rmse[3])
        all_ha_rmse.append(six_rmse[4])
        all_h_rmse.append(six_rmse[5])
print("CA_rmse: ", np.mean(all_ca_rmse))
print("CB_rmse: ", np.mean(all_cb_rmse))
print("C_rmse: ", np.mean(all_c_rmse))
print("N_rmse: ", np.mean(all_n_rmse))
print("HA_rmse: ", np.mean(all_ha_rmse))
print("H_rmse: ", np.mean(all_h_rmse))