In [None]:
import torch
import esm
from torch.utils.data import TensorDataset
from utils import extract_protein_sequence, refdb_find_shift, refdb_get_cs_seq, refdb_get_shift_re
from utils import align_bmrb_pdb
import os
import math
from torch.utils.data import DataLoader
from model import regression
from torch.utils.data import random_split
import argparse
import numpy as np

### Data process
In the data processing process, the esm model is used in advance to convert the sequence to embeddings and saved as tensordataset

Load the esm2 model

In [None]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

In [None]:
def data_process(refdb_path, save_path):
    all_esm_vec = torch.zeros(1, 512, 1280)
    all_label = torch.zeros((1, 512))
    all_mask = torch.zeros((1, 512)).bool()
    all_padding_mask = torch.zeros((1, 512)).bool()
    for root, directories, files in os.walk(refdb_path):
        for file in files:
            file_path =str(file.split(".")[0])
            bmrb_seq_list = extract_protein_sequence(file_path)
            s, e = refdb_find_shift(file_path)
            cs_seq = refdb_get_cs_seq(file_path, s, e)
            matched = align_bmrb_pdb(bmrb_seq, cs_seq)
            shift, mask = refdb_get_shift_re(file_path, s, e, bmrb_seq, matched, atom_type)
            for i, bmrb_seq in enumerate(bmrb_seq_list):
                if '_' not in bmrb_seq and 0<len(bmrb_seq) < 512:
                    data = [("protein1", bmrb_seq_list[i])]
                    batch_labels, batch_strs, batch_tokens = batch_converter(data)
                    with torch.no_grad():
                        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
                    token_representations = results["representations"][33]
                    embedding = token_representations[:, 1:-1, :].squeeze()
                    embedding = torch.nn.functional.pad(embedding, (0, 0, 0, 512 - embedding.shape[0]))
                    # padding the size of tensor from "res*1280" to 512*1280
                    label = torch.tensor(shift[i])
                    padding_mask = torch.zeros(512).bool()
                    padding_mask[:label.shape[0]] = True
                    label = torch.nn.functional.pad(label, (0, 512-label.shape[0]))
                    # padding the size of tensor from "res" to 512
                    mask = torch.tensor(mask[i])
                    mask = torch.nn.functional.pad(mask, (0, 512-mask.shape[0]), value=False)
                    if not torch.all(mask.eq(False)):
                        all_esm_vec = torch.cat((all_esm_vec, embedding.unsqueeze(0)), dim=0)
                        all_label = torch.cat((all_label, label.unsqueeze(0)), dim=0)
                        all_mask = torch.cat((all_mask, mask.unsqueeze(0)), dim=0)
                        all_padding_mask = torch.cat((all_padding_mask, padding_mask.unsqueeze(0)), dim=0)
        all_esm_vec = all_esm_vec[1:, :, :]
        all_label = all_label[1:, :]
        all_mask = all_mask[1:, :]
        all_padding_mask = all_padding_mask[1:, :]
        dataset = TensorDataset(all_esm_vec, all_label, all_mask, all_padding_mask)
        torch.save(dataset, save_path)
        print("Data saved successfully, size of dataset is: ", all_esm_vec.shape)

Create the tensordatasets of 6 atom types

In [None]:
from utils import extract_protein_sequence, refdb_find_shift, refdb_get_cs_seq, refdb_get_shift_re
from utils import align_bmrb_pdb
import os
atom_types = ["CA","CB","C","N","H","HA"]
refdb = "\dateset\RefDB_test_remove"
save_path = "./dataset/tensordataset/"
for atom_type in atom_types:
    data_path = save_path + atom_type + ".pt"
    data_process(refdb, save_path, atom_type)

### Train

Set random seeds

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# if you are using multi-GPU, you can use torch.cuda.manual_seed_all(seed) to set all seeds.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Set hyperparameters

In [None]:
parser = argparse.ArgumentParser()

