In [None]:
import numpy as np
import re
import gzip

import pandas as pd
from tqdm import tqdm

In [2]:
species_list = ['alligator_mississippiensis', 'anolis_carolinensis', 'anopheles_gambiae', 'apis_mellifera', 'arabidopsis_thaliana', 'aspergillus_nidulans', 'bos_taurus', 'brachypodium_distachyon', 'caenorhabditis_elegans', 'canis_lupus_familiaris', 'columba_livia', 'coprinopsis_cinerea', 'cryptococcus_neoformans', 'danio_rerio', 'daphnia_carinata', 'dictyostelium_discoideum', 'drosophila_melanogaster', 'eimeria_maxima', 'entamoeba_histolytica', 'equus_caballus', 'gallus_gallus', 'giardia_intestinalis', 'glycine_max', 'gorilla_gorilla', 'homo_sapiens', 'hordeum_vulgare', 'leishmania_donovani', 'lotus_japonicus', 'manduca_sexta', 'medicago_truncatula', 'mus_musculus', 'neurospora_crassa', 'nicotiana_tabacum', 'oreochromis_niloticus', 'oryctolagus_cuniculus', 'oryza_sativa', 'oryzias_latipes', 'ovis_aries', 'pan_troglodytes', 'phoenix_dactylifera', 'plasmodium_falciparum', 'rattus_norvegicus', 'rhizophagus_irregularis', 'saccharomyces_cerevisiae', 'schizophyllum_commune', 'schizosaccharomyces_pombe', 'selaginella_moellendorffii', 'setaria_viridis', 'solanum_lycopersicum', 'strongylocentrotus_purpuratus', 'sus_scrofa', 'taeniopygia_guttata', 'toxoplasma_gondii', 'tribolium_castaneum', 'trichoplax_adhaerens', 'triticum_aestivum', 'trypanosoma_brucei', 'ustilago_maydis', 'xenopus_laevis', 'zea_mays']

# Proces predictions for transcript-level benchmark

### TIS Transformer

In [12]:
def init_dict_full_transcripts():
    information_dict = dict()
    information_dict["species"] = []
    information_dict["seq_number"] = []
    information_dict["label"] = []
    information_dict["preds"] = []

    return information_dict

TIS_transformer_full_transcripts_preds = init_dict_full_transcripts()

In [None]:
def extract_atg_positions(species):
    species_atgs_dict = {}
    transcript_found = False
    startcodons = "ATG"

    #Get all ATG locations on transcript
    with gzip.open(f"../../data/data_evaluation/input_testsets/mRNA_testsets_processed/input_testset_{species}_softmasked.fasta.gz", "rt") as infile:
        for line in infile:
            if line.startswith(">"):
                if "TIS=1" in line:
                    transcript_found = True
                    TIS_pos = int(line.split("ATG_pos=")[1].split("|")[0])
                    seq_number = line.split("seq_number=")[1].split("|")[0]
            else:
                if transcript_found:
                    seq = line.strip().upper()
                    atg_indices = [i for i in range(len(seq)) if seq.startswith(startcodons, i)]
                    atg_indices.remove(TIS_pos)

                    #Separate upstream and downstream indices
                    upstream_indices = [i for i in atg_indices if i < TIS_pos]
                    downstream_indices = [i for i in atg_indices if i > TIS_pos]

                    #Check if indices are in the same reading frame as the TIS
                    same_frame = lambda x: (x - TIS_pos) % 3 == 0

                    #Categorize by frame
                    upstream_same_frame = [i for i in upstream_indices if same_frame(i)]
                    upstream_diff_frame = [i for i in upstream_indices if not same_frame(i)]
                    downstream_same_frame = [i for i in downstream_indices if same_frame(i)]
                    downstream_diff_frame = [i for i in downstream_indices if not same_frame(i)]

                    #Store results in the dictionary
                    species_atgs_dict[seq_number] = {
                        "TIS_pos": TIS_pos,
                        "upstream_same_frame": upstream_same_frame,
                        "upstream_diff_frame": upstream_diff_frame,
                        "downstream_same_frame": downstream_same_frame,
                        "downstream_diff_frame": downstream_diff_frame
                    }

                
                transcript_found = False
    
    return species_atgs_dict
    

