For benchmarking, we calculated the four metrics: recovery rate, perplexity, rmsd and tmscore. 

In [2]:
from RhoDesign import RhoDesignModel
from alphabet import Alphabet
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import random
from util import load_structure, extract_coords_from_structure, seq_rec_rate,CoordBatchConverter
import os
random.seed(0)

_device = 0
alphabet = Alphabet(['A','G','C','U','X'])
batch_converter = CoordBatchConverter(alphabet)

class args_class:  
    def __init__(self, encoder_embed_dim, decoder_embed_dim, dropout):
        self.local_rank = int(os.getenv("LOCAL_RANK", -1))
        self.device_id = [0, 1, 2, 3, 4, 5, 6, 7]
        self.epochs = 100
        self.lr = 1e-5
        self.batch_size = 1
        self.encoder_embed_dim = encoder_embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.dropout = dropout
        self.gvp_top_k_neighbors = 15
        self.gvp_node_hidden_dim_vector = 256
        self.gvp_node_hidden_dim_scalar = 512
        self.gvp_edge_hidden_dim_scalar = 32
        self.gvp_edge_hidden_dim_vector = 1
        self.gvp_num_encoder_layers = 3
        self.gvp_dropout = 0.1
        self.encoder_layers = 3
        self.encoder_attention_heads = 4
        self.attention_dropout = 0.1
        self.encoder_ffn_embed_dim = 512
        self.decoder_layers = 3
        self.decoder_attention_heads = 4
        self.decoder_ffn_embed_dim = 512

def get_sequence_loss(model, batch , _device):
    device = _device
    # batch_converter = CoordBatchConverter(alphabet)
    
    coords, confidence, strs, tokens, padding_mask,ss_ct_map = batch_converter(
        batch, device=device)
    
    c = coords[:,:,[0,1,2],:] # the four backbone atoms
    adc = coords[:,:,:,:] # eight atoms which are used to compute dihedral angles
    padding_mask = padding_mask.bool()

    prev_output_tokens = tokens[:, :-1].to(device)
    target = tokens[:, 1:]
    target_padding_mask = (target == alphabet.padding_idx)
    logits, _ = model.forward(c, adc,ss_ct_map,padding_mask, confidence, prev_output_tokens)
    loss = F.cross_entropy(logits, target, reduction='none')
    loss = loss[0].cpu().detach().numpy()
    target_padding_mask = target_padding_mask[0].cpu().numpy()
    return loss, target_padding_mask

def score_sequence(model, batch,_device):
    loss, target_padding_mask = get_sequence_loss(model, batch,_device)
    ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
    return ll_fullseq

def score_backbone(model, coords, seq, ss_ct_map, _device):
    batch = [(coords, None, seq,ss_ct_map)]
    ll= score_sequence(model, batch,_device) 
    ppl = np.exp(-ll)
    return ppl

def eval_ppl(model,pdb_list,model_path):
    """
    fpath: path to pdb file
    """
    
    temp=torch.load(model_path) 
    model.load_state_dict(temp)  
    model.eval()

    with torch.no_grad():
        pfile = './../data/test/'
        ssfile = './../data/test_ss/'
        ppl = []
        wrong_ppl = []
        wrong_p = []
        for i in tqdm(pdb_list):
            fpath = pfile+i+'.pdb'
            ss_path = ssfile+i+'.npy'
            s = load_structure(fpath)
            coords, seq = extract_coords_from_structure(s)
            ss_ct_map = np.load(ss_path)
            ppl_v = score_backbone(model,coords,seq,ss_ct_map,_device)
            ppl.append(ppl_v)
    return np.mean(ppl)


def eval(model,pdb_list,model_path,_device):
    """
    fpath: path to pdb file
    """
    test_path = './../data/test/'
    test_ss_path = './../data/test_ss/'
    
    model_dir=torch.load(model_path) 
    model.load_state_dict(model_dir)  
    model.eval()
    rc = []
    for pdb_name in tqdm(pdb_list):
        pdb_path = test_path + pdb_name + '.pdb'
        ss_path = test_ss_path + pdb_name.split('.')[0] + '.npy'
        ss_ct_map = np.load(ss_path)
        pdb = load_structure(pdb_path)
        coords, seq = extract_coords_from_structure(pdb)
        pred_seq = model.sample(coords,ss_ct_map,_device,temperature=1e-5)
        rc_value = seq_rec_rate(seq,pred_seq)
        rc.append(rc_value)
    
    return np.mean(rc)


