In [1]:
import os
import torch
import numpy as np
import math
from tqdm import tqdm
import wandb
import pickle
import pandas as pd

from data.rna_object import RNA
from data.dataloader import get_train_test_dataloaders, RNATorsionalAnglesDataset, collate_fn

from torch.utils.data import DataLoader

from models.transformer import TorsionalAnglesTransformerDecoder

In [2]:
torch.manual_seed(0)
np.random.seed(0)

pdbs_path="/home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs/"
processed_dir = "/ssd_scratch/users/sriram.devata/rna_structure/dataset/"
perfect_pdb_files_train_val_test_path="/home2/sriram.devata/rna_project/cdhit/torrna_train_val_test.pkl"

with open(perfect_pdb_files_train_val_test_path, "rb") as fp:
        training_pdbs, validation_pdbs, testing_pdbs = pickle.load(fp)
training_files = list()
for each_training_pdb in training_pdbs:
        training_files.append(f"{pdbs_path}/{each_training_pdb}.pdb")
testing_files = list()
for each_testing_pdb in testing_pdbs:
        testing_files.append(f"{pdbs_path}/{each_testing_pdb}.pdb")

pdbs_to_predict = 1000
list_of_pdbs_to_predict = testing_files[:pdbs_to_predict]
predict_dataset = RNATorsionalAnglesDataset(list_of_pdbs_to_predict, processed_dir=processed_dir, type_dataset="test")
predict_dataloader = DataLoader(predict_dataset, collate_fn=collate_fn, batch_size=32)

train_dataset = RNATorsionalAnglesDataset(training_files, processed_dir=processed_dir, type_dataset="train")
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=32)