In [14]:
def get_preds(species, species_atgs_dict, TIS_transformer_data):
    #Load the .npy file with TIS Transformer
    tis_transformer_preds_unprocessed = np.load(f'../../data/data_evaluation/TIS_transformer/preds/testset/{species}.npy', allow_pickle=True)

    #Run through all sequences predicted upon
    for sequence in tis_transformer_preds_unprocessed:
        # Extract sequence metadata
        sequence_header = sequence[0]
        TIS_label = float(sequence_header.split("|")[2].split("=")[1])

        if TIS_label == 1.0:
            # Extract additional metadata
            seq_number = sequence_header.split("seq_number=")[1].split("|")[0]
            source = sequence_header.split("source=")[1].split("|")[0]
            TSS_annotated = sequence_header.split("TSS_annotated=")[1].split("|")[0]

            # Get ATG positions and predictions
            atg_positions = species_atgs_dict[seq_number]
            preds = sequence[1]

            TIS_pred = preds[atg_positions["TIS_pos"]]
            upstream_same_frame_preds = preds[atg_positions["upstream_same_frame"]]
            upstream_diff_frame_preds = preds[atg_positions["upstream_diff_frame"]]
            downstream_same_frame_preds = preds[atg_positions["downstream_same_frame"]]
            downstream_diff_frame_preds = preds[atg_positions["downstream_diff_frame"]]

            # Add predictions and metadata to the final data
            TIS_transformer_data.append((TIS_pred, 1.0, "TIS", source, TSS_annotated, species, seq_number))
            TIS_transformer_data.extend([(pred, 0.0, "upstream_same_frame", source, TSS_annotated, species, seq_number) for pred in upstream_same_frame_preds])
            TIS_transformer_data.extend([(pred, 0.0, "upstream_diff_frame", source, TSS_annotated, species, seq_number) for pred in upstream_diff_frame_preds])
            TIS_transformer_data.extend([(pred, 0.0, "downstream_same_frame", source, TSS_annotated, species, seq_number) for pred in downstream_same_frame_preds])
            TIS_transformer_data.extend([(pred, 0.0, "downstream_diff_frame", source, TSS_annotated, species, seq_number) for pred in downstream_diff_frame_preds])


    return TIS_transformer_data

In [27]:
TIS_transformer_data = []

for species in tqdm(species_list, desc="Processing species"):
    species_atgs_dict = extract_atg_positions(species)
    TIS_transformer_data = get_preds(species, species_atgs_dict, TIS_transformer_data)

TIS_transformer_df_full = pd.DataFrame(TIS_transformer_data, columns=["preds", "label", "ATG_type", "annotation_source", "TSS_annotated", "species", "seq_number"])

#save df
TIS_transformer_df_full.to_csv('../../data/data_evaluation/preds_processed/transcripts/TIS_transformer_df_transcripts.csv.gz', index=False, compression = "gzip")

Processing species: 100%|██████████| 60/60 [01:38<00:00,  1.64s/it]


### NetStart 2.0 and ablations

In [89]:
def get_netstart_preds(preds_subpath):
    
    results = []
    
    for species in tqdm(species_list, desc=f"Processing data in {preds_subpath}"):
        try:
            # Read CSV and filter for TIS=1 entries immediately
            df = pd.read_csv(f"../../data/data_evaluation/{preds_subpath}/preds_{species}.csv.gz", compression = "gzip")
            df = df[df['entry_line'].str.contains("TIS=1", na=False)]
            
            # Vectorized operations for extracting values
            df['TIS_pos'] = df['entry_line'].str.extract(r'ATG_pos=(\d+)').astype(int)
            df['source'] = df['entry_line'].str.extract(r'source=([^|]+)')
            df['seq_number'] = df['entry_line'].str.extract(r'seq_number=([^|]+)')
            df['TSS_annotated'] = df['entry_line'].str.extract(r'TSS_annotated=([^|]+)')
            
            # Vectorized calculations
            df['label'] = (df['atg_position'] == df['TIS_pos']).astype(float)
            df['TIS_frame'] = df['TIS_pos'] % 3
            df['frame'] = df['atg_position'] % 3
            df['same_frame'] = df['frame'] == df['TIS_frame']
            
            # Vectorized ATG_type assignment
            conditions = [
                df['atg_position'] == df['TIS_pos'],
                (df['atg_position'] < df['TIS_pos']) & df['same_frame'],
                (df['atg_position'] < df['TIS_pos']) & ~df['same_frame'],
                (df['atg_position'] > df['TIS_pos']) & df['same_frame'],
                (df['atg_position'] > df['TIS_pos']) & ~df['same_frame']
            ]
            choices = ['TIS', 'upstream_same_frame', 'upstream_diff_frame', 
                      'downstream_same_frame', 'downstream_diff_frame']
            df['ATG_type'] = np.select(conditions, choices)
            
            # Create results dictionary
            batch_results = df.apply(lambda row: {
                "position": row['atg_position'],
                "preds": row['preds'],
                "label": row['label'],
                "ATG_type": row['ATG_type'],
                "annotation_source": row['source'],
                "TSS_annotated": row['TSS_annotated'],
                "species": species,
                "seq_number": row['seq_number']
            }, axis=1).tolist()
            
            results.extend(batch_results)
            
        except FileNotFoundError:
            continue

    return pd.DataFrame(results)

In [90]:
#netstart_overall_df_full = get_netstart_preds(preds_subpath = "netstart2.0/preds/overall/testset")
#netstart_group_specific_df_full = get_netstart_preds(preds_subpath = "netstart2.0/preds/group_specific/testset")
#esm2_finetuned_ablation = get_netstart_preds(preds_subpath = "ablations/esm2_finetuned/testset")
#netstart1_ablation = get_netstart_preds(preds_subpath = "ablations/netstart1/testset")

