### ULTRA

In [1]:
from __future__ import print_function
import json
from matplotlib import pyplot as plt
import os
import sys
import glob
from time import time
import itertools
from itertools import islice, chain
from struct import *
import shutil
import pandas as pd
import argparse
import errno
import math

# import pickle
import dill as pickle 
import gffutils
import pysam
import numpy as np
import multiprocessing as mp
from multiprocessing import Pool
from collections import defaultdict
# from collections import OrderedDict

# from modules import create_splice_graph as splice_graph
# from modules import graph_chainer 

from uLTRA.modules import create_augmented_gene as augmented_gene 
from uLTRA.modules import mem_wrapper 
from uLTRA.modules import colinear_solver 
from uLTRA.modules import help_functions
from uLTRA.modules import classify_read_with_mams
from uLTRA.modules import classify_alignment2
from uLTRA.modules import sam_output
from uLTRA.modules import align
from uLTRA.modules import prefilter_genomic_reads

from new_modules import functions
from new_modules import evaluate_exons
from new_modules import evaluate_splice_annotations
from new_modules import get_diff_loc_reads
from uLTRA.evaluation import plot_correctness_per_exon_size
from uLTRA.evaluation import plots
from uLTRA.evaluation import venn_diagram

In [2]:
# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')

# Restore
def enablePrint():
    sys.stdout = sys.__stdout__

def output_final_alignments(ultra_alignments_path, path_indexed_aligned, path_unindexed_aligned):
    # read in all reads from alternative aligner that was also mapped by uLTRA
    
    alt_alignments_file = pysam.AlignmentFile(path_indexed_aligned, "r", check_sq=False)
    alt_alignments = { read.query_name : read for read in alt_alignments_file.fetch(until_eof=True) if not read.is_secondary }


    alignment_infile = pysam.AlignmentFile( ultra_alignments_path, "r" )
    tmp_merged_outfile = pysam.AlignmentFile( ultra_alignments_path.decode()+ 'tmp', "w", template= alignment_infile)
    replaced_unaligned_cntr = 0
    tot_counter = 0
    scoring_dict = defaultdict(int)
    for read in alignment_infile.fetch():
        if not read.is_secondary:
            tot_counter += 1

        if read.query_name in alt_alignments:
            if read.flag == 4:
                read = alt_alignments[ read.query_name ] # replace unmapped uLTRA read with alternative alignment if mapped
                replaced_unaligned_cntr += 1
            elif not read.is_secondary: 
                ultra_scoring_diff, classification = check_alignment_fit(read,  alt_alignments[ read.query_name ])
                scoring_dict[ultra_scoring_diff] += 1
                if ultra_scoring_diff < 0:
                    read = alt_alignments[ read.query_name ] # replace uLTRA read with alternative alignment if better fit

        tmp_merged_outfile.write(read)
    alignment_infile.close()
    # path_genomic_aligned = os.path.join(args.outfolder, "unindexed.sam")

    # add all reads that we did not attempt to align with uLTRA
    # these reads had a primary alignment to unindexed regions by other pre-processing aligner (minimap2 as of now)
    not_attempted_cntr = 0
    unindexed = pysam.AlignmentFile(path_unindexed_aligned, "r")
    for read in unindexed.fetch():
        tmp_merged_outfile.write(read)
        if not read.is_secondary: 
            not_attempted_cntr += 1
    unindexed.close()
    tmp_merged_outfile.close()
    print("{0} reads were not attempted to be aligned with ultra, instead alternative aligner was used.".format(not_attempted_cntr))
    print("{0} reads with primary alignments were replaced with alternative aligner because they were unaligned with uLTRA.".format(replaced_unaligned_cntr))
    print("{0} primary alignments had best fit with uLTRA.".format(sum([v for k,v in scoring_dict.items() if k > 0])))
    print("{0} primary alignments had equal fit.".format(scoring_dict[0]))
    print("{0} primary alignments had best fit with alternative aligner.".format(sum([v for k,v in scoring_dict.items() if k < 0])))

    bin_boundaries = [-2**32, -100,-50,-20,-10,-5,-4,-3,-2,-1, 0, 1, 2, 3, 4, 5, 10, 20, 50, 100, 2**32]
    n = len(bin_boundaries)
    counts = [0]*n #{ (b_l, b_u) : 0 for b_l, b_u in zip(bin_boundaries[:-1], bin_boundaries[1:])}
    start_next = 0
    for k in sorted(scoring_dict.keys()):
        for i in range(start_next, n):
            b = bin_boundaries[i]
        # for i, b in enumerate(bin_boundaries):
            if k < b:
                counts[i] += scoring_dict[k]
            else:
                start_next = i


    print("Table of score-difference between alignment methods (negative number: alternative aligner better fit, positive number is uLTRA better fit)")
    print("Score is calculated as sum(identities) - sum(ins, del, subs)")
    print("Format: Score difference: Number of primary alignments ")
    for i in range(len(counts)-1):
        print("[{0} - {1}): {2}".format(bin_boundaries[i],bin_boundaries[i+1], counts[i+1] - counts[i]))
    # print("{0} read with primary alignments aligned with uLTRA.".format(tot_counter - replaced_unaligned_cntr - replaced_fit_cntr))

    shutil.move(ultra_alignments_path.decode()+ 'tmp', ultra_alignments_path)

def batch(dictionary, size, batch_type):
    # if batch_type == 'nt':
    #     total_nt = sum([len(seq) for seq in dictionary.values() ])
    batches = []
    sub_dict = {}
    curr_nt_count = 0
    for i, (acc, seq) in enumerate(dictionary.items()):
        curr_nt_count += len(seq)
        if curr_nt_count >= size:
            sub_dict[acc] = seq
            batches.append(sub_dict)
            sub_dict = {}
            curr_nt_count = 0
        else:
            sub_dict[acc] = seq

    if curr_nt_count/size != 0:
        sub_dict[acc] = seq
        batches.append(sub_dict)
    
    return batches