if os.path.exists(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl"):
        with open(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl", "rb") as input_file:
                print(f"Found RNA objects precomputed")
                rna_objects = pickle.load(input_file)
else:
        rna_objects = list()
        print(f"Making RNA objects")
        for each_pdb_file in tqdm(list_of_pdbs_to_predict):
                try:
                        rna_objects.append(RNA(each_pdb_file, calc_rna_fm_embeddings=True, load_dssr_dihedrals=True, load_coords=False))
                except:
                        pass
        with open(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl", "wb") as output_file:
                pickle.dump(rna_objects, output_file)
        
if os.path.exists(f"predict_seq.fasta"):
        print(f"Found precomputed FASTA sequences")
else:
        print(f"Computing FASTA sequences for SPOT-RNA-1D")
        with open(f"predict_seq.fasta", "w") as fp:
                for ran_object_idx,each_rna_object in enumerate(rna_objects):
                        pdb_code = each_rna_object.pdb_path.split('/')[-1].replace(".pdb", "")
                        fp.write(f">{pdb_code}\n")
                        fp.write(f"{each_rna_object.dssr_full_seq}\n")
        print(f"Run SPOT-RNA-1D")
        exit(1)


Found precomputed test dataset
Found precomputed train dataset
Found RNA objects precomputed
Found precomputed FASTA sequences


In [3]:
dict_attr = type('dict_attr', (object,), {})
args = dict_attr()
args.lr = 2e-4
args.embeddim = 640
args.hiddendim = 256
args.numheads = 4
args.numlayers = 3
args.dropout = 0.2
args.tol = 5

model = TorsionalAnglesTransformerDecoder(embed_dim=args.embeddim, hidden_dim=args.hiddendim, num_heads=args.numheads, num_layers=args.numlayers, dropout=args.dropout)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.load_state_dict(torch.load(f"/home2/sriram.devata/rna_project/rna_transformer/checkpoints/best_model_rna_transformer_{args.lr}_{args.embeddim}_{args.hiddendim}_{args.numheads}_{args.numlayers}_{args.dropout}_{args.tol}.pkl", map_location=device))
model.eval()
print(model)

TorsionalAnglesTransformerDecoder(
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=640, out_features=640, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=640, out_features=640, bias=True)
        )
        (linear1): Linear(in_features=640, out_features=2048, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=2048, out_features=640, bias=True)
        (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
        (dropout3): Dropout(p=0.2, inplace=

In [4]:
# ------------------- Helper functions -------------------

all_dihedral_angle_names = ["alpha" ,"beta" ,"gamma" ,"delta" ,"chi" ,"epsilon" ,"zeta" ,"eta" ,"theta"]

def convert_to_rad_and_back_to_degree(cur_res_spot_rna_pred_angles):

        final_degree_angles = list()
        for each_angle in cur_res_spot_rna_pred_angles:
                torsion_angle = torch.tensor([each_angle * math.pi / 180])
                final_degree_angle = torch.arctan2(torch.sin(torsion_angle),torch.cos(torsion_angle)).item() * 180 / math.pi
                final_degree_angles.append(final_degree_angle)

        return final_degree_angles

def make_rad_angle_tensor_to_deg_angle_list(rad_angle_tensor):

        constructed_angles = list()
        for each_angle_idx in range(len(all_dihedral_angle_names)):
                cos_angle = rad_angle_tensor[2*each_angle_idx]
                sin_angle = rad_angle_tensor[2*each_angle_idx+1]
                rad_angle = torch.arctan2(sin_angle,cos_angle)
                deg_angle = rad_angle.item() * 180 / math.pi
                constructed_angles.append(deg_angle)

        return constructed_angles

def calculate_mae(predicted, groundtruth, prediction_method="Unknown"):

        angle_errors_to_export = list()
        all_angle_errors = list()
        for each_angle_idx in range(len(all_dihedral_angle_names)):
                all_angle_errors.append(list())
                angle_errors_to_export.append(list())

        for each_rna_idx in range(len(predicted)):

                for each_angle_idx in range(len(all_dihedral_angle_names)):
                        angle_errors_to_export[each_angle_idx].append(list())

                for each_residue_idx in range(len(predicted[each_rna_idx])):

                        predicted_angles = predicted[each_rna_idx][each_residue_idx]
                        gt_angles = groundtruth[each_rna_idx][each_residue_idx]

                        for each_angle_idx in range(len(all_dihedral_angle_names)):

                                if prediction_method == "Random Baseline":
                                        difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])
                                        difference = sum(difference)/len(difference)
                                        difference = difference.item()
                                else:
                                        difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])

                                difference = min(difference, 360-difference)
                                if math.isnan(difference):
                                        continue
                                all_angle_errors[each_angle_idx].append(difference)
                                angle_errors_to_export[each_angle_idx][-1].append(difference)

        return_maes = list()
        all_errors = list()

#         print(f"\n----------\nMAEs for all dihedral angles predicted by {prediction_method}")
        for each_angle_idx in range(len(all_dihedral_angle_names)):
                mae = sum(all_angle_errors[each_angle_idx])/len(all_angle_errors[each_angle_idx])
                return_maes.append(mae)
                all_errors.append(all_angle_errors[each_angle_idx])
#                 print(f"{all_dihedral_angle_names[each_angle_idx]}: {mae:.3f}")

        return return_maes, all_errors, angle_errors_to_export