#netstart_overall_df_full.to_csv('../../data/data_evaluation/preds_processed/transcripts/netstart_overall_df_transcripts.csv.gz', index=False, compression = "gzip")
#netstart_group_specific_df_full.to_csv('../../data/data_evaluation/preds_processed/transcripts/netstart_group_specific_df_transcripts.csv.gz', index=False, compression = "gzip")
#esm2_finetuned_ablation.to_csv('../../data/data_evaluation/preds_processed/transcripts/esm2_finetuned_ablation_df_transcripts.csv.gz', index=False, compression = "gzip")
#netstart1_ablation.to_csv('../../data/data_evaluation/preds_processed/transcripts/netstart1_ablation_df_transcripts.csv.gz', index=False, compression = "gzip")

Processing data in ablations/netstart1/testset: 100%|██████████| 60/60 [12:40<00:00, 12.68s/it]


# Proces predictions for testset benchmark

In [5]:
def init_dict_testsets():
    #Initialize 
    information_dict = dict()
    information_dict["species"] = []
    information_dict["seq_type"] = []
    information_dict["label"] = []
    information_dict["preds"] = []
    information_dict["annotation_source"] = []
    information_dict["seq_number"] = []
    information_dict["dataset_type"] = []
    information_dict["first_downstream_intron"] = []
    information_dict["TSS_annotated"] = []
    information_dict["5_flanking_sequence_missing"] = []
    information_dict["3_flanking_sequence_missing"] = []

    return information_dict

### TIS Transformer

In [25]:
def get_tis_transformer_preds():
    #Initialize 
    TIS_transformer_testset_preds = init_dict_testsets()
    TIS_transformer_genomic_preds = init_dict_testsets()


    for species in species_list:
        #Load the .npy file with TIS Transformer
        tis_transformer_preds_unprocessed = np.load(f'../../data/data_evaluation/TIS_transformer/preds/testset/{species}.npy', allow_pickle=True)
        
        #Run through all sequences predicted upon
        for sequence in tis_transformer_preds_unprocessed:
            #Get prediction information
            sequence_header = sequence[0]
            seq_number = sequence_header.split("seq_number=")[1].split("|")[0]
            TIS_label = float(sequence_header.split("TIS=")[1].split("|")[0])
            seq_type = sequence_header.split("type=")[1].split("|")[0]
            annotation_source = sequence_header.split("source=")[1].split("|")[0]
            TIS_pred_position = int(sequence_header.split("ATG_pos=")[1].split("|")[0])
            TSS_annotated = sequence_header.split("TSS_annotated=")[1].strip()

            #Extract prediction at labelled position (TIS/non-TIS)
            preds = sequence[1]
            TIS_pred = preds[TIS_pred_position]

            #Write all relevant information to dict
            TIS_transformer_testset_preds["species"].append(species)
            TIS_transformer_testset_preds["seq_type"].append(seq_type)
            TIS_transformer_testset_preds["label"].append(TIS_label)
            TIS_transformer_testset_preds["preds"].append(TIS_pred)
            TIS_transformer_testset_preds["annotation_source"].append(annotation_source)
            TIS_transformer_testset_preds["seq_number"].append(seq_number)
            TIS_transformer_testset_preds["dataset_type"].append("testset")
            TIS_transformer_testset_preds["first_downstream_intron"].append(None)
            TIS_transformer_testset_preds["TSS_annotated"].append(TSS_annotated)
            TIS_transformer_testset_preds["5_flanking_sequence_missing"].append(None)
            TIS_transformer_testset_preds["3_flanking_sequence_missing"].append(None)
        

        #Load genomic testset
        tis_transformer_preds_unprocessed_genomic = np.load(f'../../data/data_evaluation/TIS_transformer/preds/genomic/{species}.npy', allow_pickle=True)

        #Run through all sequences predicted upon
        for sequence in tis_transformer_preds_unprocessed_genomic:
            #Get prediction information
            sequence_header = sequence[0]
            gene_name = sequence_header.split("|")[0]
            annotation_source = sequence_header.split("source=")[1].split("|")[0]
            TIS_pred_position = int(sequence_header.split("TIS_position=")[1].split("|")[0])
            first_downstream_intron = int(sequence_header.split("first_downstream_intron_start=")[1].split("|")[0])
            missing_5_flanking_seq = sequence_header.split("5_flanking_sequence_missing=")[1].split("|")[0]
            missing_3_flanking_seq = sequence_header.split("3_flanking_sequence_missing=")[1].split("|")[0]
            TSS_annotated = sequence_header.split("TSS_annotated=")[1].strip()
                
            #Extract prediction at labelled position (TIS/non-TIS)
            preds = sequence[1]
            TIS_pred = preds[TIS_pred_position]
                
            #Write all relevant information to dict
            TIS_transformer_genomic_preds["species"].append(species)
            TIS_transformer_genomic_preds["seq_type"].append("Genomic TIS")
            TIS_transformer_genomic_preds["label"].append(float(1))
            TIS_transformer_genomic_preds["preds"].append(TIS_pred)
            TIS_transformer_genomic_preds["annotation_source"].append(annotation_source)
            TIS_transformer_genomic_preds["seq_number"].append(gene_name)
            TIS_transformer_genomic_preds["dataset_type"].append("genomic")
            TIS_transformer_genomic_preds["first_downstream_intron"].append(int(first_downstream_intron))
            TIS_transformer_genomic_preds["TSS_annotated"].append(TSS_annotated)
            TIS_transformer_genomic_preds["5_flanking_sequence_missing"].append(missing_5_flanking_seq)
            TIS_transformer_genomic_preds["3_flanking_sequence_missing"].append(missing_3_flanking_seq)

    TIS_transformer_df = pd.DataFrame(TIS_transformer_testset_preds)
    TIS_transformer_genomic_df = pd.DataFrame(TIS_transformer_genomic_preds)

    return TIS_transformer_df, TIS_transformer_genomic_df

