# Dataset creation pipeline

In [None]:
import pandas as pd
import requests
import gzip
import os
import io
from collections import defaultdict
import time
import random
import math
from tqdm import tqdm 
from Bio import SeqIO

### Download TSA records
First off. Download mastertable from: https://www.ncbi.nlm.nih.gov/Traces/wgs/?page=1&view=tsa

Run the below to download the per-species raw TSA FASTAs. This runs on `'PRI|MAM|ROD'` but is easily extendable to other species families. 

In [None]:
species = 'PRI|MAM|ROD' # "PRI" (primates) or "MAM" (mammals) or "ROD" (rodents)
tsa_master = pd.read_csv("tsa_lookup_mastertable.csv")

# print all unique values in the "div_s" column
print(tsa_master['div_s'].unique())

In [None]:
# Define species families
species = 'PRI|MAM|ROD' # "PRI" (primates) or "MAM" (mammals) or "ROD" (rodents)
tsa_master = pd.read_csv("tsa_lookup_mastertable.csv")
os.makedirs(f"tsa_species_fasta_output_raw_{species}", exist_ok=True)
tsa_raw_output_dir = f"tsa_species_fasta_output_raw_{species}"

In [None]:
# From column "div_s" extract all entries with the specified species family 
filtered_df = tsa_master[tsa_master['div_s'].str.contains(species, na=False)]
species_char_count = defaultdict(int)
filtered_entries_count = defaultdict(int)

for index, row in tqdm(filtered_df.iterrows(), total=filtered_df.shape[0], desc="Processing rows"):
    try:
        if pd.isna(row['organism_an']):
            print(f"Skipping row {index} - no organism name")
            continue

        species_name = row['organism_an']
        filename = species_name.replace(' ', '_')
        prefix = row["prefix_s"] # get prefix
        identifier = prefix[0] # get identifier (first letter)
        url = f"https://ftp.ncbi.nlm.nih.gov/genbank/tsa/{identifier}/tsa.{prefix[:-2]}.1.fsa_nt.gz" # construct FTP URL. remove last two letters from prefix
        tqdm.write(f"Downloading {url} for {species_name}...")

        response = requests.get(url, stream=True)
        if response.status_code == 200:
            output_file = os.path.join(tsa_raw_output_dir, f"{filename}.fasta")
            with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as f:
                fasta_content = f.read().decode('utf-8')
                entries = fasta_content.split(">")
                filtered_entries = []
                for entry in entries[1:]:
                    # Very coarse-grained filtering: check if "mRNA" is in the header line (first line of the entry)
                    header = entry.split('\n', 1)[0]
                    if "mRNA" not in header:
                        filtered_entries.append(">" + entry)
                    else:
                        filtered_entries_count[filename] += 1
                filtered_content = "".join(filtered_entries)
                with open(output_file, 'a') as outfile:
                    outfile.write(filtered_content)
                species_char_count[filename] += len(filtered_content)
        else:
            tqdm.write(f"Failed to download {url}. Status code: {response.status_code}")
            
        time.sleep(random.uniform(0.5, 1.5))  # delay
    
    except Exception as e:
        tqdm.write(f"Error processing row {index}: {e}")

# Summary
for species, count in species_char_count.items():
    filtered = filtered_entries_count[species]
    print(f"Created {species}.fasta with {count} characters (coarse grained filtered out {filtered} mRNA entries)")

### Filter out coding RNA
Filtering out coding RNA is done using MMseqs2 against the Swiss-Prot database

For easy of use and reproduceability this runs on BioLib as seen below.

This is run on a preproccesed data-record: https://biolib.com/ncRNA-foundational-model/MAM-PRI-ROD-TSA/ for 'PRI|MAM|ROD'

In [None]:
import biolib
biolib.login()

In [None]:
mmseqs2_code_filter = biolib.load('ncRNA-foundational-model/blast-species') # Right now this runs on a data-record containing all raw 'PRI|MAM|ROD' TSA records (created above). App is not able to take input right now.
# Starting multple jobs for faster computation
# Determine the range and splits based on which species families used.
job_0_4 = mmseqs2_code_filter.cli(args=['--range', '0-4'], blocking=False)
job_5_9 = mmseqs2_code_filter.cli(args=['--range', '5-9'], blocking=False)
job_12_19 = mmseqs2_code_filter.cli(args=['--range', '12-19'], blocking=False)
job_20_29 = mmseqs2_code_filter.cli(args=['--range', '20-29'], blocking=False)
job_30_39 = mmseqs2_code_filter.cli(args=['--range', '30-39'], blocking=False)
job_44_49 = mmseqs2_code_filter.cli(args=['--range', '44-49'], blocking=False)
job_50_59 = mmseqs2_code_filter.cli(args=['--range', '50-59'], blocking=False)
job_60_66 = mmseqs2_code_filter.cli(args=['--range', '60-66'], blocking=False)