In [5]:
rna_puzzles_directory = "/home2/sriram.devata/rna_project/standardized_dataset_rna_puzzles/"
for puzzle_submission_idx, each_dir in enumerate(os.listdir(rna_puzzles_directory)):
    if not os.path.isdir(f"{rna_puzzles_directory}/{each_dir}"):
        continue
    
    # print(f"------Puzzle: {each_dir}------")
    
    cur_puzzle_directory = f"{rna_puzzles_directory}/{each_dir}"
    
    all_puzzle_submissions = os.listdir(cur_puzzle_directory)
    puzzle_solution_file = None
    for each_puzzle_submission in all_puzzle_submissions:
        if "solution" in each_puzzle_submission and "pdb" in each_puzzle_submission:
            puzzle_solution_file = each_puzzle_submission
    if puzzle_solution_file is None:
        continue
    # print(puzzle_solution_file)
        
    # get DSSR angles
    puzzle_solution_rna_object = RNA(puzzle_solution_file, calc_rna_fm_embeddings=True, load_dssr_dihedrals=True, load_coords=False, dssr_path=cur_puzzle_directory)
    gt_torsional_angles = puzzle_solution_rna_object.dssr_torsion_angles.unsqueeze(0).to(device)
    all_gt_angles = list()
    for each_residue_idx in range(len(gt_torsional_angles[0])):
        cur_res_gt_angles = make_rad_angle_tensor_to_deg_angle_list(gt_torsional_angles[0][each_residue_idx])
        all_gt_angles.append(cur_res_gt_angles)
        
    # --- Predict the dihedral angles with TorRNA ---
    rna_fm_embeddings = puzzle_solution_rna_object.rna_fm_embeddings.unsqueeze(0).to(device)
    initial_embeddings = puzzle_solution_rna_object.rna_fm_initial_embeddings.unsqueeze(0).to(device)
    gt_torsional_angles = puzzle_solution_rna_object.dssr_torsion_angles.unsqueeze(0).to(device)
    padding_mask = torch.sum(torch.abs(gt_torsional_angles), dim=-1) < 1e-5
    output = model(rna_fm_embeddings, padding_mask, initial_embeddings).detach().cpu()[0]
    all_predicted_angles = list()
    for each_residue_idx in range(len(output)):
        if padding_mask[0][each_residue_idx]:
            continue	# True if this residue is padded
        predicted_angles = make_rad_angle_tensor_to_deg_angle_list(output[each_residue_idx])
        all_predicted_angles.append(predicted_angles)
    all_maes, all_predicted_errors, angle_errors_to_export = calculate_mae([all_predicted_angles], [all_gt_angles], prediction_method="TorRNA")
    # print(f"TorRNA MAEs: {all_maes}")


    all_submissions_rmsds = pd.read_csv(f"{rna_puzzles_directory}/{each_dir}/rmsds.csv")
    all_groups = list(all_submissions_rmsds["fn"])
    all_groups = ['_'.join(x.split('_')[:-2]) for x in all_groups if "solution" not in x]
    all_groups = list(set(all_groups))
    
    each_puzzle_all_groups_labels = ["TorRNA"]
    each_puzzle_all_groups_maes = [all_maes]
    # get the RMSDs of all submissions by this group
    for each_group in all_groups:
        each_group_all_rmsds = all_submissions_rmsds
        each_group_all_rmsds = all_submissions_rmsds.loc[all_submissions_rmsds['fn'].str.contains(each_group)]
        
        # get the filename of the submission with the lowest RMSD
        each_group_best_pdb = each_group_all_rmsds.nsmallest(n=1, columns="rmsd_all").iloc[0]["fn"]
        each_group_best_rna_object = RNA(each_group_best_pdb, calc_rna_fm_embeddings=True, load_dssr_dihedrals=True, load_coords=False, dssr_path=cur_puzzle_directory)
        
        # if there is a chain mismatch, no point comparing the predictions
        if puzzle_solution_rna_object.dssr_full_seq != each_group_best_rna_object.dssr_full_seq:
            continue
            
        # get the DSSR angles of each of the group submission
        each_group_best_angles = each_group_best_rna_object.dssr_torsion_angles.unsqueeze(0).to(device)
        all_each_group_best_angles = list()
        for each_residue_idx in range(len(each_group_best_angles[0])):
            cur_res_gt_angles = make_rad_angle_tensor_to_deg_angle_list(each_group_best_angles[0][each_residue_idx])
            all_each_group_best_angles.append(cur_res_gt_angles)
        all_maes, all_predicted_errors, angle_errors_to_export = calculate_mae([all_each_group_best_angles], [all_gt_angles], prediction_method="TorRNA")
        # print(f"{each_group}: {all_maes}")
        if sum(all_maes) < 0.001:
            continue # Some cases have all 0 MAEs, it's an error
        each_puzzle_all_groups_labels.append(each_group)
        each_puzzle_all_groups_maes.append(all_maes)

        
    if len(each_puzzle_all_groups_maes) <= 3:
        continue
    # print(f"------Puzzle: {each_dir}------")
    print(f"\midrule")
    # print(f"{each_dir} & & & & & & & & & \\\\")
    print(f"\\multicolumn{{10}}{{c}}{{ {each_dir} }} \\\\")
    print(f"\midrule")
    for each_group_name, each_group_mae in zip(each_puzzle_all_groups_labels,each_puzzle_all_groups_maes):
        each_group_name = each_group_name.replace('_', '\_')
        print(f"\\text{{ {each_group_name} }}", end='')
        for mae_idx, each_mae in enumerate(each_group_mae):
            
            if each_mae == min([x[mae_idx] for x in each_puzzle_all_groups_maes]):
                print(f" & \\textbf{{ {each_mae:.2f} }}", end='')
            else:
                print(f" & {each_mae:.2f} ", end='')
            
        print("\\\\")
    
    
    if puzzle_submission_idx > 30:
        break