### AUGUSTUS

In [1]:
def get_augustus_preds(softmask):
    #Initialize
    AUGUSTUS_testset_preds = init_dict_testsets()
    AUGUSTUS_genomic_preds = init_dict_testsets()

    if softmask == True:
        dir = "softmask"
    else:
        dir = "no_softmask"
    #get predictions from all species
    for species in species_list:
        initiated = False

        with open(f'../../data/data_evaluation/AUGUSTUS/preds/{dir}/testset/{species}_preds.gff', "r") as file:
            for line in file:
                #New sequence predicted on
                if line.startswith("# ----- prediction on sequence number"):
                    
                    #If first sequence has been found
                    if initiated == True:
                        
                        #If the matching start codon is found; TIS predicted
                        if atg_pos in predicted_atg_pos_list:
                            AUGUSTUS_testset_preds["preds"].append(1.0)
                        #If the matching start codon is not found; TIS not predicted
                        else:
                            AUGUSTUS_testset_preds["preds"].append(0.0)

                    seq_type = line.split("type=")[1].split("|")[0].replace("_", " ")
                    seq_number = line.split("seq_number=")[1].split("|")[0]
                    label = line.split("TIS=")[1].split("|")[0]
                    source = line.split("source=")[1].split("|")[0] 
                    TSS_annotated = line.split("TSS_annotated=")[1].split(")")[0]
                    
                    #Get information to dict
                    AUGUSTUS_testset_preds["species"].append(species)
                    AUGUSTUS_testset_preds["seq_type"].append(seq_type)
                    AUGUSTUS_testset_preds["label"].append(float(label))
                    AUGUSTUS_testset_preds["annotation_source"].append(source)
                    AUGUSTUS_testset_preds["seq_number"].append(seq_number)
                    AUGUSTUS_testset_preds["dataset_type"].append("testset")
                    AUGUSTUS_testset_preds["first_downstream_intron"].append(None)
                    AUGUSTUS_testset_preds["TSS_annotated"].append(TSS_annotated)
                    AUGUSTUS_testset_preds["5_flanking_sequence_missing"].append(None)
                    AUGUSTUS_testset_preds["3_flanking_sequence_missing"].append(None)
                
                    #Find labelled ATG position in sequence
                    match_atg = re.search(r'ATG_pos=(\d+)', line)
                    if match_atg:
                        atg_pos_value = int(match_atg.group(1))
                        atg_pos = atg_pos_value + 1
                    
                    predicted_atg_pos_list = []
                    initiated = True
                
                #Get all start codon annotations
                if "start_codon" in line:
                    predicted_atg_pos_list.append(int(line.split("start_codon\t")[1].split("\t")[0]))
            
            #Last sequence
            #If the matching start codon is found; TIS predicted
            if atg_pos in predicted_atg_pos_list:
                AUGUSTUS_testset_preds["preds"].append(1.0)
            #If the matching start codon is not found; TIS not predicted
            else:
                AUGUSTUS_testset_preds["preds"].append(0.0)

        
        #Repeat for genomic test set
        initiated = False
        with open(f'../../data/data_evaluation/AUGUSTUS/preds/{dir}/genomic/{species}_preds.gff', "r") as file:
            for line in file:
                if line.startswith("# ----- prediction on sequence number"):
                    
                    if initiated == True:
                        
                        #If the matching start codon is found; TIS predicted
                        if atg_pos in predicted_atg_pos_list:
                            AUGUSTUS_genomic_preds["preds"].append(1.0)
                        #If the matching start codon is not found; TIS not predicted
                        else:
                            AUGUSTUS_genomic_preds["preds"].append(0.0)
                    
                    gene_name = line.split("name = ")[1].split("|")[0]
                    source = line.split("annotation_source=")[1].split("|")[0] 
                    first_downstream_intron = line.split("first_downstream_intron_start=")[1].split("|")[0] 
                    missing_5_flanking_seq = line.split("5_flanking_sequence_missing=")[1].split("|")[0]
                    missing_3_flanking_seq = line.split("3_flanking_sequence_missing=")[1].split("|")[0]
                    TSS_annotated = line.split("TSS_annotated=")[1].split(")")[0].strip()
                    
                    #Get information to dict
                    AUGUSTUS_genomic_preds["species"].append(species)
                    AUGUSTUS_genomic_preds["seq_type"].append("Genomic TIS")
                    AUGUSTUS_genomic_preds["label"].append(float(1))
                    AUGUSTUS_genomic_preds["annotation_source"].append(source)
                    AUGUSTUS_genomic_preds["seq_number"].append(gene_name)
                    AUGUSTUS_genomic_preds["dataset_type"].append("genomic")
                    AUGUSTUS_genomic_preds["first_downstream_intron"].append(int(first_downstream_intron))
                    AUGUSTUS_genomic_preds["TSS_annotated"].append(TSS_annotated)
                    AUGUSTUS_genomic_preds["5_flanking_sequence_missing"].append(missing_5_flanking_seq)
                    AUGUSTUS_genomic_preds["3_flanking_sequence_missing"].append(missing_3_flanking_seq)
                    
                    #Find labelled ATG position
                    match_atg = re.search(r'TIS_position=(\d+)', line)
                    if match_atg:
                        atg_pos_value = int(match_atg.group(1))
                        atg_pos = atg_pos_value + 1
                    
                    predicted_atg_pos_list = []
                    initiated = True
                
                #Get all start codon annotations
                if "start_codon" in line:
                    predicted_atg_pos_list.append(int(line.split("start_codon\t")[1].split("\t")[0]))
                    ###If the matching start codon is found
            
            if atg_pos in predicted_atg_pos_list:
                AUGUSTUS_genomic_preds["preds"].append(1.0)
            else:
                AUGUSTUS_genomic_preds["preds"].append(0.0)
                
    augustus_testset_df = pd.DataFrame(AUGUSTUS_testset_preds)
    augustus_genomic_df = pd.DataFrame(AUGUSTUS_genomic_preds)

    return augustus_testset_df, augustus_genomic_df