def load_reference(args):
    refs = {acc : seq for acc, (seq, _) in help_functions.readfq(open(args.ref,"r"))}
    refs_lengths = { acc : len(seq) for acc, seq in refs.items()} 
    return refs, refs_lengths

def prep_splicing(args, refs_lengths):
    if args.index:
        index_folder = args.index
        help_functions.mkdir_p(index_folder)
    else:
        index_folder = args.outfolder

    database = os.path.join(index_folder,'database.db')

    if os.path.isfile(database):
        print("Database found in directory using this one.")
        print("If you want to recreate the database, please remove the file: {0}".format(database))
        print()
        db = gffutils.FeatureDB(database, keep_order=True)
        # sys.exit()
    elif not args.disable_infer:
        db = gffutils.create_db(args.gtf, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True)
        db = gffutils.FeatureDB(database, keep_order=True)
    else:
        db = gffutils.create_db(args.gtf, dbfn=database, force=True, keep_order=True, merge_strategy='merge', 
                                sort_attribute_values=True, disable_infer_genes=True, disable_infer_transcripts=True)
        db = gffutils.FeatureDB(database, keep_order=True)

    
    segment_to_ref, parts_to_segments, splices_to_transcripts, \
    transcripts_to_splices, all_splice_pairs_annotations, \
    all_splice_sites_annotations, segment_id_to_choordinates, \
    segment_to_gene, gene_to_small_segments, flank_choordinates, \
    max_intron_chr, exon_choordinates_to_id, chr_to_id, id_to_chr = augmented_gene.create_graph_from_exon_parts(db, args.flank_size, args.small_exon_threshold, args.min_segm, refs_lengths)

    # dump to pickle here! Both graph and reference seqs
    # help_functions.pickle_dump(args, genes_to_ref, 'genes_to_ref.pickle')
    help_functions.pickle_dump(index_folder, segment_to_ref, 'segment_to_ref.pickle')
    help_functions.pickle_dump(index_folder, splices_to_transcripts, 'splices_to_transcripts.pickle')
    help_functions.pickle_dump(index_folder, transcripts_to_splices, 'transcripts_to_splices.pickle')
    help_functions.pickle_dump(index_folder, parts_to_segments, 'parts_to_segments.pickle')
    help_functions.pickle_dump(index_folder, all_splice_pairs_annotations, 'all_splice_pairs_annotations.pickle')
    help_functions.pickle_dump(index_folder, all_splice_sites_annotations, 'all_splice_sites_annotations.pickle')
    help_functions.pickle_dump(index_folder, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(index_folder, segment_to_gene, 'segment_to_gene.pickle')
    help_functions.pickle_dump(index_folder, gene_to_small_segments, 'gene_to_small_segments.pickle')
    help_functions.pickle_dump(index_folder, flank_choordinates, 'flank_choordinates.pickle')
    help_functions.pickle_dump(index_folder, max_intron_chr, 'max_intron_chr.pickle')
    help_functions.pickle_dump(index_folder, exon_choordinates_to_id, 'exon_choordinates_to_id.pickle')
    help_functions.pickle_dump(index_folder, chr_to_id, 'chr_to_id.pickle')
    help_functions.pickle_dump(index_folder, id_to_chr, 'id_to_chr.pickle')

    
def check_alignment_fit(aln_ultra, aln_other):
    """
        returns: 
        1. the differnce in scoring is positive if uLTRA better
        2. the classification obtained by uLTRA
    """
    diffs = {1,2,8} # cigar IDs for INS, DEL, SUBS
    matches_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ == 7])
    diffs_ultra = sum([length for type_, length in aln_ultra.cigartuples if type_ in diffs]) 
    matches_other = sum([length for type_, length in aln_other.cigartuples if type_ == 7])
    diffs_other = sum([length for type_, length in aln_other.cigartuples if type_ in diffs]) 
    # print(matches_ultra, diffs_ultra, matches_other, diffs_other, matches_other - diffs_other <= matches_ultra - diffs_ultra)
    # return matches_other - diffs_other <= matches_ultra - diffs_ultra
    return (matches_ultra - diffs_ultra) - (matches_other - diffs_other), aln_ultra.get_tag('XC')