\midrule
\multicolumn{10}{c}{ rp19 } \\
\midrule
\text{ TorRNA } & \textbf{ 31.50 } & \textbf{ 12.79 } & \textbf{ 23.72 } & 13.34  & \textbf{ 12.71 } & \textbf{ 11.38 } & \textbf{ 16.47 } & 19.68  & 25.57 \\
\text{ 19\_RW3D } & 45.63  & 19.24  & 30.84  & 13.93  & 16.98  & 19.72  & 26.17  & 28.59  & 33.63 \\
\text{ 19\_Das\_Human } & 39.07  & 23.32  & 30.23  & \textbf{ 11.05 } & 12.81  & 19.99  & 25.18  & 20.19  & 29.86 \\
\text{ 19\_SimRNA } & 50.34  & 21.96  & 33.75  & 15.13  & 19.00  & 22.39  & 44.70  & 33.98  & 40.79 \\
\text{ 19\_Ding\_Human } & 37.55  & 18.05  & 28.31  & 16.77  & 15.15  & 13.41  & 20.80  & 23.42  & 31.44 \\
\text{ 19\_3dRNA } & 64.21  & 37.15  & 53.03  & 30.67  & 18.18  & 38.71  & 42.04  & 33.17  & 38.45 \\
\text{ 19\_Dokholyan } & 49.47  & 20.29  & 51.30  & 14.85  & 18.40  & 12.67  & 23.71  & 29.31  & 34.16 \\
\text{ 19\_LeeServer } & 34.73  & 17.87  & 28.32  & 25.31  & 24.20  & 29.19  & 24.57  & 25.22  & 33.43 \\
\text{ 19\_Chen\_Human } & 46.04  & 26.70  & 31.6

\midrule
\multicolumn{10}{c}{ rp08 } \\
\midrule
\text{ TorRNA } & \textbf{ 31.66 } & \textbf{ 12.79 } & \textbf{ 23.25 } & \textbf{ 6.57 } & \textbf{ 10.73 } & \textbf{ 14.09 } & 27.66  & 19.83  & 23.64 \\
\text{ 8\_Bujnicki } & 45.12  & 21.01  & 35.56  & 13.05  & 17.51  & 24.51  & 29.55  & 21.26  & 26.02 \\
\text{ 8\_Das } & 40.19  & 20.56  & 32.50  & 6.65  & 12.08  & 20.57  & \textbf{ 25.98 } & \textbf{ 14.41 } & \textbf{ 21.49 }\\
\text{ 8\_Adamiak } & 48.42  & 24.83  & 37.34  & 9.47  & 18.05  & 24.29  & 34.31  & 19.72  & 30.67 \\
\text{ 8\_Ding } & 49.11  & 20.79  & 38.26  & 11.82  & 14.75  & 19.63  & 38.18  & 32.91  & 34.27 \\
\text{ 8\_Chen } & 48.50  & 22.54  & 37.30  & 9.94  & 15.48  & 25.51  & 33.81  & 23.39  & 26.84 \\
\text{ 8\_Dokholyan } & 66.87  & 22.16  & 59.96  & 10.56  & 17.25  & 19.64  & 41.87  & 29.14  & 44.06 \\
