# Sequence Identity Split Creation

This notebook creates the Sequence Identity split used to evaluate gRNAde on biologically dissimilar clusters of RNAs.
We cluster the sequences based on nucleotide similarity using CD-HIT (Fu et al., 2012) with an identity threshold of 80% to create training, validation and test sets.

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

import os
import subprocess
import numpy as np
import torch
from tqdm import tqdm
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

In [None]:
def run_cd_hit_est(
        input_sequences, 
        identity_threshold = 0.9,
        word_size = 2,
        input_file = "input",
        output_file = "output"
    ):
    # https://manpages.ubuntu.com/manpages/impish/man1/cd-hit-est.1.html
        
    # Write input sequences to the temporary input file
    SeqIO.write(input_sequences, input_file, "fasta")

    # Run CD-HIT-EST
    cmd = [
        "cd-hit-est",
        "-i", input_file,
        "-o", output_file,
        "-c", str(identity_threshold), # Sequence identity threshold (e.g., 90%)
        "-n", str(word_size),          # Word size for sequence comparisson, larger is better (default: 2)
    ]
    subprocess.run(cmd, check=True)

    # Read clustered sequences from the temporary output file
    clustered_sequences = list(SeqIO.parse(output_file, "fasta"))

    # Process the clustering output
    seq_idx_to_cluster = {}
    with open(output_file + ".clstr", "r") as f:
        current_cluster = None
        for line in f:
            if line.startswith(">"):
                current_cluster = int(line.strip().split(" ")[1])
            else:
                sequence_id = int(line.split(">")[1].split("...")[0])
                seq_idx_to_cluster[sequence_id] = current_cluster

    # Delete temporary files
    # os.remove(input_file)
    # os.remove(output_file)
    # os.remove(output_file + ".clstr")

    return clustered_sequences, seq_idx_to_cluster

In [None]:
data_list = torch.load(os.path.join("../data/", "processed.pt"))
seq_list = []

for idx, data in enumerate(data_list):
    seq = data["seq"]
    seq_list.append(SeqRecord(Seq(seq), id=str(idx)))  # the ID for each sequence is its index in data_list
print(len(seq_list))

In [None]:
# Cluster at 80% sequence identity (lowest currently possible)
clustered_sequences, seq_idx_to_cluster = run_cd_hit_est(seq_list, identity_threshold=0.8, word_size=3)

In [None]:
# Number of clusters
len(clustered_sequences)

In [None]:
# Sanity check: it seems short sequences are not being clustered
try:
    # Why does this fail? Guess: sequences are too short?
    assert len(seq_idx_to_cluster.keys()) == len(seq_list)
except:
    # Which sequence indices are not clustered? What are their corresponding sequences?
    idx_not_clustered = list(set(list(range(len(data_list)))) - set(seq_idx_to_cluster.keys()))
    print("Number of missing indices after clustering: ", len(idx_not_clustered))
    
    seq_lens = []
    for idx in idx_not_clustered:
        seq_lens.append(len(data_list[idx]["seq"]))
    print("Sequence lengths for missing indices:")
    print(f"    Distribution: {np.mean(seq_lens)} +- {np.std(seq_lens)}")
    print(f"    Max: {np.max(seq_lens)}, Min: {np.min(seq_lens)}")

In [None]:
# Cluster sizes: number of sequences in each cluster
cluster_ids, cluster_sizes = np.unique(list(seq_idx_to_cluster.values()), return_counts=True)
for id, size in zip(cluster_ids[:10], cluster_sizes[:10]):
    print(id, size)
# Print some examples

In [None]:
# seq_idx_to_cluster: (index in data_list: cluster ID)
# (NEW) cluster_to_seq_idx_list: (cluster ID: list of indices in data_list)
cluster_to_seq_idx_list = {}
for seq_idx, cluster in seq_idx_to_cluster.items():
    if cluster in cluster_to_seq_idx_list.keys():
        cluster_to_seq_idx_list[cluster].append(seq_idx)
    else:
        cluster_to_seq_idx_list[cluster] = [seq_idx]

In [None]:
# Cluster sizes: number of structures (total) in each cluster
cluster_sizes_structs = []
for cluster, seq_idx_list in cluster_to_seq_idx_list.items():
    count = 0
    for seq_idx in seq_idx_list:
        count += len(data_list[seq_idx]['coords_list'])
    cluster_sizes_structs.append(count)

# Cluster sequence size and structure size
print("cluster ID, # sequences, total # structures")
for id, size, size_structs in zip(cluster_ids[:10], cluster_sizes[:10], cluster_sizes_structs[:10]):
    print(id, size, size_structs)
# Print some examples

In [None]:
test_idx_list = []
val_idx_list = []
train_idx_list = []

# Some heuristics
# * Add samples to validation and test sets till their sizes are filled (200 samples), after which add everything to the train set
# * Do not add very large seqeuence clusters (sizes > 100) to validation or test set
# 

for cluster, seq_idx_list in cluster_to_seq_idx_list.items():
    
    if len(test_idx_list) < 200 and cluster_sizes[cluster] < 100:
        test_idx_list += seq_idx_list
    elif len(val_idx_list) < 200 and cluster_sizes[cluster] < 100:
        val_idx_list += seq_idx_list
    else:
        train_idx_list += seq_idx_list

In [None]:
# Add all the sequences that were not assigned any clusters into the training set
try:
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)
except:
    train_idx_list += idx_not_clustered
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)

In [None]:
torch.save((train_idx_list, val_idx_list, test_idx_list), "../data/seq_identity_split.pt")