jobs =[job_0_4, job_5_9, job_12_19, job_20_29, job_30_39, job_44_49, job_50_59, job_60_66]
for job in jobs:
    job.wait() # wait for finish
    job.save_files(output_dir='tsa_species_fasta_non_coding_mmseqs2/', path_filter='*/*.fasta') # download all results 

### GC and entropy filtering

As an additional filtering step, sequences are retained only if they meet the following criteria:
- GC content is between 30% and 80%.
- Sequence length is at least 10 nucleotides.
- Shannon entropy is at least 1.75.

In [None]:
import os
import math
import argparse
from Bio import SeqIO
from tqdm import tqdm

MIN_GC = 0.3
MAX_GC = 0.8
MIN_ENTROPY = 1.75
MIN_LEN = 10

def shannon_entropy(seq):
    """Calculates the Shannon entropy of a sequence."""
    freq = {}
    for base in seq:
        # Consider only standard DNA/RNA bases
        if base in "ACGTU":
            freq[base] = freq.get(base, 0) + 1
    total = sum(freq.values())
    if total == 0:
        return 0.0

    entropy = -sum((count / total) * math.log2(count / total)
                   for count in freq.values() if count > 0)
    return entropy

def gc_content(seq):
    seq = seq.upper()
    g = seq.count('G')
    c = seq.count('C')
    length = len(seq)
    return (g + c) / length if length > 0 else 0

def process_fasta_file(input_filepath, output_filepath):
    filtered_records = []
    total_sequences = 0
    removed_sequences = 0

    print(f"Processing {os.path.basename(input_filepath)}...")
    with open(input_filepath, "r") as handle:
        initial_count = sum(1 for _ in SeqIO.parse(handle, "fasta"))
    
    with open(input_filepath, "r") as handle:
        for record in tqdm(SeqIO.parse(handle, "fasta"), total=initial_count, unit="seq"):
            total_sequences += 1
            seq = str(record.seq).upper().replace("N", "")
            seq_len = len(seq)

            if seq_len < MIN_LEN:
                removed_sequences += 1
                continue

            gc = gc_content(seq)
            ent = shannon_entropy(seq)

            if MIN_GC <= gc <= MAX_GC and ent >= MIN_ENTROPY:
                filtered_records.append(record)
            else:
                removed_sequences += 1

    if filtered_records:
        SeqIO.write(filtered_records, output_filepath, "fasta")

    return removed_sequences, total_sequences