args = args_class(512,512,0.1)
dictionary = Alphabet(['A','G','C','U','X'])
model = RhoDesignModel(args, dictionary).cuda(device=_device)



pdb_list = os.listdir('./../data/test/')
pdb_list = [i.split('.')[0] for i in pdb_list]

model_path = './../model/ss_apexp_best.pth'

recovery_rate = eval(model,pdb_list,model_path,_device) 
perplexity = eval_ppl(model,pdb_list,model_path)

print('recovery_rate:', recovery_rate)
print('perplexity:', perplexity)

100%|██████████| 279/279 [01:14<00:00,  3.77it/s]
100%|██████████| 279/279 [00:16<00:00, 17.08it/s]

recovery_rate: 0.5278057973963203
perplexity: 2.431720873359259





In [3]:
from RhoDesign import RhoDesignModel
from alphabet import Alphabet
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import random
from util import load_structure, extract_coords_from_structure, seq_rec_rate,CoordBatchConverter
import os
random.seed(0)

_device = 0
alphabet = Alphabet(['A','G','C','U','X'])
batch_converter = CoordBatchConverter(alphabet)

class args_class:  
    def __init__(self, encoder_embed_dim, decoder_embed_dim, dropout):
        self.local_rank = int(os.getenv("LOCAL_RANK", -1))
        self.device_id = [0, 1, 2, 3, 4, 5, 6, 7]
        self.epochs = 100
        self.lr = 1e-5
        self.batch_size = 1
        self.encoder_embed_dim = encoder_embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.dropout = dropout
        self.gvp_top_k_neighbors = 15
        self.gvp_node_hidden_dim_vector = 256
        self.gvp_node_hidden_dim_scalar = 512
        self.gvp_edge_hidden_dim_scalar = 32
        self.gvp_edge_hidden_dim_vector = 1
        self.gvp_num_encoder_layers = 3
        self.gvp_dropout = 0.1
        self.encoder_layers = 3
        self.encoder_attention_heads = 4
        self.attention_dropout = 0.1
        self.encoder_ffn_embed_dim = 512
        self.decoder_layers = 3
        self.decoder_attention_heads = 4
        self.decoder_ffn_embed_dim = 512

def get_sequence_loss(model, batch , _device):
    device = _device
    # batch_converter = CoordBatchConverter(alphabet)
    
    coords, confidence, strs, tokens, padding_mask,ss_ct_map = batch_converter(
        batch, device=device)
    
    c = coords[:,:,[0,1,2],:] # the four backbone atoms
    adc = coords[:,:,:,:] # eight atoms which are used to compute dihedral angles
    padding_mask = padding_mask.bool()

    prev_output_tokens = tokens[:, :-1].to(device)
    target = tokens[:, 1:]
    target_padding_mask = (target == alphabet.padding_idx)
    logits, _ = model.forward(c, adc,ss_ct_map,padding_mask, confidence, prev_output_tokens)
    loss = F.cross_entropy(logits, target, reduction='none')
    loss = loss[0].cpu().detach().numpy()
    target_padding_mask = target_padding_mask[0].cpu().numpy()
    return loss, target_padding_mask

def score_sequence(model, batch,_device):
    loss, target_padding_mask = get_sequence_loss(model, batch,_device)
    ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
    return ll_fullseq

def score_backbone(model, coords, seq, ss_ct_map, _device):
    batch = [(coords, None, seq,ss_ct_map)]
    ll= score_sequence(model, batch,_device) 
    ppl = np.exp(-ll)
    return ppl

def eval_ppl(model,pdb_list,model_path):
    """
    fpath: path to pdb file
    """
    
    temp=torch.load(model_path) 
    model.load_state_dict(temp)  
    model.eval()

    with torch.no_grad():
        pfile = '/home/hedongchen/projects/RNA3D_DATA/pdb/' # please specify the path to the pdb files of cross-fold validation datasets
        ssfile = '/home/hedongchen/projects/RNA3D_DATA/ss/'
        ppl = []
        wrong_ppl = []
        wrong_p = []
        for i in tqdm(pdb_list):
            fpath = pfile+i+'.pdb'
            ss_path = ssfile+i+'.npy'
            s = load_structure(fpath)
            coords, seq = extract_coords_from_structure(s)
            ss_ct_map = np.load(ss_path)
            ppl_v = score_backbone(model,coords,seq,ss_ct_map,_device)
            ppl.append(ppl_v)
    return np.mean(ppl)