### NetStart 2.0 and ablations

In [None]:
def get_netstart_preds(preds_subpath):
   netstart_testset_preds = init_dict_testsets()
   netstart_genomic_preds = init_dict_testsets()
   
   for species in tqdm(species_list, desc=f"Processing {preds_subpath}"):
       #Process testset
       try:
           df = pd.read_csv(f"../../data/data_evaluation/{preds_subpath}/testset/preds_{species}.csv.gz", compression = "gzip")
           
           #Extract all fields using vectorized operations
           df['seq_number'] = df['entry_line'].str.extract(r'seq_number=([^|]+)')
           df['seq_type'] = df['entry_line'].str.extract(r'type=([^|]+)') 
           df['label'] = df['entry_line'].str.extract(r'TIS=([^|]+)').astype(float)
           df['source'] = df['entry_line'].str.extract(r'source=([^|]+)')
           df['TSS_annotated'] = df['entry_line'].str.extract(r'TSS_annotated=(.+)$')
           df['ATG_pos'] = df['entry_line'].str.extract(r'ATG_pos=(\d+)').astype(int)
           
           #Filter matches and append in bulk
           matches = df[df['ATG_pos'] == df['atg_position']]
           
           if not matches.empty:
               netstart_testset_preds['species'].extend([species] * len(matches))
               netstart_testset_preds['seq_type'].extend(matches['seq_type'])
               netstart_testset_preds['label'].extend(matches['label'])
               netstart_testset_preds['preds'].extend(matches['preds'].astype(float))
               netstart_testset_preds['annotation_source'].extend(matches['source'])
               netstart_testset_preds['seq_number'].extend(matches['seq_number'])
               netstart_testset_preds['dataset_type'].extend(['testset'] * len(matches))
               netstart_testset_preds['first_downstream_intron'].extend([None] * len(matches))
               netstart_testset_preds['TSS_annotated'].extend(matches['TSS_annotated'])
               netstart_testset_preds['5_flanking_sequence_missing'].extend([None] * len(matches))
               netstart_testset_preds['3_flanking_sequence_missing'].extend([None] * len(matches))
       except FileNotFoundError:
           continue


       #Process genomic set 
       try:
           df = pd.read_csv(f"../../data/data_evaluation/{preds_subpath}/genomic/preds_{species}.csv.gz", compression = "gzip")

           df['gene_name'] = df['entry_line'].str.extract(r'([^|]+)')
           df['source'] = df['entry_line'].str.extract(r'source=([^|]+)')
           df['first_downstream_intron'] = df['entry_line'].str.extract(r'first_downstream_intron_start=([^|]+)').astype(int)
           df['missing_5_flanking_seq'] = df['entry_line'].str.extract(r'5_flanking_sequence_missing=([^|]+)')
           df['missing_3_flanking_seq'] = df['entry_line'].str.extract(r'3_flanking_sequence_missing=([^|]+)')
           df['TSS_annotated'] = df['entry_line'].str.extract(r'TSS_annotated=(.+)$')
           df['ATG_pos'] = df['entry_line'].str.extract(r'TIS_position=(\d+)').astype(int)
           
           matches = df[df['ATG_pos'] == df['atg_position']]
           
           if not matches.empty:
               netstart_genomic_preds['species'].extend([species] * len(matches))
               netstart_genomic_preds['seq_type'].extend(['Genomic TIS'] * len(matches))
               netstart_genomic_preds['label'].extend([1.0] * len(matches))
               netstart_genomic_preds['preds'].extend(matches['preds'].astype(float))
               netstart_genomic_preds['annotation_source'].extend(matches['source'])
               netstart_genomic_preds['seq_number'].extend(matches['gene_name'])
               netstart_genomic_preds['dataset_type'].extend(['testset'] * len(matches))
               netstart_genomic_preds['first_downstream_intron'].extend(matches['first_downstream_intron'])
               netstart_genomic_preds['TSS_annotated'].extend(matches['TSS_annotated'])
               netstart_genomic_preds['5_flanking_sequence_missing'].extend(matches['missing_5_flanking_seq'])
               netstart_genomic_preds['3_flanking_sequence_missing'].extend(matches['missing_3_flanking_seq'])
       except FileNotFoundError:
           continue
           
   return pd.DataFrame(netstart_testset_preds), pd.DataFrame(netstart_genomic_preds)