def main():
    MIN_GC = 0.3
    MAX_GC = 0.8
    MIN_ENTROPY = 1.0
    MIN_LEN = 10
    DEFAULT_OUTPUT_DIR = "non_coding_filtered_entropy"
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir",)
    parser.add_argument("--output_dir", default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--min_gc", type=float, default=MIN_GC)
    parser.add_argument("--max_gc", type=float, default=MAX_GC)
    parser.add_argument("--min_entropy", type=float, default=MIN_ENTROPY)
    parser.add_argument("--min_len", type=int, default=MIN_LEN)

    args = parser.parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir

    total_removed_sequences = 0
    total_sequences_processed = 0

    fasta_files_found = False
    for filename in os.listdir(input_dir):
        if filename.lower().endswith(".fasta")
            fasta_files_found = True
            input_filepath = os.path.join(input_dir, filename)
            output_filepath = os.path.join(output_dir, filename)
            removed, processed = process_fasta_file(input_filepath, output_filepath)
            total_removed_sequences += removed
            total_sequences_processed += processed

    print(f"Total sequences filtered out from all files: {total_removed_sequences}")
    print(f"Total sequences processed from all files: {total_sequences_processed}")

if __name__ == "__main__":
    main()

### Homology filtering
Homology filtering is done using MMseqs2

This will generate the m8 files with all inter species homology hits as well as sequence identities.

This also runs on a precomputed data-record: https://biolib.com/ncRNA-foundational-model/non-coding-filtered-MAM-PRI-ROD-TSA/ for 'PRI|MAM|ROD'

In [None]:
homology_species = biolib.load('ncRNA-foundational-model/homology-species')

job_0_14 = homology_species.cli(args=[
    '--range', '0-14'
], blocking=False)

job_15_29 = homology_species.cli(args=[
    '--range', '15-29'
], blocking=False)

job_30_44 = homology_species.cli(args=[
    '--range', '30-44'
], blocking=False)

job_45_62 = homology_species.cli(args=[
    '--range', '45-62'
], blocking=False)

jobs = [job_0_14, job_15_29, job_30_44, job_45_62]
for job in jobs:
    job.wait() # wait for finish
    job.save_files(output_dir=f"tsa_species_fasta_m8_files", path_filter="tmp_search/*_results.m8")

The m8 files are used with the non-coding FASTAs to filter based on the following thresholds:

| Sequence Identity | Minimum Inter-Species Hits |
| :---------------- | :------------------------- |
| 75%               | 1                          |
| 75%               | 3                          |
| 75%               | 5                          |
| 75%               | 7                          |
| 80%               | 1                          |
| 80%               | 3                          |
| 80%               | 5                          |
| 80%               | 7                          |
| 85%               | 1                          |
| 85%               | 3                          |
| 85%               | 5                          |
| 85%               | 7                          |
| 90%               | 1                          |
| 90%               | 3                          |
| 90%               | 5                          |
| 90%               | 7                          |
| 95%               | 1                          |
| 95%               | 3                          |
| 95%               | 5                          |
| 95%               | 7                          |
| 99%               | 1                          |
| 99%               | 3                          |
| 99%               | 5                          |
| 99%               | 7                          |

This can be done locally by running the below code. Adjust paths as needed. 

In [None]:
import os
import csv
from collections import defaultdict
import sys
from Bio.Seq import Seq
import argparse
from tqdm import tqdm

# === CONFIGURATION ===
M8_DIR = "m8_files/tmp_search"
FASTA_INPUT_DIR = "non_coding_filtered_entropy"
BASE_OUTPUT_DIR = "conserved_rnas_filtered_flanked"
QUERY_CONSERVATION_CUTOFFS = [0.0, 0.50]
THRESHOLD_COMBINATIONS = [
    (0.99, 7), (0.99, 5), (0.99, 3), (0.99, 1),
    (0.95, 7), (0.95, 5), (0.95, 3), (0.95, 1),
    (0.90, 7), (0.90, 5), (0.90, 3), (0.90, 1),
    (0.85, 7), (0.85, 5), (0.85, 3), (0.85, 1),
    (0.80, 7), (0.80, 5), (0.80, 3), (0.80, 1),
    (0.75, 7), (0.75, 5), (0.75, 3), (0.75, 1),
]
MIN_ALIGN_LEN = 20
EVALUE_THRESHOLD = 1e-10
flank=20

def parse_fasta_stream(fasta_path):
    seqs = {}
    with open(fasta_path) as f:
        seq_id = None
        seq_lines = []
        for line in tqdm(f, desc=f"Parsing FASTA: {os.path.basename(fasta_path)}", leave=False):
            line = line.rstrip()
            if line.startswith(">"):
                if seq_id:
                    seqs[seq_id] = "\n".join(seq_lines)
                seq_id = line[1:].split()[0]
                seq_lines = [line]
            else:
                seq_lines.append(line)
        if seq_id:
            seqs[seq_id] = "\n".join(seq_lines)
    return seqs

def parse_m8_stream(m8_path, species):
    best_hits = {}
    alignment_coords = {}
    with open(m8_path) as f:
        for line in tqdm(f, desc=f"Parsing M8: {os.path.basename(m8_path)}", leave=False):
            parts = line.rstrip().split("\t")
            if len(parts) < 14:
                continue
            query, target = parts[0], parts[1]
            identity = float(parts[2])
            align_len = int(parts[3])
            evalue = float(parts[10])
            qcov = float(parts[12])
            hit_species = target.split("__")[0]
            if hit_species == species:
                continue
            key = (query, hit_species)
            if key not in best_hits or identity > best_hits[key][2] or (identity == best_hits[key][2] and qcov > best_hits[key][4]):
                best_hits[key] = (query, target, identity, align_len, qcov, evalue)
                alignment_coords[key] = (int(parts[6]), int(parts[7]))
    return best_hits, alignment_coords

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--flank-only", action="store_true", help="Only write flanked alignment dataset, skip full sequences")
    args = parser.parse_args()

    fasta_dir = os.path.join(os.path.dirname(__file__), FASTA_INPUT_DIR)
    m8_dir = M8_DIR
    base_output_dir = os.path.join(os.path.dirname(__file__), BASE_OUTPUT_DIR)
    os.makedirs(base_output_dir, exist_ok=True)

    fasta_files = [f for f in os.listdir(fasta_dir) if f.endswith(".fasta")]
    all_species = [os.path.splitext(f)[0].replace("_non_coding", "") for f in fasta_files]
    total_possible_species = len(all_species) - 1

    # Prepare output writers for each combination
    output_handles = {}
    stats_handles = {}
    alignment_handles = {}
    for identity, min_species in THRESHOLD_COMBINATIONS:
        for qcons in QUERY_CONSERVATION_CUTOFFS:
            outdir = os.path.join(base_output_dir, f"run_{identity}_{min_species}-sp_querycons{qcons}")
            os.makedirs(outdir, exist_ok=True)
            fasta_path = os.path.join(outdir, "conserved_sequences.fasta")
            stats_path = os.path.join(outdir, "stats.tsv")
            alignment_path = os.path.join(outdir, "conserved_alignment.fasta")
            output_handles[(identity, min_species, qcons)] = open(fasta_path, "w")
            stats_handles[(identity, min_species, qcons)] = open(stats_path, "w", newline='')
            alignment_handles[(identity, min_species, qcons)] = open(alignment_path, "w")
            stats_writer = csv.writer(stats_handles[(identity, min_species, qcons)], delimiter='\t')
            stats_writer.writerow(["species", "unique_species_hit", "conserved_rna_count"])

    print(f"Processing {len(fasta_files)} species...")
    for fasta_file in tqdm(fasta_files, desc="Species", unit="species"):
        species = os.path.splitext(fasta_file)[0].replace("_non_coding", "")
        fasta_path = os.path.join(fasta_dir, fasta_file)
        m8_path = os.path.join(m8_dir, f"{species}_results.m8")
        if not os.path.exists(m8_path):
            print(f"Warning: {m8_path} not found, skipping {species}")
            continue

        print(f"\nProcessing species: {species}")
        seqs = parse_fasta_stream(fasta_path)
        best_hits, alignment_coords = parse_m8_stream(m8_path, species)

        for identity, min_species in tqdm(THRESHOLD_COMBINATIONS, desc=f"Thresholds for {species}", leave=False):
            query2species = defaultdict(set)
            query2hits = defaultdict(list)
            for (query, hitsp), (q, t, ident, alen, qcov, evalue) in best_hits.items():
                if ident >= identity and alen >= MIN_ALIGN_LEN and evalue <= EVALUE_THRESHOLD:
                    query2species[query].add(hitsp)
                    query2hits[query].append((hitsp, q, t, ident, alen, qcov, evalue))
            filtered_queries = {q for q, sps in query2species.items() if len(sps) >= min_species}

            for qcons in QUERY_CONSERVATION_CUTOFFS:
                out_handle = output_handles[(identity, min_species, qcons)]
                stats_writer = csv.writer(stats_handles[(identity, min_species, qcons)], delimiter='\t')
                alignment_handle = alignment_handles[(identity, min_species, qcons)]
                conserved_ids = []
                for q in filtered_queries:
                    n_species = len(query2species[q])
                    if total_possible_species > 0 and (n_species / total_possible_species) >= qcons:
                        conserved_ids.append(q)
                # Write sequences
                for q in conserved_ids:
                    if q in seqs:
                        lines = seqs[q].split("\n")
                        lines[0] = f">{species}__{lines[0][1:]}"
                        if not args.flank_only:
                            out_handle.write("\n".join(lines) + "\n")
                        for hit in query2hits[q]:
                            hitsp, qid, tid, ident, alen, qcov, evalue = hit
                            qstart, qend = alignment_coords.get((q, hitsp), (None, None))
                            if qstart is None or qend is None:
                                print(f"Warning: No alignment coords for {q}, {hitsp}")
                                continue
                            seq = ''.join([l for l in seqs[q].split("\n")[1:]])
                            start = max(0, min(qstart, qend) - 1 - flank)
                            end = min(len(seq), max(qstart, qend) + flank)
                            subseq = seq[start:end]
                            if qstart > qend:
                                subseq = str(Seq(subseq).reverse_complement())
                            header = f">{species}__{q}__{hitsp}__{qstart}_{qend}_flank{flank}"
                            if subseq:
                                alignment_handle.write(f"{header}\n{subseq}\n")
                            else:
                                print(f"Warning: Empty alignment for {header}")
                unique_species_hit = len({sp for q in conserved_ids for sp in query2species[q]})
                stats_writer.writerow([species, unique_species_hit, len(conserved_ids)])
                print(f"  [identity={identity}, min_species={min_species}, qcons={qcons}] {species}: {len(conserved_ids)} conserved, {unique_species_hit} unique species hit")

    # Close all output files
    for handle in output_handles.values():
        handle.close()
    for handle in stats_handles.values():
        handle.close()
    for handle in alignment_handles.values():
        handle.close()

    print("All threshold and query conservation combinations processed.")

if __name__ == "__main__":
    main()

### Redundancy clustering
Redudancy clustering is done using MMSEQS2 with 90% identity and 80% coverage

This will generate and output datasets with less redundancy since similar sequences are removed.

This is done one a precomputed data-record: https://biolib.com/ncRNA-foundational-model/TSA-conserved-before-clustering/ for 'PRI|MAM|ROD' and run through a BioLib application. 

In [None]:
mmseqs_clustering = biolib.load('ncRNA-foundational-model/mmseqs-clustering')

# Running without arguments since data-record is preloaded
job = mmseqs_clustering.run()
job.wait() # wait for finish
job.save_files(output_dir='final_datasets')

### Final dataset(s)
The final dataset(s) can now be found in `final_datasets`