def prep_seqs(args, refs, refs_lengths):
    if args.index:
        index_folder = args.index
    else:
        index_folder = args.outfolder

    parts_to_segments = help_functions.pickle_load( os.path.join(index_folder, 'parts_to_segments.pickle') )
    segment_id_to_choordinates = help_functions.pickle_load( os.path.join(index_folder, 'segment_id_to_choordinates.pickle') )
    segment_to_ref = help_functions.pickle_load( os.path.join(index_folder, 'segment_to_ref.pickle') )
    flank_choordinates = help_functions.pickle_load( os.path.join(index_folder, 'flank_choordinates.pickle') )
    exon_choordinates_to_id = help_functions.pickle_load( os.path.join(index_folder, 'exon_choordinates_to_id.pickle') )
    chr_to_id = help_functions.pickle_load( os.path.join(index_folder, 'chr_to_id.pickle') )
    id_to_chr = help_functions.pickle_load( os.path.join(index_folder, 'id_to_chr.pickle') )

    # for chr_id in id_to_chr:
    #     print(chr_id, id_to_chr[chr_id])

    # tiling_parts_to_segments = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_parts_to_segments.pickle') )
    # tiling_segment_id_to_choordinates = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_id_to_choordinates.pickle') )
    # tiling_segment_to_ref = help_functions.pickle_load( os.path.join(args.outfolder, 'tiling_segment_to_ref.pickle') )
    
    print( "Number of ref seqs in gff:", len(parts_to_segments.keys()))

    refs_id = {}

    not_in_annot = set()
    for acc, seq in refs.items():
        if acc not in chr_to_id:
            not_in_annot.add(acc)
        else:
            acc_id = chr_to_id[acc]
            refs_id[acc_id] = seq

    refs_id_lengths = { acc_id : len(seq) for acc_id, seq in refs_id.items()} 
    help_functions.pickle_dump(index_folder, refs_id_lengths, 'refs_id_lengths.pickle')
    help_functions.pickle_dump(index_folder, refs_lengths, 'refs_lengths.pickle')

    print( "Number of ref seqs in fasta:", len(refs.keys()))

    not_in_ref = set(chr_to_id.keys()) - set(refs.keys())
    if not_in_ref:
        print("Warning: Detected {0} sequences that are in annotation but not in reference fasta. Using only sequences present in fasta. The following sequences cannot be detected in reference fasta:\n".format(len(not_in_ref)))
        for s in not_in_ref:
            print(s)

    if not_in_annot:
        print("Warning: Detected {0} sequences in reference fasta that are not in annotation:\n".format(len(not_in_annot)))
        for s in not_in_annot:
            print(s, "with length:{0}".format(len(refs[s])))
    # ref_part_sequences, ref_flank_sequences = augmented_gene.get_part_sequences_from_choordinates(parts_to_segments, flank_choordinates, refs_id)
    ref_part_sequences = augmented_gene.get_sequences_from_choordinates(parts_to_segments, refs_id)
    ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)

    if not args.use_NAM_seeds: # not using NAM seeds
        augmented_gene.mask_abundant_kmers(ref_part_sequences, args.min_mem, args.mask_threshold)
        augmented_gene.mask_abundant_kmers(ref_flank_sequences, args.min_mem, args.mask_threshold)

    # print([unpack('LLL',t) for t in ref_flank_sequences.keys()])
    ref_part_sequences = help_functions.update_nested(ref_part_sequences, ref_flank_sequences)
    ref_segment_sequences = augmented_gene.get_sequences_from_choordinates(segment_id_to_choordinates, refs_id)
    # ref_flank_sequences = augmented_gene.get_sequences_from_choordinates(flank_choordinates, refs_id)
    ref_exon_sequences = augmented_gene.get_sequences_from_choordinates(exon_choordinates_to_id, refs_id)
    help_functions.pickle_dump(index_folder, segment_id_to_choordinates, 'segment_id_to_choordinates.pickle')
    help_functions.pickle_dump(index_folder, ref_part_sequences, 'ref_part_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_segment_sequences, 'ref_segment_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_flank_sequences, 'ref_flank_sequences.pickle')
    help_functions.pickle_dump(index_folder, ref_exon_sequences, 'ref_exon_sequences.pickle')