def eval(model,pdb_list,model_path,_device):
    """
    fpath: path to pdb file
    """
    test_path = '/home/hedongchen/projects/RNA3D_DATA/pdb/' # please specify the path to the pdb files of cross-fold validation datasets
    test_ss_path = '/home/hedongchen/projects/RNA3D_DATA/ss/' 
    
    model_dir=torch.load(model_path) 
    model.load_state_dict(model_dir)  
    model.eval()
    rc = []
    for pdb_name in tqdm(pdb_list):
        pdb_path = test_path + pdb_name + '.pdb'
        ss_path = test_ss_path + pdb_name.split('.')[0] + '.npy'
        ss_ct_map = np.load(ss_path)
        pdb = load_structure(pdb_path)
        coords, seq = extract_coords_from_structure(pdb)
        pred_seq = model.sample(coords,ss_ct_map,_device,temperature=1e-5)
        rc_value = seq_rec_rate(seq,pred_seq)
        rc.append(rc_value)
    
    return np.mean(rc)


args = args_class(512,512,0.1)
dictionary = Alphabet(['A','G','C','U','X'])
model = RhoDesignModel(args, dictionary).cuda(device=_device)

name = 'seq'
rc = []
ppl = []

for i in range(5):
    path = f'/home/hedongchen/fold_{name}_{str(i)}.npy' # please specify the path to the pdb id file
    model_path = f'/home/hedongchen/projects/RhoDesign/model/f_{name}_{str(i)}/cf_16.pth' # please specify the path to the model checkpoint
    pdb_list = np.load(path)
    
    rc.append(eval(model,pdb_list,model_path,_device))
    ppl.append(eval_ppl(model,pdb_list,model_path))

recovery_rate = np.mean(rc)
perplexity = np.mean(ppl)

print('For cross-fold validation, when sequence-similarity < 0.6, average recovery_rate:', recovery_rate)
print('For cross-fold validation, when sequence-similarity < 0.6, average perplexity:', perplexity)

100%|██████████| 801/801 [03:56<00:00,  3.39it/s]
100%|██████████| 801/801 [00:50<00:00, 16.00it/s]
100%|██████████| 803/803 [03:11<00:00,  4.19it/s]
100%|██████████| 803/803 [00:43<00:00, 18.29it/s]
100%|██████████| 852/852 [05:08<00:00,  2.76it/s]
100%|██████████| 852/852 [01:02<00:00, 13.68it/s]
100%|██████████| 820/820 [02:48<00:00,  4.86it/s]
100%|██████████| 820/820 [00:40<00:00, 20.16it/s]
100%|██████████| 801/801 [03:37<00:00,  3.68it/s]
100%|██████████| 801/801 [00:47<00:00, 16.80it/s]

For cross-fold validation, when sequence-similarity < 0.6, average recovery_rate: 0.631432166540998
For cross-fold validation, when sequence-similarity < 0.6, average perplexity: 2.021607148379395





In [4]:
name = 'struc'
rc = []
ppl = []

for i in range(5):
    path = f'/home/hedongchen/fold_{name}_{str(i)}.npy' # please specify the path to the pdb id file
    model_path = f'/home/hedongchen/projects/RhoDesign/model/f_{name}_{str(i)}/cf_16.pth' # please specify the path to the model checkpoint
    pdb_list = np.load(path)
    
    rc.append(eval(model,pdb_list,model_path,_device))
    ppl.append(eval_ppl(model,pdb_list,model_path))

recovery_rate = np.mean(rc)
perplexity = np.mean(ppl)

print('For cross-fold validation, when structure-similarity < 0.5, average recovery_rate:', recovery_rate)
print('For cross-fold validation, when structure-similarity < 0.5, average perplexity:', perplexity)

100%|██████████| 802/802 [04:39<00:00,  2.87it/s]
100%|██████████| 802/802 [00:57<00:00, 13.88it/s]
100%|██████████| 803/803 [03:11<00:00,  4.19it/s]
100%|██████████| 803/803 [00:44<00:00, 18.18it/s]
100%|██████████| 801/801 [03:20<00:00,  3.99it/s]
100%|██████████| 801/801 [00:45<00:00, 17.69it/s]
100%|██████████| 803/803 [03:21<00:00,  3.99it/s]
100%|██████████| 803/803 [00:45<00:00, 17.52it/s]
100%|██████████| 840/840 [06:34<00:00,  2.13it/s]
100%|██████████| 840/840 [01:15<00:00, 11.11it/s]

For cross-fold validation, when structure-similarity < 0.5, average recovery_rate: 0.6494589401523532
For cross-fold validation, when structure-similarity < 0.5, average perplexity: 1.9656064742114487





For TM-score and RMSD, please firstly git clone the repository of RhoFold, and then use RhoFold to predict the structure of predicted sequences.