### Tiberius

In [83]:
def get_tiberius_preds(softmask):
    #Initialize
    tiberius_testset_preds = init_dict_testsets()
    tiberius_genomic_preds = init_dict_testsets()

    if softmask == True:
        dir = "softmask"
    else: 
        dir = "no_softmask"

    #get predictions from all species
    for species in species_list:

        try:
            seq_numbers = []

            with open(f'../../data/data_evaluation/Tiberius/preds/{dir}/testset/testset_preds_{species}.gtf', "r") as file:

                seq_pred_info_dict = dict()
                for line in file:
                    seq_number = line.split("seq_number=")[1].split("|")[0]
                    seq_numbers.append(seq_number)
                    entry_line = line.split("\t")[0]
                    
                    if entry_line not in seq_pred_info_dict:
                        seq_pred_info_dict[entry_line] = []
                    
                    if line.split("\t")[2] == "CDS" in line:
                        if "cds_type=single" in line or "cds_type=initial" in line:
                            seq_pred_info_dict[entry_line].append(int(line.split("\t")[3])-1)


            for entry_line in seq_pred_info_dict.keys():

                seq_type = entry_line.split("type=")[1].split("|")[0].replace("_", " ")
                label = entry_line.split("TIS=")[1].split("|")[0]
                source = entry_line.split("source=")[1].split("|")[0] 
                atg_pos = int(entry_line.split("ATG_pos=")[1].split("|")[0])
                TSS_annotated = entry_line.split("TSS_annotated=")[1].split("|")[0]
                seq_number = entry_line.split("seq_number=")[1].split("|")[0]
                
                if atg_pos in seq_pred_info_dict[entry_line]:
                    pred = 1.0
                else:
                    pred = 0.0

                #Get information to dict
                tiberius_testset_preds["species"].append(species)
                tiberius_testset_preds["seq_type"].append(seq_type)
                tiberius_testset_preds["label"].append(float(label))
                tiberius_testset_preds["preds"].append(pred)
                tiberius_testset_preds["annotation_source"].append(source)
                tiberius_testset_preds["seq_number"].append(seq_number)
                tiberius_testset_preds["dataset_type"].append("testset")
                tiberius_testset_preds["first_downstream_intron"].append(None)
                tiberius_testset_preds["TSS_annotated"].append(TSS_annotated)
                tiberius_testset_preds["5_flanking_sequence_missing"].append(None)
                tiberius_testset_preds["3_flanking_sequence_missing"].append(None)

            #Get a list of all test sequences
            with gzip.open(f'../../data/data_evaluation/input_testsets/mRNA_testsets_processed/input_testset_{species}_softmasked.fasta.gz', "rt") as file:
                for line in file:
                    if line.startswith(">"):
                        seq_number_testset = line.split("seq_number=")[1].split("|")[0]
                        if seq_number_testset not in seq_numbers:
                            #Get information to dict
                            tiberius_testset_preds["species"].append(species)
                            tiberius_testset_preds["seq_type"].append(line.split("type=")[1].split("|")[0])
                            tiberius_testset_preds["preds"].append(0.0)
                            tiberius_testset_preds["annotation_source"].append(line.split("source=")[1].split("|")[0])
                            tiberius_testset_preds["seq_number"].append(line.split("seq_number=")[1].split("|")[0])
                            tiberius_testset_preds["dataset_type"].append("testset")
                            tiberius_testset_preds["first_downstream_intron"].append(None)
                            tiberius_testset_preds["TSS_annotated"].append(line.split("TSS_annotated=")[1].split("|")[0])
                            tiberius_testset_preds["5_flanking_sequence_missing"].append(None)
                            tiberius_testset_preds["3_flanking_sequence_missing"].append(None)
                
                            if seq_number_testset.startswith("TIS"):
                                tiberius_testset_preds["label"].append(float(1.0))
                            else:
                                tiberius_testset_preds["label"].append(float(0.0))

        except FileNotFoundError:
            continue
        
        #Genomic test set
        try:
            gene_names = []

            with open(f'../../data/data_evaluation/Tiberius/preds/{dir}/genomic/genomic_preds_{species}.gtf', "r") as file:

                seq_pred_info_dict = dict()
                for line in file:
                    gene_name = line.split("|")[1].split("|")[0]
                    gene_names.append(gene_name)
                    entry_line = line.split("\t")[0]
                    
                    if entry_line not in seq_pred_info_dict:
                        seq_pred_info_dict[entry_line] = []
                    
                    if line.split("\t")[2] == "CDS" in line:
                        if "cds_type=single" in line or "cds_type=initial" in line:
                            seq_pred_info_dict[entry_line].append(int(line.split("\t")[3])-1)

            for entry_line in seq_pred_info_dict.keys():
                gene_name = entry_line.split("|")[1] 
                source = entry_line.split("source=")[1].split("|")[0] 
                atg_pos = int(entry_line.split("TIS_position=")[1].split("|")[0])
                TSS_annotated = entry_line.split("TSS_annotated=")[1].split("|")[0].strip()
                first_downstream_intron = entry_line.split("first_downstream_intron_start=")[1].split("|")[0]
                missing_5_flanking_seq = entry_line.split("5_flanking_sequence_missing=")[1].split("|")[0]
                missing_3_flanking_seq = entry_line.split("3_flanking_sequence_missing=")[1].split("|")[0]
                
                if atg_pos in seq_pred_info_dict[entry_line]:
                    pred = 1.0
                else:
                    pred = 0.0

                #Get information to dict
                tiberius_genomic_preds["species"].append(species)
                tiberius_genomic_preds["seq_type"].append("Genomic TIS")
                tiberius_genomic_preds["label"].append(float(1.0))
                tiberius_genomic_preds["preds"].append(pred)
                tiberius_genomic_preds["annotation_source"].append(source)
                tiberius_genomic_preds["seq_number"].append(gene_name)
                tiberius_genomic_preds["dataset_type"].append("genomic")
                tiberius_genomic_preds["first_downstream_intron"].append(first_downstream_intron)
                tiberius_genomic_preds["TSS_annotated"].append(TSS_annotated)
                tiberius_genomic_preds["5_flanking_sequence_missing"].append(missing_5_flanking_seq)
                tiberius_genomic_preds["3_flanking_sequence_missing"].append(missing_3_flanking_seq)


            with gzip.open(f'../../data/data_evaluation/input_testsets/genomic_testsets/genes_extended_1000bp/genomic_testset_{species}.fasta.gz', "rt") as file:
                for line in file:
                    if line.startswith(">"):
                        gene_name = line.split(">")[1].split("|")[0]
                        if gene_name not in gene_names:
                            #Get information to dict
                            tiberius_genomic_preds["species"].append(species)
                            tiberius_genomic_preds["seq_type"].append("Genomic TIS")
                            tiberius_genomic_preds["label"].append(float(1.0))
                            tiberius_genomic_preds["preds"].append(0.0)
                            tiberius_genomic_preds["annotation_source"].append(line.split("source=")[1].split("|")[0])
                            tiberius_genomic_preds["seq_number"].append(gene_name)
                            tiberius_genomic_preds["dataset_type"].append("genomic")
                            tiberius_genomic_preds["first_downstream_intron"].append(line.split("first_downstream_intron_start=")[1].split("|")[0])
                            tiberius_genomic_preds["TSS_annotated"].append(line.split("TSS_annotated=")[1].split("|")[0].strip())
                            tiberius_genomic_preds["5_flanking_sequence_missing"].append(line.split("5_flanking_sequence_missing=")[1].split("|")[0])
                            tiberius_genomic_preds["3_flanking_sequence_missing"].append(line.split("3_flanking_sequence_missing=")[1].split("|")[0])
        except FileNotFoundError:
            continue
                
    tiberius_df = pd.DataFrame(tiberius_testset_preds)
    tiberius_genomic_df = pd.DataFrame(tiberius_genomic_preds)

    return tiberius_df, tiberius_genomic_df