def align_reads(args):
    stats = {}
   
    stats["dataset"] = args.name
    
    
    if args.nr_cores > 1:
        if(not mp.get_context()):
            mp.set_start_method('spawn')
        print(mp.get_context())
        print("Environment set:", mp.get_context())
        print("Using {0} cores.".format(args.nr_cores))

    if args.index:
        if os.path.isdir(args.index):
            index_folder = args.index
        else:
            print("The index folder specified for alignment is not found. You specified: ", args.index )
            print("Build  the index to this folder, or specify another forder where the index has been built." )
            sys.exit()
    else:
        index_folder = args.outfolder

    # topological_sorts = help_functions.pickle_load( os.path.join(args.outfolder, 'top_sorts.pickle') )
    # path_covers = help_functions.pickle_load( os.path.join(args.outfolder, 'paths.pickle') )

    ref_part_sequences = help_functions.pickle_load( os.path.join(index_folder, 'ref_part_sequences.pickle') )
    refs_id_lengths = help_functions.pickle_load( os.path.join(index_folder, 'refs_id_lengths.pickle') )
    refs_lengths = help_functions.pickle_load( os.path.join(index_folder, 'refs_lengths.pickle') )

    reads_start = args.reads
    # here reads change
    if not args.disable_mm2:
        print("Filtering reads aligned to unindexed regions with minimap2 ")
        nr_reads_to_ignore, path_reads_to_align = prefilter_genomic_reads.main(ref_part_sequences, args.ref, args.reads,
                                                                               args.outfolder, index_folder, args.nr_cores,
                                                                               args.genomic_frac, args.mm2_ksize, args.minimap_path)
        args.reads = path_reads_to_align
        print("Done filtering. Reads filtered:{0}".format(nr_reads_to_ignore))
    ref_path = os.path.join(args.outfolder, "refs_sequences.fa")
    refs_file = open(ref_path, 'w') #open(os.path.join(outfolder, "refs_sequences_tmp.fa"), "w")
    for sequence_id, seq  in ref_part_sequences.items():
        chr_id, start, stop = unpack('LLL',sequence_id)
        # for (start,stop), seq  in ref_part_sequences[chr_id].items():
        refs_file.write(">{0}\n{1}\n".format(str(chr_id) + str("^") + str(start) + "^" + str(stop), seq))
    refs_file.close()

    del ref_part_sequences

    ######### FIND MEMS WITH MUMMER #############
    #############################################
    #############################################

    mummer_start = time()
    if args.use_NAM_seeds:
        print("Processing reads for MEM finding")
        reads_tmp = open(os.path.join(args.outfolder, 'reads_tmp.fq'), 'w')
        for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
            # print(seq)
            # print(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA))
            reads_tmp.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
        reads_tmp.close()
        args.reads_tmp = reads_tmp.name
        mem_wrapper.find_nams_strobemap(args.outfolder, args.reads_tmp, ref_path, args.outfolder, args.nr_cores, args.min_mem)
    else: # Use slaMEM
        if args.nr_cores == 1:
            print("Processing reads for MEM finding")
            reads_tmp = open(os.path.join(args.outfolder, 'reads_tmp.fq'), 'w')
            for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
                # print(seq)
                # print(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA))
                reads_tmp.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
            reads_tmp.close()
            args.reads_tmp = reads_tmp.name
            print("Finished processing reads for MEM finding ")

            mummer_out_path =  os.path.join( args.outfolder, "seeds_batch_-1.txt" )
            print("Running MEM finding forward") 
            mem_wrapper.find_mems_slamem(args.slamem_path, args.mummer_path, args.outfolder, args.reads_tmp, ref_path, mummer_out_path, args.min_mem)
            print("Completed MEM finding forward")

            print("Processing reverse complement reads for MEM finding")
            reads_rc = open(os.path.join(args.outfolder, 'reads_rc.fq'), 'w')
            for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r')):
                # print(help_functions.reverse_complement(seq))
                # print(help_functions.remove_read_polyA_ends(help_functions.reverse_complement(seq), args.reduce_read_ployA))
                reads_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
            reads_rc.close()
            args.reads_rc = reads_rc.name
            print("Finished processing reverse complement reads for MEM finding")

            mummer_out_path =  os.path.join(args.outfolder, "seeds_batch_-1_rc.txt" )
            print("Running MEM finding reverse")
            mem_wrapper.find_mems_slamem(args.slamem_path, args.mummer_path, args.outfolder, args.reads_rc, ref_path, mummer_out_path, args.min_mem)
            print("Completed MEM finding reverse")
        
        else: # multiprocess with slaMEM
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            total_nt = sum([len(seq) for seq in reads.values() ])
            batch_size = int(total_nt/int(args.nr_cores) + 1)
            print("batch nt:", batch_size, "total_nt:", total_nt)
            read_batches = batch(reads, batch_size, 'nt')
            
            #### TMP remove not to call mummer repeatedly when bugfixing #### 
            
            batch_args = []
            for i, read_batch_dict in enumerate(read_batches):
                print(len(read_batch_dict))
                read_batch_temp_file = open(os.path.join(args.outfolder, "reads_batch_{0}.fa".format(i)), "w")
                read_batch_temp_file_rc = open(os.path.join(args.outfolder, "reads_batch_{0}_rc.fa".format(i) ), "w")
                for acc, seq in read_batch_dict.items():
                    read_batch_temp_file.write('>{0}\n{1}\n'.format(acc, help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5)))
                read_batch_temp_file.close()

                for acc, seq in read_batch_dict.items():
                    read_batch_temp_file_rc.write('>{0}\n{1}\n'.format(acc, help_functions.reverse_complement(help_functions.remove_read_polyA_ends(seq, args.reduce_read_ployA, 5))))
                read_batch_temp_file_rc.close()
                
                read_batch = read_batch_temp_file.name
                read_batch_rc = read_batch_temp_file_rc.name
                mummer_batch_out_path =  os.path.join( args.outfolder, "seeds_batch_{0}.txt".format(i) )
                mummer_batch_out_path_rc =  os.path.join(args.outfolder, "seeds_batch_{0}_rc.txt".format(i) )
                batch_args.append( (args.slamem_path, args.mummer_path, args.outfolder, read_batch, ref_path, mummer_batch_out_path, args.min_mem ) )
                batch_args.append( (args.slamem_path, args.mummer_path, args.outfolder, read_batch_rc, ref_path, mummer_batch_out_path_rc, args.min_mem ) )

            pool = Pool(processes=int(args.nr_cores))
            results = pool.starmap(mem_wrapper.find_mems_slamem, batch_args)
            pool.close()
            pool.join() 


            ####################################################################


        print("Time for slaMEM to find mems:{0} seconds.".format(time()-mummer_start))
    stats["mem_time"] = time()-mummer_start
    #############################################
    #############################################
    #############################################


    print("Starting aligning reads.")
    if args.use_NAM_seeds:
        if args.nr_cores == 1:
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            classifications, alignment_outfile_name = align.align_single(reads, refs_id_lengths, args, -1)
        else:
            # OrderedDict # dicts are ordered from python v3.6 and above. 
            # One can use OrderedDict for compatibility with python v 3.4-3.5
            reads = {acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            # batch reads and mems up: divide reads by  nr_cores to get batch size
            # then write to separate SAM-files with a batch index, 
            # finally merge sam file by simple cat in terminal 
            aligning_start = time()
            # batch_size = int(len(reads)/int(args.nr_cores) + 1)
            STROBEMAP_BATCH_SIZE=500000
            read_batches = strobemap_batching(reads, STROBEMAP_BATCH_SIZE, int(args.nr_cores))
            print('Nr reads:', len(reads), "nr batches:", len(read_batches), [len(b) for b in read_batches])
            stats["reads"] = len(reads)
            classifications, alignment_outfiles = align.align_parallel(read_batches, refs_id_lengths, args)
        
            print("Time to align reads:{0} seconds.".format(time()-aligning_start))
            stats["align_time"] = time()-aligning_start
            # Combine samfiles produced from each batch
            combine_start = time()
            # print(refs_lengths)
            alignment_outfile = pysam.AlignmentFile( os.path.join(args.outfolder, args.prefix+".sam"), "w", reference_names=list(refs_lengths.keys()), reference_lengths=list(refs_lengths.values()) ) #, template=samfile)

            for f in alignment_outfiles:
                samfile = pysam.AlignmentFile(f, "r")
                for read in samfile.fetch():
                    alignment_outfile.write(read)
                samfile.close()

            alignment_outfile.close()
            alignment_outfile_name = alignment_outfile.filename
            print("Time to merge SAM-files:{0} seconds.".format(time() - combine_start))
            stats["merge_time"] = time()-combine_start

    else: # Use slaMEM
        if args.nr_cores == 1:
            reads = { acc : seq for acc, (seq, qual) in help_functions.readfq(open(args.reads, 'r'))}
            classifications, alignment_outfile_name = align.align_single(reads, refs_id_lengths, args, -1)
        else:
            # batch reads and mems up: divide reads by  nr_cores to get batch size
            # then write to separate SAM-files with a batch index, 
            # finally merge sam file by simple cat in terminal 
            aligning_start = time()
            batch_size = int(len(reads)/int(args.nr_cores) + 1)
            # read_batches = batch(reads, batch_size)
            print('Nr reads:', len(reads), "nr batches:", len(read_batches), [len(b) for b in read_batches])
            stats["reads"] = len(reads)
            classifications, alignment_outfiles = align.align_parallel(read_batches, refs_id_lengths, args)
        
            print("Time to align reads:{0} seconds.".format(time()-aligning_start))
            stats["align_time"] = time()-aligning_start

            # Combine samfiles produced from each batch
            combine_start = time()
            # print(refs_lengths)
            alignment_outfile = pysam.AlignmentFile( os.path.join(args.outfolder, args.prefix+".sam"), "w", reference_names=list(refs_lengths.keys()), reference_lengths=list(refs_lengths.values()) ) #, template=samfile)

            for f in alignment_outfiles:
                samfile = pysam.AlignmentFile(f, "r")
                for read in samfile.fetch():
                    alignment_outfile.write(read)
                samfile.close()

            alignment_outfile.close()
            alignment_outfile_name = alignment_outfile.filename
            print("Time to merge SAM-files:{0} seconds.".format(time() - combine_start))
            stats["merge_time"] = time()-combine_start


    # need to merge genomic/unindexed alignments with the uLTRA-aligned alignments
    if not args.disable_mm2:
        path_indexed_aligned = os.path.join(args.outfolder, "indexed.sam")
        path_unindexed_aligned = os.path.join(args.outfolder, "unindexed.sam")
        output_final_alignments(alignment_outfile_name, path_indexed_aligned, path_unindexed_aligned)

    counts = defaultdict(int)
    alignment_coverage = 0
    for read_acc in reads:
        if read_acc not in classifications:
            # print(read_acc, "did not meet the threshold")
            pass
        elif classifications[read_acc][0] != 'FSM':
            # print(read_acc, classifications[read_acc]) 
            pass
        if read_acc in classifications:
            alignment_coverage += classifications[read_acc][1]
            if classifications[read_acc][1] < 1.0:
                # print(read_acc, 'alignemnt coverage:', classifications[read_acc][1])
                pass
            counts[classifications[read_acc][0]] += 1
        else:
            counts['unaligned'] += 1


    print(counts)
    json.dump(functions.transform_categories(counts), open(os.path.join(args.outfolder, "counts.json"), "w"))
    print("total alignment coverage:", alignment_coverage)

    if not args.keep_temporary_files:
        print("Deleting temporary files...")
        seeds = glob.glob(os.path.join(args.outfolder, "seeds_*"))
        mum = glob.glob(os.path.join(args.outfolder, "mummer*"))
        sla = glob.glob(os.path.join(args.outfolder, "slamem*"))
        reads_tmp = glob.glob(os.path.join(args.outfolder, "reads_batch*"))
        minimap_tmp = glob.glob(os.path.join(args.outfolder, "minimap2*"))
        ultra_tmp = glob.glob(os.path.join(args.outfolder, "uLTRA_batch*"))
        
        f1 = os.path.join(args.outfolder, "reads_after_genomic_filtering.fasta")
        f2 = os.path.join(args.outfolder, "indexed.sam")
        f3 = os.path.join(args.outfolder, "unindexed.sam")
        f4 = os.path.join(args.outfolder, "refs_sequences.fa")
        f5 = os.path.join(args.outfolder, "refs_sequences.fa")
        f6 = os.path.join(args.outfolder, "reads_rc.fq")
        f7 = os.path.join(args.outfolder, "reads_tmp.fq")
        misc_files = [f1,f2,f3,f4,f5,f6,f7]
        for f in seeds + mum + sla + reads_tmp + minimap_tmp + ultra_tmp+ misc_files:
            if os.path.isfile(f):
                os.remove(f)
                print("removed:", f)
    print("Done.")
    # save stats
    stats["total_time"] = stats["align_time"] + stats["mem_time"] + stats["merge_time"]
    json.dump(stats, open(os.path.join(args.outfolder, "stats.json"), "w"))
    args.reads = reads_start

    return stats

def initialize_dump(outfolder):
    if os.path.exists(outfolder) and os.path.isdir(outfolder):
        return;
        #shutil.rmtree(outfolder)
    help_functions.mkdir_p(outfolder)
        
def ultra(args):
    # initialize dump folder
    reads = os.path.join(args.outfolder, "reads.sam")
    if (os.path.exists(reads)):
        return json.load(open(os.path.join(args.outfolder, "stats.json"), "r"))
    else:
        initialize_dump(args.outfolder)
        refs, refs_lengths = load_reference(args)
        prep_splicing(args, refs_lengths)
        prep_seqs(args, refs, refs_lengths)
        return align_reads(args)

### Minimap

In [3]:
import subprocess
import os

output = "output/test.sam"

# set minimap arguments based on supplementary data
def set_minimap_args(args):
    G = "500k" # Maximum gap on the reference (effective with -xsplice/--splice).
    t = 4 # no of threads
    k = 13 # kmer, set to 14 for ALZ
    w = 5 # minimum window size set to none for ALZ
    minimap_path = args.minimap_path

    if(args.bed == None):
        
        if("ALZ" in args.ref):
            k = 14 # kmer, set to 14 for ALZ
            t = 19 # no of threads
            G = "500k" # Maximum gap on the reference (effective with -xsplice/--splice).
            return [minimap_path, "-a", args.ref, args.reads, 
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-G", G,  "-o", os.path.join(args.outfolder, "reads.sam")]
        elif("SIRV" in args.ref):
            return [minimap_path, "-a", args.ref, args.reads, 
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-w", str(w), "-G", G, "--splice-flank=no",
                                "--secondary=no", "-C", "5", "-o", os.path.join(args.outfolder, "reads.sam")]
        else:
            return [minimap_path, "-a", args.ref, args.reads, 
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-w", str(w), "-G", G, "-o", os.path.join(args.outfolder, "reads.sam")]
    else:
        if("ALZ" in args.ref):
            k = 14 # kmer, set to 14 for ALZ
            t = 19 # no of threads
            G = "500k" # Maximum gap on the reference (effective with -xsplice/--splice).
            return [minimap_path, "-a", args.ref, args.reads, "--junc-bed", args.bed,
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-G", G,  "-o", os.path.join(args.outfolder, "reads.sam")]
        elif("SIRV" in args.ref):
            return [minimap_path, "-a", args.ref, args.reads, "--junc-bed", args.bed,
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-w", str(w), "-G", G, "--splice-flank=no",
                                "--secondary=no", "-C", "5", "-o", os.path.join(args.outfolder, "reads.sam")]
        else:
            return [minimap_path, "-a", args.ref, args.reads, "--junc-bed", args.bed,
                               "-k", str(k), "--eqx", "-t", str(t), 
                               "-ax" ,"splice", "-w", str(w), "-G", G, "-o", os.path.join(args.outfolder, "reads.sam")]
    

def minimap2(args):
    reads = os.path.join(args.outfolder, "reads.sam")
    if (not os.path.exists(reads)):
        stats = {}
        stats["dataset"] = args.name

        initialize_dump(args.outfolder)
        mm2_start = time()

        subprocess.check_call(set_minimap_args(args), env = os.environ)
        stats["total_time"] = time() - mm2_start
        json.dump(stats, open(os.path.join(args.outfolder, "stats.json"), "w"))



### deSALT

In [4]:
def set_desalt_args(args):
    d = 10 
    s = 2
    l = 14
    noncan = 9
    max_intron_length = 500000
    index_path = args.desalt_index
    desalt_path = args.desalt_path
    if(args.desalt_gtf == None):
        if("SIRV" in args.ref):
            noncan = 4
            max_intron_length = 200000
            return [desalt_path, "aln", index_path, args.reads, 
                               "-d", str(d), "-s", str(s), "-l", str(l),
                                "--noncan", str(noncan), "--max-intron-len", str(max_intron_length), 
                    "-o", os.path.join(args.outfolder, "reads.sam")]
        else:
            return [desalt_path, "aln", index_path, args.reads, 
                               "-d", str(d), "-s", str(s), "-l", str(l),
                                "--noncan", str(noncan), "--max-intron-len", str(max_intron_length),
                    "-o", os.path.join(args.outfolder, "reads.sam")]
    else:
        if("SIRV" in args.ref):
            noncan = 4
            max_intron_length = 200000
            return [desalt_path, "aln", index_path, args.reads, 
                               "-d", str(d), "-s", str(s), "-l", str(l),
                                "--noncan", str(noncan), "--max-intron-len", str(max_intron_length),
                                "--gtf", args.desalt_gtf,
                                "-o", os.path.join(args.outfolder, "reads.sam")]
        else:
            return [desalt_path, "aln", index_path, args.reads, 
                               "-d", str(d), "-s", str(s), "-l", str(l),
                                "--noncan", str(noncan), "--max-intron-len", str(max_intron_length),
                                "--gtf", args.desalt_gtf,
                                "-o", os.path.join(args.outfolder, "reads.sam")]

def deSALT(args):
    reads = os.path.join(args.outfolder, "reads.sam")
    if (not os.path.exists(reads)):
        stats = {}
        stats["dataset"] = args.name
        initialize_dump(args.outfolder)
        deSALT_start = time()
        if not os.path.exists(args.desalt_index): 
            subprocess.check_call([args.desalt_path, "index",args.ref, args.desalt_index])
            stats["index_time"] = time() - deSALT_start
        subprocess.check_call(set_desalt_args(args), env = os.environ)
        stats["total_time"] = time() - deSALT_start
        json.dump(stats, open(os.path.join(args.outfolder, "stats.json"), "w"))

### graphmap2

In [5]:
def set_graphmap2_args(args):
    graphmap_path = args.graphmap_path
    if(args.gtf == None):
        return [graphmap_path, "align",
                               "-d",  args.reads, "-r", args.ref, 
                               "-o", os.path.join(args.outfolder, "reads.sam")]
    else: 
        return [graphmap_path, "align", "--gtf", args.gtf,
                               "-d",  args.reads, "-r", args.ref, 
                               "-o", os.path.join(args.outfolder, "reads.sam")]

def graphmap2(args):
    reads = os.path.join(args.outfolder, "reads.sam")
    if (not os.path.exists(reads)):
        stats = {}
        stats["dataset"] = args.name
        initialize_dump(args.outfolder)
        graphmap2_start = time()
        subprocess.check_call(set_graphmap2_args(args), env = os.environ)
        stats["total_time"] = time() - graphmap2_start
        json.dump(stats, open(os.path.join(args.outfolder, "stats.json"), "w"))

In [6]:
from new_modules.arguments import arguments

def create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, name, tool_name, mm2, desalt_index):
    args = arguments(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, name, tool_name, mm2, desalt_index)
    s = args.reads
    s = s[s.rindex('/')+1:]
    s = s[:s.rindex('.')]
    args.name = s
    if args.ont:
        args.min_mem = 17
        args.min_acc = 0.6
        args.mm2_ksize = 14
        # args.alignment_threshold = 0.5
    if args.isoseq:
        args.min_mem = 20
        args.min_acc = 0.8
        # args.alignment_threshold = 0.5
    return args

![Datasets Used](assets/images/table1.png)

In [7]:
# Create a dict with the paths of each dataset
dataset_dict = {
    # "test" : {
    #     "ref": "data/genome/SIRV_isoforms_multi-fasta_170612a.fasta" ,
    #     "reads": "uLTRA/test/reads.fa" ,
    #     "gtf": "data/genome/annotations/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.gtf",
    #     "bed":"data/genome/annotations/bed/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.bed",
    #     "desalt_gtf": "data/genome/annotations/desalt_gtf/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.gtf",
    #     "ont": False,
    #     "isoseq": True,
    #     "disable_infer": False,
    #     "desalt_index":"data/genome/indexes/sirv",
    # },
    "ENS": {"ref": "data/genome/GRCh38.p12.genome.fa" ,
            "reads": "data/simulated/reads_ens.fa" ,
            "gtf": "data/genome/annotations/gencode.v34.chr_patch_hapl_scaff.annotation.gtf" ,
            "bed": "data/genome/annotations/bed/gencode.v34.annotation.bed",
            "desalt_gtf": "data/genome/annotations/desalt_gtf/gencode.v34.annotation.bed",
            "ont": False,
            "isoseq": False,
            "disable_infer": True,
            "desalt_index":"data/genome/indexes/human",
            "accessions_map": "data/simulated/accessions_map_ens.csv",
    },
    # "SIM_ANN": {"ref": "data/genome/GRCh38.p12_genomic.fna" ,
    #         "reads": "data/simulated/reads_sim/reads_sim.fa" ,
    #        "gtf": "data/genome/annotations/gencode.v34.chr_patch_hapl_scaff.annotation.gtf" ,
    #         "bed":"data/genome/annotations/bed/gencode.v34.annotation.bed",
    #         "desalt_gtf": "data/genome/annotations/desalt_gtf/gencode.v34.annotation.bed",
    #         "ont": False,
    #         "isoseq": False,
    #         "disable_infer": True,
    #         "desalt_index":"data/genome/indexes/human",
    #         "accessions_map": "data/simulated/accessions_map_sim.csv",
    # },
#     "SIM_NIC": {"ref": "data/genome/GRCh38.p13_genomic.fna" ,
#             "reads": "data/simulated/reads_nic.fa" ,
#             "gtf": "data/genome/annotations/gencode.v34.annotation.gff3" ,
#             "bed":"data/genome/annotations/bed/gencode.v34.annotation.bed",
#             "ont": False,
#             "isoseq": False,
#             "disable_infer": True,
#             "desalt_index":"data/genome/indexes/humanIndex",   
#             "accessions_map": "data/simulated/accessions_map_nic.csv",

#     },
#     "SIRV": {"ref": "data/genome/SIRV_isoforms_multi-fasta_170612a.fasta",
#             "reads": "data/SIRV/SIRV_processed.fastq",
#             "gtf": "data/genome/annotations/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.gtf",
#             "bed":"data/genome/annotations/bed/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.bed",
#             "desalt_gtf": "data/genome/annotations/desalt_gtf/SIRV_isoforms_Lot00141_multi-fasta-annotation_C_170612a.gtf",
#             "ont": True,
#             "isoseq": False,
#             "disable_infer": False,
#             "desalt_index":"data/genome/indexes/sirv",
#     },
    # "DROS": {"ref": "data/genome/DROS.BDGP6.28.all.fa",
    #         "reads": "data/DROS_processed.fastq",
    #         "gtf": "data/genome/annotations/Drosophila_melanogaster.BDGP6.28.100.gtf",
    #         "bed":"data/genome/annotations/bed/Drosophila_melanogaster.BDGP6.28.100.bed",
    #         "desalt_gtf": "data/genome/annotations/desalt_gtf/Drosophila_melanogaster.BDGP6.28.100.gtf",
    #         "ont": True,
    #         "isoseq": False,
    #         "disable_infer": True,
    #         "desalt_index":"data/genome/indexes/da",   
    # },
    # "ALZ": {"ref": "data/genome/GRCh38.p12.genome.fa" ,
    #         "reads": "data/simulated/reads_ens.fa" ,
    #         "gtf": "data/genome/annotations/gencode.v34.chr_patch_hapl_scaff.annotation.gtf" ,
    #         "bed": "data/genome/annotations/bed/gencode.v34.annotation.bed",
    #         "desalt_gtf": "data/genome/annotations/desalt_gtf/gencode.v34.annotation.bed",
    #         "ont": False,
    #         "isoseq": True,
    #         "disable_infer": True,
    #         "desalt_index":"data/genome/indexes/human",
    # }
}

In [None]:
def plot_time(directory):
    subdirs = next(os.walk(directory))[1]
    times = []
    tools = ['uLTRA', 'uLTRA_mm2', 'minimap2', 'minimap2_GTF', 'deSALT', 'deSALT_GTF', 'graphmap2', 'graphmap2_GTF']
    for tool in tools:
        time = json.load(open(os.path.join(directory, tool, 'stats.json'), 'r'))['total_time']
        times.append(time)
    plt.clf()
    plt.bar(tools, times)
    plt.xlabel("Tool")
    plt.ylabel("Time (s)")
    plt.savefig("time.png")
        
    
def real_data_pipeline(dataset):
    splice_annotations_csvs = []
    category_stats = {"uLTRA" :{}, "uLTRA_mm2" :{}, "minimap2" :{}, "minimap2_GTF" :{}, "deSALT" :{}, "deSALT_GTF" :{}}
    
    ref = dataset["ref"]
    gtf = dataset["gtf"]
    bed = dataset["bed"]
    desalt_gtf = dataset["desalt_gtf"]
    reads = dataset["reads"]
    ont = dataset["ont"]
    isoseq = dataset["isoseq"]
    disable_infer = dataset["disable_infer"]
    desalt_index = dataset["desalt_index"]


    #ultra
    toolname = "uLTRA"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, True, desalt_index)
    general_stats = ultra(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 

    #ultra_mm2
    toolname = "uLTRA_mm2"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    general_mm_stats = ultra(args)
    sam_file = os.path.join(args.outfolder, "reads.sam")
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, sam_file, toolname) 
    splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))

    # mm2
    toolname = "minimap2"
    args = create_args(ref, gtf, None, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    minimap2(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))
    
    # mm2_GTF
    toolname = "minimap2_GTF"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    minimap2(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))

    # deSALT
    toolname = "deSALT"
    args = create_args(ref, gtf, bed, None, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    deSALT(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    
    # deSALT_GTF
    toolname = "deSALT_GTF"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    deSALT(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))


    # # graphmap2
    # toolname = "graphmap2"
    # args = create_args(ref, None, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    # graphmap2(args)
    # args.gtf = gtf
    # category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    # splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))
    
    # graphmap2_GTF
    # toolname = "graphmap2_GTF"
    # args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    # graphmap2(args)
    # category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    # splice_annotations_csvs.append(os.path.join(args.outfolder, "splice_annotations.csv"))

    # Final outputs
    # concordance
    combined_splice_annotations = pd.concat([pd.read_csv(f) for f in splice_annotations_csvs])
    combined_splice_annotations_path = os.path.join("output", dataset_name, "splice_annotations.csv")
    combined_splice_annotations.to_csv(combined_splice_annotations_path, index=False, encoding='utf-8-sig')
    get_diff_loc_reads.get_diff_loc_reads(args, combined_splice_annotations_path, os.path.join("output", dataset_name))
    
    # extra category
    plots.category_plot(category_stats, os.path.join("output", dataset_name))
    plot_time(os.path.join("output", dataset_name))
    
    
def sim_data_pipeline(dataset):
   
    results_per_read_csvs = []
    correctness_per_exon_size_csvs = []
    category_stats = {"uLTRA" :{}, "uLTRA_mm2" :{}, "minimap2" :{}, "minimap2_GTF" :{}, "deSALT" :{}, "deSALT_GTF" :{}}
    
    ref = dataset["ref"]
    gtf = dataset["gtf"]
    bed = dataset["bed"]
    desalt_gtf = dataset["desalt_gtf"]
    reads = dataset["reads"]
    ont = dataset["ont"]
    isoseq = dataset["isoseq"]
    disable_infer = dataset["disable_infer"]
    desalt_index = dataset["desalt_index"]
    accessions_map = dataset["accessions_map"]

    # ultra
    toolname = "uLTRA"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, True, desalt_index)
    general_stats = ultra(args)
    if(not (os.path.exists(os.path.join(args.outfolder, "correctness_per_exon_size.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "correctness_per_exon_size.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read.csv"))`
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    
    #ultra_mm2
    toolname = "uLTRA_mm2"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    general_mm_stats = ultra(args)
    sam_file = os.path.join(args.outfolder, "reads.sam")
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, sam_file, toolname) 
    if(not (os.path.exists(os.path.join(args.outfolder, "correctness_per_exon_size.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "correctness_per_exon_size.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read.csv"))
    
    #minimap2
    toolname = "minimap2"
    args = create_args(ref, gtf, None, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    minimap2(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    if(not (os.path.exists(os.path.join(args.outfolder, "correctness_per_exon_size.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "correctness_per_exon_size.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read.csv"))
    
    # mm2_GTF
    toolname = "minimap2_GTF"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    minimap2(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    if(not (os.path.exists(os.path.join(args.outfolder, "correctness_per_exon_size.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "correctness_per_exon_size.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read.csv"))
    
    # desalt
    toolname = "deSALT"
    args = create_args(ref, gtf, bed, None, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    deSALT(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    if(not (os.path.exists(os.path.join(args.outfolder, "correctness_per_exon_size.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "correctness_per_exon_size.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read.csv"))`
    
    # desalt_GTf
    toolname = "deSALT_GTF"
    args = create_args(ref, gtf, bed, desalt_gtf, reads, ont, isoseq, disable_infer, dataset_name, toolname, False, desalt_index)
    deSALT(args)
    category_stats[toolname] = evaluate_splice_annotations.evaluate_splice_annotations(args, os.path.join(args.outfolder, "reads.sam"), toolname) 
    if(not (os.path.exists(os.path.join(args.outfolder, "agreement_per_exon_size_biological.csv")) 
            and os.path.exists(os.path.join(args.outfolder, "results_per_read_biological.csv")))):
        evaluate_exons.evaluate_sim_reads(args, accessions_map)
    correctness_per_exon_size_csvs.append(os.path.join(args.outfolder, "agreement_per_exon_size_biological.csv"))
    results_per_read_csvs.append(os.path.join(args.outfolder, "results_per_read_biological.csv"))
    
    # Final outputs
    # Final outputs
    results_per_read = pd.concat([pd.read_csv(f) for f in results_per_read_csvs])
    results_per_read_path = os.path.join("output", dataset_name, "results_per_read.csv")
    results_per_read.to_csv(results_per_read_path, index=False, encoding='utf-8-sig')
    
    correctness_per_exon_size = pd.concat([pd.read_csv(f) for f in correctness_per_exon_size_csvs])
    correctness_per_exon_size_path = os.path.join("output", dataset_name, "correctness_per_exon_size.csv")
    correctness_per_exon_size.to_csv(correctness_per_exon_size_path, index=False, encoding='utf-8-sig')
   
    # extra category
    plots.category_plot(category_stats, os.path.join("output", dataset_name))
    plots.alignment_accuracy_plot(results_per_read_path, os.path.join("output", dataset_name))



In [None]:
for dataset_name in dataset_dict.keys():
    dataset = dataset_dict[dataset_name]
    is_real = dataset["ont"] or dataset["isoseq"] # if the dataset is real
    if is_real:
        real_data_pipeline(dataset)
    else:
        sim_data_pipeline(dataset)