In [1]:
import torch
import numpy as np
import math
from tqdm import tqdm
import wandb
import pickle

from data.rna_object import RNA
from data.dataloader import get_train_test_dataloaders, RNATorsionalAnglesDataset, collate_fn

from torch.utils.data import DataLoader

from models.transformer import TorsionalAnglesTransformerDecoder

In [2]:
all_dihedral_angle_names = ["alpha" ,"beta" ,"gamma" ,"delta" ,"chi" ,"epsilon" ,"zeta" ,"eta" ,"theta"]

def convert_to_rad_and_back_to_degree(cur_res_spot_rna_pred_angles):

	final_degree_angles = list()
	for each_angle in cur_res_spot_rna_pred_angles:
		torsion_angle = torch.tensor([each_angle * math.pi / 180])
		final_degree_angle = torch.arctan2(torch.sin(torsion_angle),torch.cos(torsion_angle)).item() * 180 / math.pi
		final_degree_angles.append(final_degree_angle)

	return final_degree_angles

def make_rad_angle_tensor_to_deg_angle_list(rad_angle_tensor):

	constructed_angles = list()
	for each_angle_idx in range(len(all_dihedral_angle_names)):
		cos_angle = rad_angle_tensor[2*each_angle_idx]
		sin_angle = rad_angle_tensor[2*each_angle_idx+1]
		rad_angle = torch.arctan2(sin_angle,cos_angle)
		deg_angle = rad_angle.item() * 180 / math.pi
		constructed_angles.append(deg_angle)

	return constructed_angles

def calculate_mae(predicted, groundtruth, prediction_method="Unknown"):

	all_angle_errors = list()
	for each_angle_idx in range(len(all_dihedral_angle_names)):
		all_angle_errors.append(list())

	for each_rna_idx in range(len(predicted)):
		for each_residue_idx in range(len(predicted[each_rna_idx])):

			predicted_angles = predicted[each_rna_idx][each_residue_idx]
			gt_angles = groundtruth[each_rna_idx][each_residue_idx]

			for each_angle_idx in range(len(all_dihedral_angle_names)):

				if prediction_method == "Random Baseline":
					difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])
					difference = sum(difference)/len(difference)
					difference = difference.item()
				else:
					difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])

				difference = min(difference, 360-difference)
				if math.isnan(difference):
					continue
				all_angle_errors[each_angle_idx].append(difference)

	return_maes = list()
	all_errors = list()

	print(f"\n----------\nMAEs for all dihedral angles predicted by {prediction_method}")
	for each_angle_idx in range(len(all_dihedral_angle_names)):
		mae = sum(all_angle_errors[each_angle_idx])/len(all_angle_errors[each_angle_idx])
		return_maes.append(mae)
		all_errors.append(all_angle_errors[each_angle_idx])
		print(f"{all_dihedral_angle_names[each_angle_idx]}: {mae:.3f}")

	return return_maes, all_errors

def calculate_mae_for_positions(predicted, groundtruth, nuc_positions, prediction_method=""):
    all_angle_errors = list()
    for each_angle_idx in range(len(all_dihedral_angle_names)):
        all_angle_errors.append(list())

    for each_residue_idx in nuc_positions:
        
        predicted_angles = predicted[each_residue_idx]
        gt_angles = groundtruth[each_residue_idx]

        for each_angle_idx in range(len(all_dihedral_angle_names)):

            if prediction_method == "Random Baseline":
                difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])
                difference = sum(difference)/len(difference)
            else:
                difference = abs(predicted_angles[each_angle_idx] - gt_angles[each_angle_idx])

            difference = min(difference, 360-difference)
            if math.isnan(difference):
                continue
            all_angle_errors[each_angle_idx].append(difference)

    return_maes = list()
    all_errors = list()

    for each_angle_idx in range(len(all_dihedral_angle_names)):
        try:
            mae = sum(all_angle_errors[each_angle_idx])/len(all_angle_errors[each_angle_idx])
        except:
            mae = 0
        return_maes.append(mae)
        all_errors.append(all_angle_errors[each_angle_idx])

    return return_maes, all_errors

In [5]:
dict_attr = type('dict_attr', (object,), {})
args = dict_attr()
args.lr = 2e-4
args.embeddim = 640
args.hiddendim = 256
args.numheads = 4
args.numlayers = 3
args.dropout = 0.2
args.tol = 5

# Setting up the transformer model
model = TorsionalAnglesTransformerDecoder(embed_dim=args.embeddim, hidden_dim=args.hiddendim, num_heads=args.numheads, num_layers=args.numlayers, dropout=args.dropout)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.load_state_dict(torch.load(f"/home2/sriram.devata/rna_project/rna_transformer/checkpoints/best_model_rna_transformer_{args.lr}_{args.embeddim}_{args.hiddendim}_{args.numheads}_{args.numlayers}_{args.dropout}_{args.tol}.pkl", map_location=device))
model.eval()