### Get all predictions

In [None]:
print("Getting TIS Transformer predictions")
TIS_transformer_df, TIS_transformer_genomic_df = get_tis_transformer_preds()
#Save processed predictions
TIS_transformer_df.to_csv('../../data/data_evaluation/preds_processed/testset/TIS_transformer_df.csv.gz', index=False, compression = "gzip")
TIS_transformer_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/TIS_transformer_genomic_df.csv.gz', index=False, compression = "gzip")

print("Getting AUGUSTUS predictions")
augustus_df, augustus_genomic_df = get_augustus_preds(softmask = True)
augustus_no_softmask_df, augustus_genomic_no_softmask_df = get_augustus_preds(softmask = False)
augustus_df.to_csv('../../data/data_evaluation/preds_processed/testset/augustus_df.csv.gz', index=False, compression = "gzip")#
augustus_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/augustus_genomic_df.csv.gz', index=False, compression = "gzip")
augustus_no_softmask_df.to_csv('../../data/data_evaluation/preds_processed/testset/augustus_no_softmask_df.csv.gz', index=False, compression = "gzip")
augustus_genomic_no_softmask_df.to_csv('../../data/data_evaluation/preds_processed/genes/augustus_genomic_no_softmask_df.csv.gz', index=False, compression = "gzip")

print("Getting NetStart 2.0 predictions")
netstart_overall_df, netstart_genomic_overall_df = get_netstart_preds(preds_subpath = "netstart2.0/preds/overall")
netstart_overall_df.to_csv('../../data/data_evaluation/preds_processed/testset/netstart_overall_df.csv.gz', index=False, compression = "gzip")
netstart_genomic_overall_df.to_csv('../../data/data_evaluation/preds_processed/genes/netstart_genomic_overall_df.csv.gz', index=False, compression = "gzip")

