# Data Split Creation

This notebook creates data splits used to evaluate gRNAde on randomly split RNAs.

**Workflow:**
1. Order the samples based on some metric:
    - Avg. RMSD among available structures
    - Total structures available
2. Training, validation, and test splits become progressively harder.
    - Top 100 samples with highest metric -- test set.
    - Next 100 samples with highest metric -- validation set.
    - All remaining samples -- training set.
    - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.

Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts).

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

import os
import subprocess
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset
import seaborn as sns

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

from src.data_utils import get_avg_rmsds

In [None]:
# Load data list
data_list = torch.load(os.path.join("../data/", "processed.pt"))
print(len(data_list))

# List of sample sequences (used to create .fasta input file)
seq_list = []
for idx, data in enumerate(data_list):
    seq = data["seq"]
    seq_list.append(SeqRecord(Seq(seq), id=str(idx)))  # the ID for each sequence is its index in data_list

# List of intra-sequence avg. RMSDs
rmsd_list = get_avg_rmsds(data_list)

# List of number of structures per sequence
count_list = [len(data["coords_list"]) for data in data_list]

assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)

In [None]:
# RMSD Split

# Zip the two lists together
zipped = zip(list(range(len(data_list))), rmsd_list)
# Sort the zipped list based on the values (descending order, highest first)
sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)
# Unzip the sorted list back into two separate lists
sorted_data_list_idx, sorted_rmsd_list = zip(*sorted_zipped)

In [None]:
test_idx_list = []
val_idx_list = []
train_idx_list = []

for idx, avg_rmsd in sorted_zipped:
    
    num_structs = count_list[idx]  # len(data_list[idx]['coords_list'])
    
    seq_len = len(seq_list[idx])

    if seq_len < 1000 and seq_len > 10:

        # Test set
        if len(test_idx_list) < 100:
            test_idx_list.append(idx)
        
        # Validation set
        elif len(val_idx_list) < 100:
            val_idx_list.append(idx)
    
        # Training set
        else:
            train_idx_list.append(idx)
    
    # Training set
    else:
        train_idx_list.append(idx)

In [None]:
assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)

In [None]:
torch.save((train_idx_list, val_idx_list, test_idx_list), "../data/random_rmsd_split.pt")