TorsionalAnglesTransformerDecoder(
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=640, out_features=640, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=640, out_features=640, bias=True)
        )
        (linear1): Linear(in_features=640, out_features=2048, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=2048, out_features=640, bias=True)
        (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
        (dropout3): Dropout(p=0.2, inplace=

In [6]:
pdbs_path="/home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs/"
dssr_path=pdbs_path
processed_dir = "/ssd_scratch/users/sriram.devata/rna_structure/dataset/"
perfect_pdb_files_train_val_test_path="/home2/sriram.devata/rna_project/cdhit/torrna_train_val_test.pkl"

with open(perfect_pdb_files_train_val_test_path, "rb") as fp:
	training_pdbs, validation_pdbs, testing_pdbs = pickle.load(fp)
training_files = list()
for each_training_pdb in training_pdbs:
	training_files.append(f"{pdbs_path}/{each_training_pdb}.pdb")
testing_files = list()
for each_testing_pdb in testing_pdbs:
	testing_files.append(f"{pdbs_path}/{each_testing_pdb}.pdb")

In [7]:
def get_all_regions(each_test_pdb, pdbs_path):
    each_pdb_file = f"{pdbs_path}/{each_test_pdb}.pdb"
    try:
        rna_object = RNA(each_pdb_file, calc_rna_fm_embeddings=True, load_dssr_dihedrals=True, load_coords=False)
    except:
        return None
    # print(pdb_idx, rna_object.dssr_full_seq, len(rna_object.dssr_full_seq))
    
    
    # --- Predict the dihedral angles ---
    rna_fm_embeddings = rna_object.rna_fm_embeddings.unsqueeze(0).to(device)
    initial_embeddings = rna_object.rna_fm_initial_embeddings.unsqueeze(0).to(device)
    pading_mask_shape = list(rna_fm_embeddings.shape[:-1])
    pading_mask_shape[-1] -= 2
    padding_mask = torch.zeros(pading_mask_shape, device=device, dtype=torch.bool)

    # Get the dihedral angle predictions. This will be the "ground truth" dihedral angles
    output = model(rna_fm_embeddings, padding_mask, initial_embeddings).detach().cpu()[0]
    rad_gt_angles = rna_object.dssr_torsion_angles
    all_predicted_angles = list()
    all_gt_angles = list()
    for each_residue_idx in range(len(output)):
        if padding_mask[0][each_residue_idx]:
            continue	# True if this residue is padded
        predicted_angles = make_rad_angle_tensor_to_deg_angle_list(output[each_residue_idx])
        gt_angles = make_rad_angle_tensor_to_deg_angle_list(rad_gt_angles[each_residue_idx])
        
        all_predicted_angles.append(predicted_angles)
        all_gt_angles.append(gt_angles)
    
    
    each_pdb_dssr = f"{dssr_path}/{each_test_pdb}.out"
    with open(each_pdb_dssr, 'r') as f:
        each_pdb_dssr_lines = f.readlines()
        each_pdb_dssr_lines = [x.strip() for x in each_pdb_dssr_lines]
        
    base_pair_lines = list()
    for line_idx, each_line in enumerate(each_pdb_dssr_lines):
        if "List of" in each_line and "base pair" in each_line:
            for each_base_pair_line_idx, each_base_pair_line in enumerate(each_pdb_dssr_lines[line_idx+2:]):
                if "****" in each_base_pair_line:
                    break
                if len(each_base_pair_line) > 0:
                    base_pair_lines.append(each_base_pair_line)
            break
    
    all_base_pairs = list() # Has all base pairs. Is of the form `[(1,15), (2,14), ...]`
    for each_base_pair_line in base_pair_lines:
        # each_base_pair_line is of the form: `1 A.G1   A.C15    G-C WC    19-XIX    cWW  cW-W`
        each_base_pair_line = each_base_pair_line.split()
        bp_first = each_base_pair_line[1]
        bp_second = each_base_pair_line[2]
        
        # Get the base pair type as well
        bp_type = each_base_pair_line[4]
        
        bp_first_num = int(''.join(filter(str.isdigit, bp_first)))  # Remove all non-digits from each string of the form `A.G1`
        bp_second_num = int(''.join(filter(str.isdigit, bp_second)))
        
        all_base_pairs.append((bp_first_num, bp_second_num, bp_type))
        
            
    # Find the nucleotides that are non-canonical pairs
    non_canonical_nums = list()
    if len(all_base_pairs) > 0:
        for each_bp in all_base_pairs:
            bp_one, bp_two, bp_type = each_bp
            if bp_type != "WC" and bp_type != "rWC" and bp_type != "Wobbles":
                non_canonical_nums += [bp_one, bp_two]
                
    # Find the nucleotides that are canonical pairs
    all_canonical_nums = list()
    if len(all_base_pairs) > 0:
        for each_bp in all_base_pairs:
            bp_one, bp_two, bp_type = each_bp
            if bp_type == "WC" or bp_type == "rWC" or bp_type == "Wobbles":
                all_canonical_nums += [bp_one, bp_two]
                
                
    # Find all hairpin loop nucleotides
    hairpin_loop_lines = list() # 
    for line_idx, each_line in enumerate(each_pdb_dssr_lines):
        if "List of" in each_line and "hairpin loop" in each_line:
            for each_hairpin_loop_line_idx, each_hairpin_loop_line in enumerate(each_pdb_dssr_lines[line_idx+1:]):
                if "****" in each_hairpin_loop_line:
                    break
                if len(each_hairpin_loop_line) > 0:
                    hairpin_loop_lines.append(each_hairpin_loop_line)
            break
            
    hairpin_loop_nums = list()
    for each_hairpin_loop_line_idx, each_hairpin_loop_line in enumerate(hairpin_loop_lines):
        if "hairpin loop" in each_hairpin_loop_line:
            # cur_hairpin_loop_line is of the form `nts=7 CUCAACU A.C6,A.U7,A.C8,A.A9,A.A10,A.C11,A.U12`
            cur_hairpin_loop_line = hairpin_loop_lines[each_hairpin_loop_line_idx+3]
            _, _, cur_hairpin_nucs = cur_hairpin_loop_line.strip().split()
            cur_hairpin_nucs = cur_hairpin_nucs.split(',')
            cur_hairpin_nucs = [int(''.join(filter(str.isdigit, x))) for x in cur_hairpin_nucs]
            # Get only the middle 4-5 nucleotides of the hairpin - the nucleotides at the very edge
            while len(cur_hairpin_nucs) > 5:
                cur_hairpin_nucs = cur_hairpin_nucs[1:-1]
            
            hairpin_loop_nums += cur_hairpin_nucs
    
    
    # Find all multiplet nucleotides
    multiplet_lines = list() # 
    for line_idx, each_line in enumerate(each_pdb_dssr_lines):
        if "List of" in each_line and "multiplet" in each_line:
            for each_multiplet_line_idx, each_multiplet_line in enumerate(each_pdb_dssr_lines[line_idx+1:]):
                if "****" in each_multiplet_line:
                    break
                if len(each_multiplet_line) > 0:
                    multiplet_lines.append(each_multiplet_line)
            break

    multiplet_nums = list() # is of the form `['1 nts=3 AAU A.A4,A.A6,A.U20', '2 nts=3 UCG A.U5,A.C7,A.G19', '3 nts=3 UAA A.U8,A.A17,A.A18']`
    for each_multiplet_line in multiplet_lines:
        _, _, _, multiplet_nucleotides = each_multiplet_line.strip().split()
        # multiplet_nucleotides is of the form `A.A4,A.A6,A.U20`
        for each_multiplet_nucleotide in multiplet_nucleotides.split(','):
            # each_multiplet_nucleotide is of the form `A.A4`. Have to remove all alphabets and special characters
            each_multiplet_nucleotide_num = int(''.join(filter(str.isdigit, each_multiplet_nucleotide)))
            multiplet_nums.append(each_multiplet_nucleotide_num)
#     if len(lone_paired_nums) == 0:
#         continue

    # Get all the lines from the summary of all nucleotides
    summary_lines = list() # 
    for line_idx, each_line in enumerate(each_pdb_dssr_lines):
        # The summary of each nucleotide has a small note that ends with this line
        if "break: no backbone linkage between" in each_line:
            for each_summary_line_idx, each_summary_line in enumerate(each_pdb_dssr_lines[line_idx+1:]):
                if "****" in each_summary_line:
                    break
                if len(each_summary_line) > 0:
                    summary_lines.append(each_summary_line)
            break
    
    # Find the nucleotides that are unpaired
    unpaired_nums = list()
    if len(all_base_pairs) > 0:
        bp_ones, bp_twos, bp_types = zip(*all_base_pairs)
        all_paired_nums = list(set(bp_ones + bp_twos))
        
        for each_summary_line in summary_lines:
            nuc_num, nuc_letter, _, nuc_num_in_chain, _, _ = each_summary_line.strip().split()
            nuc_num_in_chain = int(''.join(filter(str.isdigit, nuc_num_in_chain)))
            if nuc_num_in_chain not in all_paired_nums:
                unpaired_nums.append(nuc_num_in_chain)
    else:
        for each_summary_line in summary_lines:
            nuc_num, nuc_letter, _, nuc_num_in_chain, _, _ = each_summary_line.strip().split()
            nuc_num_in_chain = int(''.join(filter(str.isdigit, nuc_num_in_chain)))
            unpaired_nums.append(nuc_num_in_chain)
    
    # Find all the pseudoknotted nucleotides
    pseudoknot_nums = list()
    for each_pseudoknot_line in summary_lines:
        # Each line is of the form `1  G ( A.G1  0.008  anti,~C3'-endo,BI,canonical,non-pair-contact,helix-end,stem-end`
        each_pseudoknot_line = each_pseudoknot_line.strip().split()
        if each_pseudoknot_line[2] in ['[', ']', '{', '}', '<', '>']:
            pseudoknot_nums.append(int(each_pseudoknot_line[0]))
     
    
    # Find all canonical nested nucleotides
    canonical_nested_nums = list()
    if len(all_base_pairs) > 0:
        bp_ones, bp_twos, bp_types = zip(*all_base_pairs)
        all_paired_nums = list(set(bp_ones + bp_twos))
        for each_canonical_nested_pair_line in summary_lines:
            # Each line is of the form `1  G ( A.G1  0.008  anti,~C3'-endo,BI,canonical,non-pair-contact,helix-end,stem-end`
            nuc_num, nuc_letter, _, _, _, each_nucleotide_attributes = each_canonical_nested_pair_line.strip().split()
            each_nucleotide_attributes = each_nucleotide_attributes.split(',')
            is_pseudoknot = False # Checks if this nucleotide is part of a pseudoknot
            is_nested = False # Checks if this nucleotide is part of a nested base pair
            for each_attribute in each_nucleotide_attributes:
                if "helix" in each_attribute:
                    is_nested = True
                elif "pseudoknot" in each_attribute:
                    is_pseudoknot = True
            if int(nuc_num) in all_paired_nums and is_nested and not is_pseudoknot:
                canonical_nested_nums.append(int(nuc_num))
            
    
    # Find all nucleotides belonging to lone base pairs
    # Defined as "single tertiary base pair not associated with a helix" from https://www.biorxiv.org/content/10.1101/677310v1.full
    lone_base_pairs_nums = list()
    if len(all_base_pairs) > 0:
        bp_ones, bp_twos, bp_types = zip(*all_base_pairs)
        all_paired_nums = list(set(bp_ones + bp_twos))
        for each_canonical_nested_pair_line in summary_lines:
            # Each line is of the form `1  G ( A.G1  0.008  anti,~C3'-endo,BI,canonical,non-pair-contact,helix-end,stem-end`
            nuc_num, nuc_letter, _, _, _, each_nucleotide_attributes = each_canonical_nested_pair_line.strip().split()
            each_nucleotide_attributes = each_nucleotide_attributes.split(',')
            is_in_helix = False # Checks if this nucleotide is part of a helix
            for each_attribute in each_nucleotide_attributes:
                if "helix" in each_attribute or "stem" in each_attribute:
                    is_in_helix = True
            if int(nuc_num) in all_paired_nums and not is_in_helix:
                lone_base_pairs_nums.append(int(nuc_num))        
                
    # The numbers until now are the nucleotide numbers 1..N, but might be some M..N if this is some chain other than A
    # Have to change these numbers to be 0..N specific to this chain
    nuc_num_in_chain_to_zero_index_mapping = dict()
    for each_canonical_nested_pair_line in summary_lines:
        # Each line is of the form `1  G ( A.G1  0.008  anti,~C3'-endo,BI,canonical,non-pair-contact,helix-end,stem-end`
        nuc_num, nuc_letter, _, nuc_num_in_chain, _, _ = each_canonical_nested_pair_line.strip().split()
        nuc_num = int(nuc_num) - 1 # This is originally 1..N
        nuc_num_in_chain = int(''.join(filter(str.isdigit, nuc_num_in_chain))) # This is either 1..N or M..N depending on the chain
        nuc_num_in_chain_to_zero_index_mapping[str(nuc_num_in_chain)] = nuc_num
        
    def correct_mapping_in_nuc_num(given_input_list):
        return_list = list()
        for each_nuc_num in given_input_list:
            return_list.append(nuc_num_in_chain_to_zero_index_mapping[str(each_nuc_num)])
        return return_list
    

    
    if len(rna_object.dssr_full_seq) != len(nuc_num_in_chain_to_zero_index_mapping):
        print(f"SKIPPING {each_pdb_dssr} due to mismatch of number of residues")

#         print(each_pdb_dssr)
#         print(nuc_num_in_chain_to_zero_index_mapping)
#         print(pdb_idx, rna_object.dssr_full_seq, len(rna_object.dssr_full_seq))
#         print(unpaired_nums)
    
        return None
    
    try:
        unpaired_nums = correct_mapping_in_nuc_num(unpaired_nums)
        lone_base_pairs_nums = correct_mapping_in_nuc_num(lone_base_pairs_nums)
        pseudoknot_nums = correct_mapping_in_nuc_num(pseudoknot_nums)
        multiplet_nums = correct_mapping_in_nuc_num(multiplet_nums)
        non_canonical_nums = correct_mapping_in_nuc_num(non_canonical_nums)
        all_canonical_nums = correct_mapping_in_nuc_num(all_canonical_nums)
        hairpin_loop_nums = correct_mapping_in_nuc_num(hairpin_loop_nums)
        canonical_nested_nums = correct_mapping_in_nuc_num(canonical_nested_nums)
        
        bp_ones, bp_twos, bp_types = zip(*all_base_pairs)
        all_paired_nums = list(set(bp_ones + bp_twos))
        
        reshaped_all_base_pairs = list()
        for i in range(len(bp_ones)):
            reshaped_all_base_pairs.append(bp_ones[i])
            reshaped_all_base_pairs.append(bp_twos[i])
        reshaped_all_base_pairs = correct_mapping_in_nuc_num(reshaped_all_base_pairs)
        # reshaped_all_base_pairs has all the base pairs equivalent to doing .reshape(-1) on the numbers
        
    except Exception as e:
        print(f"Unable to correct mappings for {each_pdb_dssr}", e)
        return None
        
#     print(each_pdb_file, each_pdb_dssr, len(rna_object.dssr_full_seq), rna_object.dssr_full_seq)
#     print(f"All base pairs: {all_base_pairs}")
#     print(f"All unpaired bases: {unpaired_nums}")
#     print(f"All non-canonicals: {non_canonical_nums}")
#     print(f"All canonicals: {all_canonical_nums}")
#     print(f"All hairpins: {hairpin_loop_nums}")
#     print(f"All multiplets: {multiplet_nums}")
#     print(f"All pseudoknots: {pseudoknot_nums}")
#     print(f"All canonical nested: {canonical_nested_nums}")
#     print(f"All lone base pairs: {lone_base_pairs_nums}")


    return [len(rna_object.dssr_full_seq), all_gt_angles, all_predicted_angles, nuc_num_in_chain_to_zero_index_mapping, unpaired_nums, lone_base_pairs_nums, pseudoknot_nums, multiplet_nums, non_canonical_nums, all_canonical_nums, hairpin_loop_nums, canonical_nested_nums, reshaped_all_base_pairs] 

In [8]:
total_num_nucleotides = 0

unpaired_errors = [list() for _ in range(9)]
lone_pairs_errors = [list() for _ in range(9)]
pseudoknots_errors = [list() for _ in range(9)]
multiplets_errors = [list() for _ in range(9)]
noncanonical_pairs_errors = [list() for _ in range(9)]
all_canonical_pairs_errors = [list() for _ in range(9)]
canonical_nested_pairs_errors = [list() for _ in range(9)]
hairpin_loop_errors = [list() for _ in range(9)]

for pdb_idx, each_test_pdb in tqdm(enumerate(testing_pdbs), total=len(testing_pdbs)):
    
    regions_output = get_all_regions(each_test_pdb, pdbs_path)
    if regions_output is not None:
        len_seq, all_gt_angles, all_predicted_angles, nuc_num_in_chain_to_zero_index_mapping, unpaired_nums, lone_base_pairs_nums, pseudoknot_nums, multiplet_nums, non_canonical_nums, all_canonical_nums, hairpin_loop_nums, canonical_nested_nums, reshaped_all_base_pairs = regions_output
    else:
        continue
        
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, unpaired_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        unpaired_errors[angle_idx] += all_errors[angle_idx]
    
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, lone_base_pairs_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        lone_pairs_errors[angle_idx] += all_errors[angle_idx]
    
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, pseudoknot_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        pseudoknots_errors[angle_idx] += all_errors[angle_idx]
    
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, multiplet_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        multiplets_errors[angle_idx] += all_errors[angle_idx]
    
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, non_canonical_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        noncanonical_pairs_errors[angle_idx] += all_errors[angle_idx]
        
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, all_canonical_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        all_canonical_pairs_errors[angle_idx] += all_errors[angle_idx]
    
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, canonical_nested_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        canonical_nested_pairs_errors[angle_idx] += all_errors[angle_idx]
        
    return_maes, all_errors = calculate_mae_for_positions(all_predicted_angles, all_gt_angles, hairpin_loop_nums, prediction_method="RNA Transformer Decoder")
    for angle_idx,each_angle_error in enumerate(all_errors):
        hairpin_loop_errors[angle_idx] += all_errors[angle_idx]
        
    total_num_nucleotides += len_seq
        
    if pdb_idx > 1000:
        break

def print_region_maes(region_errors, region_name="Unknown"):
    print(f"{region_name} ({len(region_errors[0])/total_num_nucleotides*100:.2f}\%)", end=',')
    for each_angle_idx, each_angle_errors in enumerate(region_errors):
        try:
            mae = sum(each_angle_errors)/len(each_angle_errors)
        except:
            mae = 0
        print(f"{mae:.2f}", end=',')
    print()
        

print(f"MAEs for various regions of TorRNA")
for each_angle_name in all_dihedral_angle_names:
    print(each_angle_name, end=',')
print()

print_region_maes(unpaired_errors, region_name="Unpaired")
print_region_maes(lone_pairs_errors, region_name="Lone Pairs")
print_region_maes(pseudoknots_errors, region_name="Pseudoknots")
print_region_maes(multiplets_errors, region_name="Multiplets")
print_region_maes(noncanonical_pairs_errors, region_name="Non-Canonical Pairs")
print_region_maes(all_canonical_pairs_errors, region_name="All Canonical Pairs")
print_region_maes(canonical_nested_pairs_errors, region_name="Canonical Nested Pairs")
print_region_maes(hairpin_loop_errors, region_name="Hairpin Loops")

  2%|▏         | 4/172 [00:23<16:09,  5.77s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5LYU_B.out due to mismatch of number of residues


  6%|▌         | 10/172 [00:47<10:34,  3.92s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6O7H_G.out due to mismatch of number of residues


  7%|▋         | 12/172 [00:56<12:01,  4.51s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5WTI_B.out due to mismatch of number of residues


 10%|█         | 18/172 [01:19<09:46,  3.81s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7UZX_G.out not enough values to unpack (expected 3, got 0)


 12%|█▏        | 20/172 [01:26<09:07,  3.60s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6BK8_E.out due to mismatch of number of residues


 12%|█▏        | 21/172 [01:30<08:54,  3.54s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2XDB_G.out due to mismatch of number of residues


 13%|█▎        | 22/172 [01:35<10:17,  4.11s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5ZWO_G.out not enough values to unpack (expected 3, got 0)


 14%|█▍        | 24/172 [01:42<09:14,  3.75s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1U6B_C.out not enough values to unpack (expected 3, got 0)


 16%|█▌        | 27/172 [01:54<10:04,  4.17s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//3WQY_C.out '18'


 20%|██        | 35/172 [02:28<09:31,  4.17s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1C9S_W.out not enough values to unpack (expected 3, got 0)


 24%|██▍       | 41/172 [02:55<09:17,  4.26s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8I45_A.out due to mismatch of number of residues


 27%|██▋       | 46/172 [03:12<07:28,  3.56s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5MWI_A.out not enough values to unpack (expected 3, got 0)


 29%|██▉       | 50/172 [03:30<09:10,  4.51s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2DU3_D.out '18'


 31%|███▏      | 54/172 [03:46<08:27,  4.30s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5AH5_C.out due to mismatch of number of residues


 34%|███▍      | 59/172 [04:03<06:29,  3.45s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7XPL_G.out not enough values to unpack (expected 3, got 0)


 42%|████▏     | 72/172 [04:46<04:56,  2.96s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6NUE_H.out not enough values to unpack (expected 3, got 0)


 42%|████▏     | 73/172 [04:48<04:32,  2.76s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2ZUE_B.out due to mismatch of number of residues


 44%|████▍     | 76/172 [04:57<04:24,  2.76s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7WWV_M.out not enough values to unpack (expected 3, got 0)


 47%|████▋     | 81/172 [05:12<04:46,  3.15s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1F7U_B.out '18'


 49%|████▉     | 84/172 [05:22<04:46,  3.26s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//4RDX_C.out due to mismatch of number of residues


 55%|█████▍    | 94/172 [05:55<03:40,  2.83s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8FZT_A.out not enough values to unpack (expected 3, got 0)


 58%|█████▊    | 100/172 [06:09<03:00,  2.51s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5EEU_W.out not enough values to unpack (expected 3, got 0)


 60%|██████    | 104/172 [06:20<03:06,  2.74s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8AW3_1.out '55'


 66%|██████▌   | 113/172 [06:46<02:41,  2.73s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8BVM_U.out due to mismatch of number of residues


 68%|██████▊   | 117/172 [06:57<02:30,  2.74s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5OIK_P.out not enough values to unpack (expected 3, got 0)


 69%|██████▉   | 119/172 [07:02<02:09,  2.44s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7UOB_T.out not enough values to unpack (expected 3, got 0)


 70%|███████   | 121/172 [07:10<02:43,  3.21s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5U30_B.out '75'


 72%|███████▏  | 123/172 [07:13<02:03,  2.52s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7SBA_Z.out not enough values to unpack (expected 3, got 0)


 73%|███████▎  | 125/172 [07:18<01:56,  2.48s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7ZHG_5.out not enough values to unpack (expected 3, got 0)


 75%|███████▌  | 129/172 [07:32<02:20,  3.27s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7K16_P.out due to mismatch of number of residues


 83%|████████▎ | 143/172 [08:11<01:28,  3.04s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1MNX_A.out due to mismatch of number of residues


 84%|████████▎ | 144/172 [08:14<01:19,  2.84s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1GTN_W.out not enough values to unpack (expected 3, got 0)


 85%|████████▍ | 146/172 [08:19<01:08,  2.65s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5LWJ_A.out '35'


 90%|████████▉ | 154/172 [08:43<00:51,  2.87s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5Z71_A.out not enough values to unpack (expected 3, got 0)


 91%|█████████ | 156/172 [08:47<00:42,  2.65s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5ED1_B.out not enough values to unpack (expected 3, got 0)


 92%|█████████▏| 159/172 [08:57<00:40,  3.15s/it]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1E8O_E.out '15'


 94%|█████████▍| 162/172 [09:07<00:31,  3.18s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6BSI_R.out due to mismatch of number of residues


100%|██████████| 172/172 [09:38<00:00,  3.37s/it]

MAEs for various regions of TorRNA
alpha,beta,gamma,delta,chi,epsilon,zeta,eta,theta,
Unpaired (27.77\%),61.56,29.32,48.21,22.45,32.14,26.75,46.76,48.18,48.83,
Lone Pairs (5.99\%),44.43,22.18,37.18,16.28,22.48,22.96,42.93,25.10,36.16,
Pseudoknots (2.71\%),38.30,18.51,30.23,11.88,13.38,18.73,27.96,23.60,27.61,
Multiplets (9.24\%),50.89,24.89,42.53,17.36,25.07,22.88,42.60,27.60,37.04,
Non-Canonical Pairs (25.61\%),47.24,24.14,41.32,15.58,24.48,22.42,37.94,25.94,33.89,
All Canonical Pairs (50.51\%),29.51,15.06,27.54,7.82,9.37,13.80,15.18,11.10,14.43,
Canonical Nested Pairs (46.49\%),33.32,16.89,30.45,9.01,12.43,15.13,18.51,13.95,17.30,
Hairpin Loops (11.72\%),62.57,26.55,42.77,22.75,28.52,26.57,46.18,48.98,45.43,





In [9]:
# ------------------- Check SPOT-RNA-1D predictions in various regions -------------------

import os
if os.path.exists(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl"):
	with open(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl", "rb") as input_file:
		print(f"Found RNA objects precomputed")
		rna_objects = pickle.load(input_file)
else:
	rna_objects = list()
	print(f"Making RNA objects")
	for each_pdb_file in tqdm(testing_pdbs):
		try:
			rna_objects.append(RNA(each_pdb_file, calc_rna_fm_embeddings=True, load_dssr_dihedrals=True, load_coords=False))
		except:
			pass
	with open(f"/ssd_scratch/users/sriram.devata/rna_objects_temp.pkl", "wb") as output_file:
		pickle.dump(rna_objects, output_file)
    
total_num_nucleotides = 0

unpaired_errors = [list() for _ in range(9)]
lone_pairs_errors = [list() for _ in range(9)]
pseudoknots_errors = [list() for _ in range(9)]
multiplets_errors = [list() for _ in range(9)]
noncanonical_pairs_errors = [list() for _ in range(9)]
all_canonical_pairs_errors = [list() for _ in range(9)]
canonical_nested_pairs_errors = [list() for _ in range(9)]
hairpin_loop_errors = [list() for _ in range(9)]

spot_rna_pred_location = "/home2/sriram.devata/rna_project/SPOT-RNA-1D/outputs/"
for each_pdb_idx in tqdm(range(len(testing_pdbs))):
    
    try:
        pdb_file_name = testing_pdbs[each_pdb_idx]
        rna_object = rna_objects[each_pdb_idx]
        gt_angles = rna_object.dssr_torsion_angles
        # print(pdb_file_name, len(rna_object.dssr_full_seq))

        pdb_code = pdb_file_name.split('/')[-1].replace(".pdb", "")
        spot_rna_pred_file = f"{spot_rna_pred_location}/{pdb_code}.txt"

        with open(spot_rna_pred_file, "r") as f:
            spot_rna_pred_lines = f.readlines()

        each_rna_predicted_angles = list()
        each_rna_gt_angles = list()
        # print(pdb_code, gt_angles.shape)
        for each_residue_idx in range(len(gt_angles)):
            spot_rna_pred_line = spot_rna_pred_lines[each_residue_idx+2]	# first two lines are column headers and an empty line
            cur_res_spot_rna_pred_angles = [float(x) for x in spot_rna_pred_line.strip().split()[2:]]	# first two columns are number and basename
    
                        # correcting the order of the angles, SPOT-RNA-1D has a different ordering of angles
            index_correction = [0, 1, 2, 3, 6, 4, 5, 7, 8]
            cur_res_spot_rna_pred_angles = [cur_res_spot_rna_pred_angles[x] for x in index_correction]

            cur_res_gt_angles = make_rad_angle_tensor_to_deg_angle_list(gt_angles[each_residue_idx])

            each_rna_predicted_angles.append(cur_res_spot_rna_pred_angles)
            each_rna_gt_angles.append(cur_res_gt_angles)


        regions_output = get_all_regions(pdb_code, pdbs_path)
        if regions_output is not None:
            len_seq, all_gt_angles, all_predicted_angles, nuc_num_in_chain_to_zero_index_mapping, unpaired_nums, lone_base_pairs_nums, pseudoknot_nums, multiplet_nums, non_canonical_nums, all_canonical_nums, hairpin_loop_nums, canonical_nested_nums, reshaped_all_base_pairs = regions_output
        else:
            continue
        
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, unpaired_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            unpaired_errors[angle_idx] += all_errors[angle_idx]
        
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, lone_base_pairs_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            lone_pairs_errors[angle_idx] += all_errors[angle_idx]
            
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, pseudoknot_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            pseudoknots_errors[angle_idx] += all_errors[angle_idx]
            
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, multiplet_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            multiplets_errors[angle_idx] += all_errors[angle_idx]
            
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, non_canonical_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            noncanonical_pairs_errors[angle_idx] += all_errors[angle_idx]
        
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, all_canonical_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            all_canonical_pairs_errors[angle_idx] += all_errors[angle_idx]
        
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, canonical_nested_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            canonical_nested_pairs_errors[angle_idx] += all_errors[angle_idx]
            
        return_maes, all_errors = calculate_mae_for_positions(each_rna_predicted_angles, all_gt_angles, hairpin_loop_nums, prediction_method="RNA Transformer Decoder")
        for angle_idx,each_angle_error in enumerate(all_errors):
            hairpin_loop_errors[angle_idx] += all_errors[angle_idx]
        
        total_num_nucleotides += len_seq

        if each_pdb_idx > 1000:
            break
            
    except Exception as e:
        print("Skipping because", e)
        break
        pass

def print_region_maes(region_errors, region_name="Unknown"):
    print(f"{region_name} ({len(region_errors[0])/total_num_nucleotides*100:.2f}\%)", end=',')
    for each_angle_idx, each_angle_errors in enumerate(region_errors):
        try:
            mae = sum(each_angle_errors)/len(each_angle_errors)
        except:
            mae = 0
        print(f"{mae:.2f}", end=',')
    print()
        

print(f"MAEs for various regions of SPOT-RNA-1D")
for each_angle_name in all_dihedral_angle_names:
    print(each_angle_name, end=',')
print()

print_region_maes(unpaired_errors, region_name="Unpaired")
print_region_maes(lone_pairs_errors, region_name="Lone Pairs")
print_region_maes(pseudoknots_errors, region_name="Pseudoknots")
print_region_maes(multiplets_errors, region_name="Multiplets")
print_region_maes(noncanonical_pairs_errors, region_name="Non-Canonical Pairs")
print_region_maes(all_canonical_pairs_errors, region_name="All Canonical Pairs")
print_region_maes(canonical_nested_pairs_errors, region_name="Canonical Nested Pairs")
print_region_maes(hairpin_loop_errors, region_name="Hairpin Loops")


Found RNA objects precomputed


  2%|▏         | 4/172 [00:08<04:21,  1.55s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5LYU_B.out due to mismatch of number of residues


  6%|▌         | 10/172 [00:12<02:17,  1.18it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6O7H_G.out due to mismatch of number of residues


  7%|▋         | 12/172 [00:15<03:09,  1.19s/it]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5WTI_B.out due to mismatch of number of residues


 10%|█         | 18/172 [00:21<02:26,  1.05it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7UZX_G.out not enough values to unpack (expected 3, got 0)


 12%|█▏        | 20/172 [00:22<01:57,  1.29it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6BK8_E.out due to mismatch of number of residues


 13%|█▎        | 22/172 [00:23<01:35,  1.58it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2XDB_G.out due to mismatch of number of residues
Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5ZWO_G.out not enough values to unpack (expected 3, got 0)


 14%|█▍        | 24/172 [00:25<01:40,  1.47it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1U6B_C.out not enough values to unpack (expected 3, got 0)


 16%|█▌        | 27/172 [00:27<02:02,  1.18it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//3WQY_C.out '18'


 20%|██        | 35/172 [00:34<01:23,  1.64it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1C9S_W.out not enough values to unpack (expected 3, got 0)


 24%|██▍       | 41/172 [00:37<01:16,  1.72it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8I45_A.out due to mismatch of number of residues


 27%|██▋       | 47/172 [00:40<01:01,  2.03it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5MWI_A.out not enough values to unpack (expected 3, got 0)


 29%|██▉       | 50/172 [00:43<01:24,  1.44it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2DU3_D.out '18'


 31%|███▏      | 54/172 [00:46<01:35,  1.23it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5AH5_C.out due to mismatch of number of residues


 34%|███▍      | 59/172 [00:49<01:08,  1.64it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7XPL_G.out not enough values to unpack (expected 3, got 0)


 42%|████▏     | 72/172 [00:57<00:52,  1.90it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6NUE_H.out not enough values to unpack (expected 3, got 0)


 42%|████▏     | 73/172 [00:58<01:10,  1.40it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//2ZUE_B.out due to mismatch of number of residues


 44%|████▍     | 76/172 [01:00<01:00,  1.58it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7WWV_M.out not enough values to unpack (expected 3, got 0)


 47%|████▋     | 81/172 [01:03<00:51,  1.77it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1F7U_B.out '18'


 49%|████▉     | 84/172 [01:06<01:13,  1.19it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//4RDX_C.out due to mismatch of number of residues


 55%|█████▍    | 94/172 [01:17<01:07,  1.15it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8FZT_A.out not enough values to unpack (expected 3, got 0)


 58%|█████▊    | 100/172 [01:18<00:23,  3.10it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5EEU_W.out not enough values to unpack (expected 3, got 0)


 61%|██████    | 105/172 [01:21<00:31,  2.12it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8AW3_1.out '55'


 66%|██████▌   | 113/172 [01:27<00:41,  1.44it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//8BVM_U.out due to mismatch of number of residues


 68%|██████▊   | 117/172 [01:30<00:38,  1.41it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5OIK_P.out not enough values to unpack (expected 3, got 0)


 69%|██████▉   | 119/172 [01:31<00:44,  1.19it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7UOB_T.out not enough values to unpack (expected 3, got 0)


 70%|███████   | 121/172 [01:34<00:50,  1.01it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5U30_B.out '75'


 72%|███████▏  | 124/172 [01:35<00:24,  1.97it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7SBA_Z.out not enough values to unpack (expected 3, got 0)


 73%|███████▎  | 125/172 [01:36<00:30,  1.54it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7ZHG_5.out not enough values to unpack (expected 3, got 0)


 75%|███████▌  | 129/172 [01:40<00:39,  1.08it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//7K16_P.out due to mismatch of number of residues


 83%|████████▎ | 143/172 [01:53<00:27,  1.04it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1MNX_A.out due to mismatch of number of residues


 84%|████████▎ | 144/172 [01:54<00:25,  1.09it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1GTN_W.out not enough values to unpack (expected 3, got 0)


 85%|████████▍ | 146/172 [01:55<00:18,  1.41it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5LWJ_A.out '35'


 90%|████████▉ | 154/172 [02:01<00:13,  1.33it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5Z71_A.out not enough values to unpack (expected 3, got 0)


 91%|█████████ | 156/172 [02:03<00:12,  1.32it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//5ED1_B.out not enough values to unpack (expected 3, got 0)


 92%|█████████▏| 159/172 [02:05<00:11,  1.10it/s]

Unable to correct mappings for /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//1E8O_E.out '15'


 94%|█████████▍| 162/172 [02:08<00:08,  1.14it/s]

SKIPPING /home2/sriram.devata/rna_project/rna_structure/data/raw_files/all_torrna_pdbs//6BSI_R.out due to mismatch of number of residues


100%|██████████| 172/172 [02:17<00:00,  1.25it/s]

MAEs for various regions of SPOT-RNA-1D
alpha,beta,gamma,delta,chi,epsilon,zeta,eta,theta,
Unpaired (27.77\%),63.52,29.97,48.91,23.69,32.58,26.79,47.87,50.83,50.05,
Lone Pairs (5.99\%),46.32,23.19,37.55,18.67,23.55,22.73,43.84,27.46,36.57,
Pseudoknots (2.71\%),40.46,18.61,31.47,13.76,14.58,18.80,28.78,25.31,28.30,
Multiplets (9.24\%),53.55,25.33,42.99,18.81,25.29,22.78,42.76,28.56,37.53,
Non-Canonical Pairs (25.61\%),50.10,24.64,41.66,17.91,25.48,22.66,38.76,28.14,34.86,
All Canonical Pairs (50.51\%),32.72,15.66,28.26,11.66,12.15,15.44,16.28,15.48,16.05,
Canonical Nested Pairs (46.49\%),36.56,17.33,31.25,12.41,14.54,16.73,19.47,17.94,18.79,
Hairpin Loops (11.72\%),62.66,27.14,42.92,23.57,28.47,26.55,46.69,51.52,46.73,





In [17]:
torrna_maes_string="""
MAEs for various regions of TorRNA
alpha,beta,gamma,delta,chi,epsilon,zeta,eta,theta,
Unpaired (27.77\%),61.56,29.32,48.21,22.45,32.14,26.75,46.76,48.18,48.83,
Lone Pairs (5.99\%),44.43,22.18,37.18,16.28,22.48,22.96,42.93,25.10,36.16,
Pseudoknots (2.71\%),38.30,18.51,30.23,11.88,13.38,18.73,27.96,23.60,27.61,
Multiplets (9.24\%),50.89,24.89,42.53,17.36,25.07,22.88,42.60,27.60,37.04,
Non-Canonical Pairs (25.61\%),47.24,24.14,41.32,15.58,24.48,22.42,37.94,25.94,33.89,
All Canonical Pairs (50.51\%),29.51,15.06,27.54,7.82,9.37,13.80,15.18,11.10,14.43,
Canonical Nested Pairs (46.49\%),33.32,16.89,30.45,9.01,12.43,15.13,18.51,13.95,17.30,
Hairpin Loops (11.72\%),62.57,26.55,42.77,22.75,28.52,26.57,46.18,48.98,45.43,"""

spotrna1d_maes_string="""
MAEs for various regions of SPOT-RNA-1D
alpha,beta,gamma,delta,chi,epsilon,zeta,eta,theta,
Unpaired (27.77\%),63.52,29.97,48.91,23.69,32.58,26.79,47.87,50.83,50.05,
Lone Pairs (5.99\%),46.32,23.19,37.55,18.67,23.55,22.73,43.84,27.46,36.57,
Pseudoknots (2.71\%),40.46,18.61,31.47,13.76,14.58,18.80,28.78,25.31,28.30,
Multiplets (9.24\%),53.55,25.33,42.99,18.81,25.29,22.78,42.76,28.56,37.53,
Non-Canonical Pairs (25.61\%),50.10,24.64,41.66,17.91,25.48,22.66,38.76,28.14,34.86,
All Canonical Pairs (50.51\%),32.72,15.66,28.26,11.66,12.15,15.44,16.28,15.48,16.05,
Canonical Nested Pairs (46.49\%),36.56,17.33,31.25,12.41,14.54,16.73,19.47,17.94,18.79,
Hairpin Loops (11.72\%),62.66,27.14,42.92,23.57,28.47,26.55,46.69,51.52,46.73,"""


print("Region,alpha,beta,gamma,delta,chi,epsilon,zeta,eta,theta,".replace(',', ' & '), "\\\\")
print("\midrule")
for torrna_line,spotrna1d_line in zip(torrna_maes_string.split('\n'),spotrna1d_maes_string.split('\n')):
    if '%' not in torrna_line:
        continue

    torrna_line = torrna_line.split(',')
    spotrna1d_line = spotrna1d_line.split(',')
    assert torrna_line[0] == spotrna1d_line[0]
    
    print(f"{torrna_line[0]}", end=' & ')
    for each_torrna_mae,each_spotrna_mae in zip(torrna_line[1:-1],spotrna1d_line[1:-1]):
        # print(f"\\shortstack{{ {each_torrna_mae}\\\\({each_spotrna_mae}) }}", end=' & ')
        if float(each_torrna_mae) < float(each_spotrna_mae):
            print(f"\\shortstack{{ {each_torrna_mae}\\\\({each_spotrna_mae}) }}", end=' & ')
        else:
            print(f"\\shortstack{{ {each_torrna_mae}\\\\(\\textbf {{ {each_spotrna_mae} }} ) }}", end=' & ')
            
#         print(f"{each_torrna_mae}({each_spotrna_mae})", end=' & ')
    print("\\\\")

Region & alpha & beta & gamma & delta & chi & epsilon & zeta & eta & theta &  \\
\midrule
Unpaired (27.77\%) & \shortstack{ 61.56\\(63.52) } & \shortstack{ 29.32\\(29.97) } & \shortstack{ 48.21\\(48.91) } & \shortstack{ 22.45\\(23.69) } & \shortstack{ 32.14\\(32.58) } & \shortstack{ 26.75\\(26.79) } & \shortstack{ 46.76\\(47.87) } & \shortstack{ 48.18\\(50.83) } & \shortstack{ 48.83\\(50.05) } & \\
Lone Pairs (5.99\%) & \shortstack{ 44.43\\(46.32) } & \shortstack{ 22.18\\(23.19) } & \shortstack{ 37.18\\(37.55) } & \shortstack{ 16.28\\(18.67) } & \shortstack{ 22.48\\(23.55) } & \shortstack{ 22.96\\(\textbf { 22.73 } ) } & \shortstack{ 42.93\\(43.84) } & \shortstack{ 25.10\\(27.46) } & \shortstack{ 36.16\\(36.57) } & \\
Pseudoknots (2.71\%) & \shortstack{ 38.30\\(40.46) } & \shortstack{ 18.51\\(18.61) } & \shortstack{ 30.23\\(31.47) } & \shortstack{ 11.88\\(13.76) } & \shortstack{ 13.38\\(14.58) } & \shortstack{ 18.73\\(18.80) } & \shortstack{ 27.96\\(28.78) } & \shortstack{ 23.60\\(25.3