parser.add_argument('--batchsize', type=int, default=16, help='Batch size for training')
parser.add_argument('--N', type=int, default=6, help='number of encoder')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
parser.add_argument('--d_model', type=int, default=512, help='qkv d-model dimension')
parser.add_argument('--d_vec', type=int, default=1280, help='amino embedding dimension')
parser.add_argument('--n_head', type=int, default=8, help='number of attention heads')
parser.add_argument('--shuffle', type=bool, default=False, help='shuffle dataset')
parser.add_argument('--epoch', type=int, default=20000, help='epoch time')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--device', type=str, default="cuda:0", help='learning rate')
args = parser.parse_args()

select the atom type you want to train

In [None]:
atom_type = "CA"
save_path = "./dataset/your_model_ckpt/"+atom_type+".pt"
data = np.load("./dataset/tensordataset/"+atom_type+".pt", allow_pickle=True)
# Change the path and load the dataset

In [None]:
def main(data, save_path):
    model = regression(args.d_vec, args.d_model, args.n_head, args.dropout)
    device = torch.device(args.device)
    train_loss_all = []
    val_loss_all = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-8,
                                   weight_decay=0)
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer.zero_grad()
    loss_func = torch.nn.MSELoss()
    model.to(device)
    def init_weights1(model):
        if isinstance(model, torch.nn.Linear):
            torch.nn.init.kaiming_uniform(model.weight)

    def init_weights_kaiming(model):
        if hasattr(model, 'weight') and model.weight.dim() > 1:
            torch.nn.init.kaiming_uniform(model.weight.data)

    def init_weights_xavier(model):
        if isinstance(model, torch.nn.MultiheadAttention):
            torch.nn.init.xavier_uniform_(model.in_proj_weight)
            torch.nn.init.xavier_uniform_(model.out_proj.weight)

    model.apply(init_weights1)
    train_size = int(len(data) * 1)
    val_size = len(data) - train_size
    train_dataset, val_dataset = random_split(data, [train_size, val_size])
    traindata_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
    valdata_loader = DataLoader(val_dataset, batch_size=args.batchsize, shuffle=True)

    def train(epoch):
        model.train()
        epoch_loss = 0
        all_CA = 0
        for i, batch in enumerate(traindata_loader):
            mask, label, seq_vec, padding_mask = batch[2], batch[1], batch[0], batch[3]
            mask = mask.to(device)
            label = label.to(device)
            seq_vec = seq_vec.to(device)
            padding_mask = padding_mask.to(device)
            out = model(seq_vec, padding_mask)
            loss = torch.sqrt(loss_func(out.squeeze(2)[mask], label[mask]))
            # out, log_sigma = model(seq_vec, padding_mask)
            # sigma_normal = torch.sqrt(torch.mean(0.5*(torch.exp((-1)*log_sigma)) * (out.squeeze(2)[mask] - label[mask])**2 + 0.5*log_sigma))
            all_CA += label[mask].shape[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mse = loss ** 2 * label[mask].shape[0]
            epoch_loss += mse.detach().item()
        return (epoch_loss / all_CA)

    def val(epoch):
        # model.train()
        epoch_loss = 0
        all_CA = 0
        for i, batch in enumerate(valdata_loader):
            mask, label, seq_vec, padding_mask = batch[2], batch[1], batch[0], batch[3]
            mask = mask.to(device)
            label = label.to(device)
            seq_vec = seq_vec.to(device)
            padding_mask = padding_mask.to(device)
            out = model(seq_vec, padding_mask)
            loss = loss_func(out.squeeze(2)[mask], label[mask])
            all_CA += label[mask].shape[0]
            loss = loss * label[mask].shape[0]
            epoch_loss += loss.item()
            rmse = math.sqrt(epoch_loss / all_CA)
        return rmse

    best_acc = 1.8
    for epoch in range(0, args.epoch):
        train_loss = train(epoch)
        val_loss = val(epoch)
        print(f'\tepoch{epoch:.3f}Train Loss: {train_loss:.3f} | val_rmse: {val_loss:7.3f}')
        train_loss_all.append(train_loss)
        val_loss_all.append(val_loss)

        if val_loss<best_acc:
            sp = save_path + f"epoch{epoch}_val{val_loss:.3f}.pth"
            state = {
                "epoch": epoch,
                "accuracy": val_loss,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            }
            torch.save(state, sp)
            best_acc = val_loss

if __name__ == '__main__':
    main(data, save_path)