In [13]:
import pymol

def align_pdb(gt_file, compared_file):
    gt_filename_pre = os.path.splitext(os.path.split(gt_file)[1])[0]
    compare_filename_pre = os.path.splitext(os.path.split(compared_file)[1])[0]+'_des'
    if gt_filename_pre != compare_filename_pre:
        pymol.cmd.load(gt_file,object=gt_filename_pre)
        pymol.cmd.load(compared_file,object=compare_filename_pre)
        align_output = pymol.cmd.align(gt_filename_pre, compare_filename_pre, cycles=2)
        pymol.cmd.delete(gt_filename_pre)
        pymol.cmd.delete(compare_filename_pre)
        return align_output[0]
    else:
        return 0
    
def tm_score(gt_path, pred_path):
    os.system(f'/home/hedongchen/USalign {gt_path} {pred_path} > ./tmscore.txt -outfmt 2') # please specify USalign path here
    try:
        with open(f'./tmscore.txt','r') as f:
            data = f.readlines()
        t1 = float(data[1].split()[2])
    except:
        print(pred_path.split('/')[-1])
        return -1
    return t1

In [15]:
from tqdm import tqdm
name = 'seq'
tm = []
rmsd = []

for i in range(5):
    path = f'/home/hedongchen/fold_{name}_{str(i)}.npy' # please specify the path to the pdb id file
    pred_struc = f'/home/hedongchen/projects/RhoDesign/data/{name}_fold_{str(i)}_struc/'
    gt_struc = f'/home/hedongchen/projects/RNA3D_DATA/pdb/'
    pdb_list = np.load(path)
    for pdb in tqdm(pdb_list):
        gt_file = gt_struc + pdb + '.pdb'
        pred_file = pred_struc + pdb + '/unrelaxed_model.pdb'
        rmsd.append(align_pdb(gt_file, pred_file))
        tm.append(tm_score(gt_file, pred_file))

print('For cross-fold validation, when sequence-similarity < 0.6, average TM-score:', np.mean(tm))
print('For cross-fold validation, when sequence-similarity < 0.6, average RMSD:', np.mean(rmsd))

100%|██████████| 801/801 [00:37<00:00, 21.27it/s]
 20%|██        | 164/803 [00:03<00:16, 37.65it/s]



 22%|██▏       | 178/803 [00:04<00:18, 33.88it/s]



 67%|██████▋   | 537/803 [00:18<00:07, 34.52it/s]



 86%|████████▌ | 687/803 [00:26<00:08, 13.87it/s]



100%|██████████| 803/803 [00:30<00:00, 26.22it/s]
 25%|██▌       | 215/852 [00:08<00:13, 48.99it/s]



100%|██████████| 852/852 [00:51<00:00, 16.41it/s]
 97%|█████████▋| 792/820 [00:26<00:00, 39.66it/s]



100%|██████████| 820/820 [00:26<00:00, 30.71it/s]
 45%|████▍     | 360/801 [00:11<00:11, 37.53it/s]



 48%|████▊     | 388/801 [00:13<00:23, 17.86it/s]



100%|██████████| 801/801 [00:33<00:00, 23.83it/s]

For cross-fold validation, when sequence-similarity < 0.6, average TM-score: 0.301043561442237
For cross-fold validation, when sequence-similarity < 0.6, average RMSD: 13.271298147107505





In [14]:
from tqdm import tqdm
name = 'struc'
tm = []
rmsd = []

for i in range(5):
    path = f'/home/hedongchen/fold_{name}_{str(i)}.npy' # please specify the path to the pdb id file
    pred_struc = f'/home/hedongchen/projects/RhoDesign/data/{name}_fold_{str(i)}_struc/'
    gt_struc = f'/home/hedongchen/projects/RNA3D_DATA/pdb/'
    pdb_list = np.load(path)
    for pdb in tqdm(pdb_list):
        gt_file = gt_struc + pdb + '.pdb'
        pred_file = pred_struc + pdb + '/unrelaxed_model.pdb'
        rmsd.append(align_pdb(gt_file, pred_file))
        tm.append(tm_score(gt_file, pred_file))

print('For cross-fold validation, when structure-similarity < 0.5, average TM-score:', np.mean(tm))
print('For cross-fold validation, when structure-similarity < 0.5, average RMSD:', np.mean(rmsd))

  2%|▏         | 13/802 [00:00<00:17, 46.17it/s]