netstart_group_df, netstart_genomic_group_df = get_netstart_preds(preds_subpath = "netstart2.0/preds/group_specific")
netstart_group_df.to_csv('../../data/data_evaluation/preds_processed/testset/netstart_group_df.csv.gz', index=False, compression = "gzip")
netstart_genomic_group_df.to_csv('../../data/data_evaluation/preds_processed/genes/netstart_genomic_group_df.csv.gz', index=False, compression = "gzip")

print("Getting Tiberius predictions")
tiberius_df, tiberius_genomic_df = get_tiberius_preds(softmask = True)
tiberius_no_softmask_df, tiberius_no_softmask_genomic_df = get_tiberius_preds(softmask = False)

tiberius_df.to_csv('../../data/data_evaluation/preds_processed/testset/tiberius_df.csv.gz', index=False, compression = "gzip")
tiberius_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/tiberius_genomic_df.csv.gz', index=False, compression = "gzip")
tiberius_no_softmask_df.to_csv('../../data/data_evaluation/preds_processed/testset/tiberius_no_softmask_df.csv.gz', index=False, compression = "gzip")
tiberius_no_softmask_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/tiberius_no_softmask_genomic_df.csv.gz', index=False, compression = "gzip")

print("Getting NetStart 2.0A predictions")
esm2_finetuned_df, esm2_finetuned_genomic_df = get_netstart_preds(preds_subpath = "ablations/esm2_finetuned")
esm2_finetuned_df.to_csv('../../data/data_evaluation/preds_processed/testset/esm2_finetuned_df.csv.gz', index=False, compression = "gzip")
esm2_finetuned_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/esm2_finetuned_genomic_df.csv.gz', index=False, compression = "gzip")

print("Getting NetStart 1.0A predictions")
netstart1_df, netstart1_genomic_df = get_netstart_preds(preds_subpath = "ablations/netstart1")
#netstart1_df.to_csv('../../data/data_evaluation/preds_processed/testset/netstart1_df.csv.gz', index=False, compression = "gzip")
netstart1_genomic_df.to_csv('../../data/data_evaluation/preds_processed/genes/netstart1_genomic_df.csv.gz', index=False, compression = "gzip")

Getting NetStart 1.0-ish predictions


Processing ablations/netstart1: 100%|██████████| 60/60 [29:48<00:00, 29.80s/it] 


## Get predictions with shuffled species

In [None]:
def get_netstart_preds_species_shuffled(preds_subpath):
   netstart_testset_preds = init_dict_testsets()
   
   for species in tqdm(species_list, desc=f"Processing {preds_subpath}"):
       #Process testset
       try:
           df = pd.read_csv(f"../../data/data_evaluation/{preds_subpath}/preds_{species}.csv.gz", compression = "gzip")
           
           # Extract all fields using vectorized operations
           df['seq_number'] = df['entry_line'].str.extract(r'seq_number=([^|]+)')
           df['seq_type'] = df['entry_line'].str.extract(r'type=([^|]+)') 
           df['label'] = df['entry_line'].str.extract(r'TIS=([^|]+)').astype(float)
           df['source'] = df['entry_line'].str.extract(r'source=([^|]+)')
           df['TSS_annotated'] = df['entry_line'].str.extract(r'TSS_annotated=(.+)$')
           df['ATG_pos'] = df['entry_line'].str.extract(r'ATG_pos=(\d+)').astype(int)
           
           #Filter matches and append in bulk
           matches = df[df['ATG_pos'] == df['atg_position']]
           
           if not matches.empty:
               netstart_testset_preds['species'].extend([species] * len(matches))
               netstart_testset_preds['seq_type'].extend(matches['seq_type'])
               netstart_testset_preds['label'].extend(matches['label'])
               netstart_testset_preds['preds'].extend(matches['preds'].astype(float))
               netstart_testset_preds['annotation_source'].extend(matches['source'])
               netstart_testset_preds['seq_number'].extend(matches['seq_number'])
               netstart_testset_preds['dataset_type'].extend(['testset'] * len(matches))
               netstart_testset_preds['first_downstream_intron'].extend([None] * len(matches))
               netstart_testset_preds['TSS_annotated'].extend(matches['TSS_annotated'])
               netstart_testset_preds['5_flanking_sequence_missing'].extend([None] * len(matches))
               netstart_testset_preds['3_flanking_sequence_missing'].extend([None] * len(matches))
       except FileNotFoundError:
           continue
           
   return pd.DataFrame(netstart_testset_preds)

In [None]:
#Get processed predictions with species label shuffled within overall systematic group
#netstart_species_unknown_df = get_netstart_preds_species_shuffled(preds_subpath = "netstart2.0/preds/overall/testset_species_unknown")
#netstart_species_unknown_df.to_csv('../../data/data_evaluation/preds_processed/testset/netstart_species_unknown_df.csv.gz', index=False, compression = "gzip")

#Get processed predictions with species label shuffled within overall systematic group
#netstart_phylum_df = get_netstart_preds_species_shuffled(preds_subpath = "netstart2.0/preds/overall/testset_species_phylum_level")
#netstart_phylum_df.to_csv('../../data/data_evaluation/preds_processed/testset/netstart_phylum_level_df.csv.gz', index=False, compression = "gzip")

Processing netstart2.0/preds/overall/testset_species_phylum_level: 100%|██████████| 60/60 [00:47<00:00,  1.27it/s]