100%|██████████| 802/802 [00:45<00:00, 17.58it/s]
100%|██████████| 803/803 [00:30<00:00, 26.35it/s]




  7%|▋         | 58/801 [00:01<00:17, 41.55it/s]



 39%|███▊      | 309/801 [00:12<00:13, 37.62it/s]



100%|██████████| 801/801 [00:31<00:00, 25.33it/s]
 20%|██        | 164/803 [00:07<00:23, 27.74it/s]



 31%|███       | 250/803 [00:10<00:25, 21.71it/s]



100%|██████████| 803/803 [00:33<00:00, 24.00it/s]
100%|██████████| 840/840 [01:06<00:00, 12.58it/s]

For cross-fold validation, when structure-similarity < 0.5, average TM-score: 0.328571647320326
For cross-fold validation, when structure-similarity < 0.5, average RMSD: 13.027279806051704





For a given structure, to design the sequence, we can use the following code:

In [6]:
from RhoDesign import RhoDesignModel
from alphabet import Alphabet
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import random
from util import load_structure, extract_coords_from_structure, seq_rec_rate
import os
import argparse
random.seed(1)



class args_class:  # use the same param as esm-if1, waiting to be adjusted...
    def __init__(self, encoder_embed_dim, decoder_embed_dim, dropout):
        self.local_rank = int(os.getenv("LOCAL_RANK", -1))
        self.device_id = [0, 1, 2, 3, 4, 5, 6, 7]
        self.epochs = 100
        self.lr = 1e-5
        self.batch_size = 1
        self.encoder_embed_dim = encoder_embed_dim
        self.decoder_embed_dim = decoder_embed_dim
        self.dropout = dropout
        self.gvp_top_k_neighbors = 15
        self.gvp_node_hidden_dim_vector = 256
        self.gvp_node_hidden_dim_scalar = 512
        self.gvp_edge_hidden_dim_scalar = 32
        self.gvp_edge_hidden_dim_vector = 1
        self.gvp_num_encoder_layers = 3
        self.gvp_dropout = 0.1
        self.encoder_layers = 3
        self.encoder_attention_heads = 4
        self.attention_dropout = 0.1
        self.encoder_ffn_embed_dim = 512
        self.decoder_layers = 3
        self.decoder_attention_heads = 4
        self.decoder_ffn_embed_dim = 512


def eval(model,pdb_path,ss_path,save_path,_device,temp=1e-5):
    """
    fpath: path to pdb file
    """

    model_path = './../checkpoint/ss_apexp_best.pth'
    
    model_dir=torch.load(model_path) 
    model.load_state_dict(model_dir)  
    model.eval()
    rc = []
    
    ss_ct_map = np.load(ss_path)
    pdb = load_structure(pdb_path)
    coords, seq = extract_coords_from_structure(pdb)

    pred_seq = model.sample(coords,ss_ct_map,_device,temperature=temp)
    rc_value = seq_rec_rate(seq,pred_seq)
    rc.append(rc_value)
    with open(os.path.join(save_path,'pred_seq.fasta'),'w') as f:
        f.write('>predicted_by_RhoDesign'+'\n')
        f.write(pred_seq+'\n')
    print('original sequence: ' + seq)
    print('sequence: ' + pred_seq)
    print('recovery rate: ' + str(np.mean(rc)))



pdb = '/home/hedongchen/projects/RNA3D_DATA/pdb/4v6x_A8.pdb' #specify the path to the pdb file
ss = '/home/hedongchen/projects/RNA3D_DATA/ss/4v6x_A8.npy' #specify the path to the secondary structure file
save_path = './../example/'
_device = 3
temp = 1

model_args = args_class(512,512,0.1)
dictionary = Alphabet(['A','G','C','U','X'])
model = RhoDesignModel(model_args, dictionary).cuda(device=_device)
eval(model,pdb,ss,save_path,_device,temp)


original sequence: CGACUCUUAGCGGUGGAUCACUCGGCUCGUGCGUCGAUGAAGAACGCAGCUAGCUGCGAGAAUUAAUGUGAAUUGCAGGACACAUUGAUCAUCGACACUUCGAACGCACUUGCGGCCCCGGGUUCCUCCCGGGGCUACGCCUGUCUGAGCGUCGCUU
sequence: UAACUUCCGGCGGCGGACCACUCGGUCUGCAUACCGAUGAAGGACGUAACGAGCUGCAAAGACUAAUGCGAACUACGGAAUGUAGUAAUUGCUAGCGUUCUGCAUGCGUACACGACCCCAAGCUUCCCCCAGGGCAUUGUUCAUCUGAGCAUGCAUU
recovery rate: 0.5605095541401274
