In [None]:
import os
import subprocess
import sys
import uuid
from glob import glob
import shutil
from dataclasses import dataclass
from collections import defaultdict

import pandas as pd
import numpy as np
from scipy import stats
from matplotlib import pyplot as plt
from matplotlib_venn import venn3, venn3_circles,venn3_unweighted
import sourmash

from multiprocessing import Pool
# from multiprocess import Process, Manager
import uuid
import seaborn as sns
import gzip

In [None]:


def gz_line_count(file):
    """
        equivalent to `wc -l`
    """
    with gzip.open(file, "rb") as f:
        return  sum(1 for _ in f)
def get_read_set(a):
    aa = pysam.AlignmentFile(a,'r', require_index=False)
    return {r.query_name for r in aa.fetch(until_eof=True)  if r.flag & 260 == 0}



In [None]:
refdir = "ncbi_eskapee_plassembler"

In [None]:
samples_df = pd.read_csv('samples.target.tsv', sep='\t')

In [None]:
samples_have = {x.split('/')[0] for x in glob('*/hyplass_b/assembly.final.it0.fasta',root_dir=refdir)}
samples_have &= {x.split('/')[0] for x in glob('*/hyplass_b/assembly.final.it1.fasta',root_dir=refdir)}
samples_have &= {x.split('/')[0] for x in glob('*/hyplass_b/assembly.final.it2.fasta',root_dir=refdir)}

In [None]:
sadict = samples_df[['Assembly Accession', 'Assembly BioSample Accession']].rename(columns={'Assembly BioSample Accession':'sample'}).set_index('sample').to_dict()['Assembly Accession']

In [None]:
sadict = {x:y for x,y in sadict.items() if x in samples_have}

In [None]:
plotdir = f"plots-{uuid.uuid4()}"

In [None]:
!mkdir -p -p {plotdir}


In [None]:
PLOT_PARAMS = dict(dpi=300, pad_inches='tight')

In [None]:
@dataclass
class knowledge:
    srformat: str
    lrformat: str
    short_read_to_ref_mapping: str
    short_read_coverage: str
    reformat: str
    samples: list[str]
    sample2genomeacc: dict[str, str]
    sample_plasmid_blacklist_path: str
    experiment_plasmid_blacklist_path: str
    method_asm_fmt: dict[str,str]
    method_chosen_lr_fmt: dict[str,str]
    results_df_tsv: str
    hyplass_initial_selection_fmt: str
refpath = f'{refdir}/eskapee/genomes'
datapath = refdir
analysispath = 'analysis'

K = knowledge(
    srformat= datapath + "/{sample}/qc/{sample}.sr{no}.fastq.gz",
    lrformat= datapath + "/{sample}/qc/trim.lr.fastq.gz",
    short_read_to_ref_mapping= analysispath + "/{sample}/sorted.bam",
    short_read_coverage= analysispath + "/{sample}/coverage.txt",
    reformat= refpath + "/{sample}.fa",
    sample_plasmid_blacklist_path= analysispath + "/{sample}.badplasmid.tsv",
    experiment_plasmid_blacklist_path= analysispath + "/combined_badplasmid.tsv",
    samples=list(sadict.keys()),
    sample2genomeacc=sadict,
    method_asm_fmt={
        "HyPlAs-1":datapath + "/{sample}/hyplass_b/assembly.final.it1.fasta",
        "HyPlAs-2":datapath + "/{sample}/hyplass_b/assembly.final.it2.fasta",
        "Plassembler-Raven":datapath + "/{sample}/plassembler/plassembler_plasmids.fasta",
        "Plassembler-Flye":datapath + "/{sample}/plassembler_flye/plassembler_plasmids.fasta",
    },
    method_chosen_lr_fmt={
    "HyPlAs-1":datapath + "/{sample}/hyplass_b_cov/plasmid_selected_i1.coverage.tsv", 
    "HyPlAs-2":datapath + "/{sample}/hyplass_b_cov/plasmid_selected_i2.coverage.tsv", 
    "Plassembler-Raven":datapath + "/{sample}/plassembler-qc/toplasmids.coverage.tsv", 
    "Plassembler-Flye":datapath + "/{sample}/plassembler_flye-qc/toplasmids.coverage.tsv", 
    "All":datapath + "/{sample}/qc/toplasmids.coverage.tsv", 
    },
    results_df_tsv = analysispath + "/all_results_length.tsv",
    hyplass_initial_selection_fmt = datapath + "/{sample}/hyplass_a/plasmid_long_reads",
)

In [None]:
!mkdir -p {analysispath}

In [None]:
def generate_fasta(fasta):
    name = "NULL"
    seq = list()
    if fasta.endswith(".gz"):
        f = gzip.open(fasta, "rt")
    else:
        f = open(fasta, "r")
    header = "NULL"
    for _, l in enumerate(f):
        l = l.rstrip("\n")
        if l[0] == ">":
            if len(seq) == 0:
                name = l[1:].split(" ")[0]
                header = l[1:]
                continue
            yield (name, header,"".join(seq))
            seq = list()
            name = l[1:].split(" ")[0]
            header = l[1:]

        else:
            seq.append(l)
    yield (name, header,"".join(seq))
    
def make_hashes_gen(ref_path, N,K,Scale, whitelist={}, need_circular=False):
    for name,h,seq in generate_fasta(ref_path):
        if need_circular and "circular" not in h:
            continue
        if not whitelist or name in whitelist:
            hashes = sourmash.MinHash(n=N,ksize=K,scaled=Scale)
            hashes.add_sequence(seq, force=True)
            yield (name, hashes)
def make_hashes(ref_path, mashparam, whitelist={}, need_circular=False):
    try:
        return list(make_hashes_gen(ref_path, *mashparam, whitelist=whitelist, need_circular=need_circular))
    except:
        return []

In [None]:
def all2allcompare(a,b,minimum_req_match=0.75):
    # print(f"{a}\t!\t{b}")
    a_used = set()
    b_used = set()
    matches = []

    b_contigs = [x for x,y in b]
    for k1,v1 in a:
        distances = np.array([v1.jaccard(v2) for k2, v2 in b])
        _order = np.argsort(distances)
        for i, o in enumerate(_order):
            if b_contigs[o] not in b_used:
                if distances[o] > minimum_req_match:
                    b_used.add(b_contigs[o])
                    a_used.add(k1)
                    matches.append((k1, b_contigs[o]))
                    break
    return {
        "a-count":len(a),   
        "b-count":len(b),
        "matches":matches
    }
def compute_lr_stats3s(A,B,C):
    overlaps = [
        len(A - B - C),
        len(B - A - C),
        len(A & B - C),
        len(C - A - B),
        len(A & C - B),
        len(C & B - A),
        len(C & B & A)
    ]
    
    return  overlaps

In [None]:


def remove_outliers(df, field):
    return df[np.abs(stats.zscore(df[field])) < 3]

class sampleget:
    @classmethod
    def short_reads(cls, sample_id):
        sr_paths = [K.srformat.format(sample=sample_id, no=i) for i in [1,2]]
        if os.path.isfile(sr_paths[0]):
            return sr_paths
    @classmethod
    def plasmid_reference(cls, sample_id):
        ref_path = K.reformat.format(sample=sample_id)

        accession = K.sample2genomeacc[sample_id]
        if os.path.isfile(ref_path):
            return ref_path
        os.makedirs(ref_path[:ref_path.rfind('/')],exist_ok=True)
        tmp_filename = f"{sample_id}_{str(uuid.uuid4())}.zip"

        cmd = [
            "datasets",
            "download",
            "genome",
            "accession",
            accession,
            "--filename",
            tmp_filename,
            "--include",
            "genome",
            "--no-progressbar"
        ]
        ret = subprocess.run(cmd)
        subprocess.run(['mkdir', '-p', tmp_filename[:-4]])

        if ret.returncode != 0:
            print(f"Something wrong with {cmd}", file=sys.stderr)
            return None
        subprocess.run(['unzip', tmp_filename, '-d', tmp_filename[:-4]], capture_output=True)
        subprocess.run(['rm', tmp_filename])
        tmp_filename = tmp_filename[:-4]

        ref_file = glob(f'{tmp_filename}/ncbi_dataset/data/*/*.fna')[0]
        shutil.move(ref_file, ref_path)
        shutil.rmtree(tmp_filename)
        return ref_path
    @classmethod
    def short_read_to_ref_mapping(cls, sample_id, mm2_param=["-a", "-t", "32", "-x", "sr"], st_param=["-@", "16", "-m", "8G"]):
        mapping_path = K.short_read_to_ref_mapping.format(sample=sample_id)
        if os.path.isfile(mapping_path):
            return mapping_path
        os.makedirs(mapping_path[:mapping_path.rfind('/')], exist_ok=True)
        cmd = [
        "minimap2",
        cls.plasmid_reference(sample_id),
        *cls.short_reads(sample_id),
        "-o",
        mapping_path[:-4]+".sam",
        *mm2_param
        ]
        ret = subprocess.run(cmd,capture_output=True)
        if ret.returncode != 0:
            print(f"Something wrong with {cmd}", file=sys.stderr)
            return None
        
        cmd = [
            "samtools",
            "sort",
            mapping_path[:-4]+".sam",
            "-o",
            mapping_path,
            *st_param
        ]
        
        ret = subprocess.run(cmd)
        if ret.returncode != 0:
            print(f"Something wrong with {cmd}", file=sys.stderr)
            return None
        ret = subprocess.run( ['rm', mapping_path[:-4]+".sam"])

        return mapping_path
    @classmethod
    def short_read_coverage_on_ref(cls, sample_id):
        cov_path = K.short_read_coverage.format(sample=sample_id)
        if os.path.isfile(cov_path):
            return cov_path
        cmd = [
            "samtools",
            "coverage",
            cls.short_read_to_ref_mapping(sample_id),
            "-o",
            cov_path
        ]
        
        ret = subprocess.run(cmd)
        if ret.returncode != 0:
            print(f"Something wrong with {cmd}", file=sys.stderr)
            return None
        return cov_path
    @classmethod
    def plasmid_blacklist(cls, sample_id, require_complete=True, require_min_length=1000, require_max_length=2_000_000):
        bl_path = K.sample_plasmid_blacklist_path.format(sample=sample_id)
        if os.path.isfile(bl_path):
            return bl_path
        
        ref_path = cls.plasmid_reference(sample_id)
        with open(bl_path, 'w') as hand:
            for name, comment, seq in generate_fasta(ref_path):
                reason = ""
                if "chromosome" in comment:
                    reason+="chromosome;"
                elif len(seq) > require_max_length:
                    reason+="chromosome_length;"
                if not "complete" in comment:
                    reason+="not_complete;"
                if len(seq) < require_min_length:
                    reason+="too_short;"
                if reason != "":
                    print(name, reason, comment, len(seq), sample_id, sep="\t", file=hand)
        return bl_path
    
    @classmethod
    def load_coverage_file_df(cls, sample):
        cov_file = sampleget.short_read_coverage_on_ref(sample)
        dfi = pd.read_csv(cov_file, sep="\t")
        chr_cov = dfi.meandepth.iloc[0]
        dfi = dfi.drop(index=0)
        dfi['relcov'] = dfi.meandepth/chr_cov
        dfi['chrcov'] = chr_cov
        return dfi
    
    @classmethod
    def plasmid_assembly(cls, sample, method):
        asm_path_fmt = K.method_asm_fmt[method]
        asm_path = asm_path_fmt.format(sample=sample)
        if os.path.isfile(asm_path):
            return asm_path
        raise "not implemented"
    
    @classmethod
    def assembly_hashes(cls, sample, method, mashparams=(0,17,15),require_circular=True):
        asm_path = cls.plasmid_assembly(sample, method)
        whitelist = {}
        if "plassembler" in method and require_circular:
            summary_file = asm_path[:asm_path.rfind("/")] + "/plassembler_summary.tsv"
            try:
                whitelist = {y.contig for x,y in pd.read_csv(summary_file, sep="\t").iterrows() if y.circularity == 'circular' }
            except:
                whitelist = {}
        
        return make_hashes(asm_path, mashparams, whitelist, require_circular)

    @classmethod
    def compare_the_assembly_with_gt(cls, sample, methods, mashparams=(0,17,15), require_circular=True, minimum_req_match=0.75, blacklist_enabled=True):
        if not isinstance(methods, tuple):
            methods = (methods)
        gt_path = sampleget.plasmid_reference(sample)
        bl_file = sampleget.plasmid_blacklist(sample) if blacklist_enabled else None
        bldf = pd.read_csv(bl_file,sep="\t",header=None, names=["id", "reason", "comment", "length", "sample"]).set_index("id") if blacklist_enabled else None
        N, K, Scale = mashparams;
        gt_hashes = []
        for name, comment, seq in generate_fasta(gt_path):
            if blacklist_enabled and name in bldf.index:
                continue
            hashes = sourmash.MinHash(n=N,ksize=K,scaled=Scale)
            hashes.add_sequence(seq, force=True)
            gt_hashes.append((name, hashes))
        result = {}
        for method in methods:
            method_hashes = cls.assembly_hashes(sample, method, mashparams, require_circular)
            result[method] = all2allcompare(gt_hashes, method_hashes, minimum_req_match=minimum_req_match)
        return result

    @classmethod
    def compare_plasmid_containment_to_main_chr(cls,sample, mashparams, blacklist_enabled=True):
        ref_file = cls.plasmid_reference(sample)
        bl_file  = cls.plasmid_blacklist(sample)

        bl_df =  pd.read_csv(bl_file,sep="\t",header=None, names=["id", "reason", "comment", "length", "sample"]).set_index("id")

        N, K, Scale = mashparams;

        chr_hash = sourmash.MinHash(n=N,ksize=K,scaled=Scale)
        plasmid_hashes = {}
        # print(bl_df)
        for name, comment, seq in generate_fasta(ref_file):
            # print(name)
            if name in bl_df.index and ("chromosome" in bl_df.loc[name].reason):
                chr_hash.add_sequence(seq, force=True)
            elif name in bl_df.index and blacklist_enabled:
                continue
            else:
                plasmid_hashes[name] = sourmash.MinHash(n=N,ksize=K,scaled=Scale)
                plasmid_hashes[name].add_sequence(seq, force=True)
        return {n:h.contained_by(chr_hash) for n, h in plasmid_hashes.items()}
            
    @classmethod
    def selected_lr_coverage_df(cls, sample, method):
        cov_path = K.method_chosen_lr_fmt[method].format(sample=sample)
        df = pd.read_csv(cov_path, sep="\t")[["#rname","coverage"]]
        df.insert(0, "sample", sample, True)
        df.insert(1, "method", method, True)
        return df

    @classmethod
    def selected_lr_counts_hyplass(cls, sample):
        lr_folder = K.hyplass_initial_selection_fmt.format(sample=sample)
        total_read_count = gz_line_count(K.lrformat.format(sample=sample))/4 
        types = {
            "plasmid": "plasmid.fastq.gz",
            "unmapped": "unmapped.fastq.gz",
            "unknown-both": "unknown_both.fastq.gz",
            "unknown-neither": "unknown_neither.fastq.gz",
        }
        ret = {k: gz_line_count(os.path.join(lr_folder, v))//4 for k,v in types.items()}
        return ret, {k:v/total_read_count for k,v in ret.items()}

In [None]:
class experimentget:
    @classmethod
    def collect_selected_lr_coverage_df(cls, blacklist_enabled=True):
        dfs = []
        for sample in K.samples:
            for method in K.method_chosen_lr_fmt.keys():
                dfs.append(sampleget.selected_lr_coverage_df(sample, method))
        df = pd.concat(dfs, axis=0, ignore_index=True)#.set_index(["#rname", "method"])
        if blacklist_enabled:
            bl_df = cls.plasmid_blacklist()
            df = df.loc[~df["#rname"].isin(bl_df.index)]
        return df
        
    @classmethod
    def collect_plasmid_chr_containments(cls, mashparams=(0,17,15), threads=32, blacklist_enabled=True):
        if threads == 1:
            return { s: sampleget.compare_plasmid_containment_to_main_chr(sample, mashparams, blacklist_enabled=blacklist_enabled) for sample in K.samples}     
        else:
            result = {}
            with Pool(threads) as pool:
                future_parameters = [
                    (pool.apply_async(sampleget.compare_plasmid_containment_to_main_chr, (sample,), dict(mashparams=mashparams, blacklist_enabled=blacklist_enabled)), sample) for sample in K.samples]
                for future, sample in future_parameters:
                    result[sample] = future.get()
            return result
    @classmethod
    def plasmid_blacklist(cls):
        bl_path = K.experiment_plasmid_blacklist_path
        # if os.path.isfile(bl_path):
        #     return pd.read_csv(bl_path, sep="\t").set_index("id")
        dfs = []
        for k in (K.samples):
            bl_file = sampleget.plasmid_blacklist(k)
            # print(bl_file)
            dfs.append(pd.read_csv(bl_file,sep="\t",header=None, names=["id", "reason", "comment", "length", "sample"]).set_index("id"))
        df = pd.concat(dfs,axis=0)
        df.to_csv(K.experiment_plasmid_blacklist_path, sep="\t")
        return df
        
    @classmethod
    def coverage_length_df(cls, blacklist_enabled=True):
        if blacklist_enabled:
            blacklisted = set(experimentget.plasmid_blacklist().index)
        else:
            blacklisted = set()
        df =  pd.concat([sampleget.load_coverage_file_df(k) for k in (K.samples)],axis=0)
        # return df.loc[[x not in blacklisted for x in df["#rname"]]]
        return df.loc[~df["#rname"].isin(blacklisted)]

    @classmethod
    def collect_found_plasmid_sets(cls, method, blacklist_enabled=True, threads=32):
        collect = []
        for sample in K.samples:
            sr = sampleget.compare_the_assembly_with_gt(sample,methods=method, blacklist_enabled=blacklist_enabled)
            collect.extend([z[0] for k,v in sr.items() for z in v['matches']])
        return set(collect)
    @classmethod
    def collect_results(cls, blacklist_enabled=True, threads=32, minimum_req_match=0.75):
        tp = defaultdict(int)
        fp = defaultdict(int)
        fn = defaultdict(int)
        methods = list(K.method_asm_fmt.keys())
        if threads == 1:
            for sample in K.samples:
                sr = sampleget.compare_the_assembly_with_gt(sample,methods=tuple(methods), blacklist_enabled=blacklist_enabled,minimum_req_match=minimum_req_match)
                for k,v in sr.items():
                    tp[k] += len(v['matches'])
                    fp[k] += v['b-count'] - len(v['matches'])
                    fn[k] += v['a-count'] - len(v['matches'])
            return tp, fp, fn
        else:
            
            with Pool(threads) as pool:

                future_parameters = [
                    (pool.apply_async(sampleget.compare_the_assembly_with_gt, (sample,), dict(methods=tuple(methods),blacklist_enabled=blacklist_enabled,minimum_req_match=minimum_req_match)), sample) for sample in K.samples]
                for future, parameters in future_parameters:
                    sr = future.get()
                    for k,v in sr.items():
                        tp[k] += len(v['matches'])
                        fp[k] += v['b-count'] - len(v['matches'])
                        fn[k] += v['a-count'] - len(v['matches'])
                return tp, fp, fn
    @classmethod
    def collect_detailed_results(cls):
        # if os.path.isfile(K.results_df_tsv):
        #     return pd.read_csv(K.results_df_tsv, sep="\t")
        foundby = {tool:experimentget.collect_found_plasmid_sets([tool]) for tool in K.method_asm_fmt.keys()}
        pcc = experimentget.collect_plasmid_chr_containments()
        pcc_df = pd.DataFrame.from_dict({ a:dict(sample=k, chr_containment=b) for k,v in pcc.items() for a,b in v.items()}, orient="index")
        cov_df = experimentget.coverage_length_df().set_index("#rname")
        for method in K.method_asm_fmt.keys():
            pcc_df[method] = cov_df.index.isin(foundby[method])
        cov_df = experimentget.coverage_length_df().set_index("#rname")
        results_df = pcc_df.join(cov_df[['endpos', 'relcov', 'chrcov']])
        results_df.to_csv(K.results_df_tsv, sep="\t")
        return results_df
    @classmethod
    def split_containment_df_by_two_methods(cls, df, m1, m2):
        neither = df.loc[(df[m1] & df[m2])].chr_containment
        m2_missed = df.loc[(df[m1] & (~ df[m2]))].chr_containment
        m1_missed = df.loc[(~df[m1] & (df[m2]))].chr_containment
        both = df.loc[(~df[m1] & (~df[m2]))].chr_containment
        return {"Both":both, f"{m1}": m1_missed, f"{m2}":m2_missed, "Neither":neither}

In [None]:
class plotget:
    @classmethod
    def ground_truth_plasmid_length_histogram(cls):
        df = experimentget.coverage_length_df()
        df = remove_outliers(df, 'relcov')
        df = remove_outliers(df, 'endpos')
        fig=plt.figure(figsize=(8,8))
        plt.hist(np.log10(df.endpos),bins=50, log=False)
        plt.xlabel('log_10(Plasmid length)')
        plt.ylabel('count')
        tick_lbls = [1000,10000,30000,50000,100000,150000,200000]
        tick_locs = np.log10(tick_lbls)
        plt.xticks(tick_locs, tick_lbls,rotation=40)
        plt.savefig(f"{plotdir}/gt_plasmid_length.pdf", **PLOT_PARAMS)

    @classmethod
    def plasmid_length_x_coverage_scatter(cls):
        df = experimentget.coverage_length_df()
        sdf = df.loc[df.endpos < 20000]
        sdf = remove_outliers(sdf, 'relcov')
        sdf = remove_outliers(sdf, 'endpos')
        ldf = df.loc[df.endpos >= 20000]
        ldf = remove_outliers(ldf, 'relcov')
        ldf = remove_outliers(ldf, 'endpos')
        fig = plt.figure(figsize=(8,8))
        ax = plt.gca()
        ax.scatter( sdf['endpos'],sdf['relcov'],  alpha=0.5,  label='small plasmid')
        ax.scatter( ldf['endpos'],ldf['relcov'],  alpha=0.5, label='large plasmid')
        ax.set_ylabel('log(coverage) (relative to the chromosome)')
        ax.set_xlabel('log(length)')
        ax.set_yscale('log')
        ax.set_xscale('log')
        plt.legend()
        plt.savefig(f"{plotdir}/plasmid_length_to_coverage.pdf", **PLOT_PARAMS)

    @classmethod
    def tools_precision_recall_f1_table(cls,blacklist_enabled=True, minimum_req_match=0.75):
        tp, fp, fn = experimentget.collect_results(threads=32,blacklist_enabled=blacklist_enabled, minimum_req_match=minimum_req_match)
        
        precision = {k: v/(v+fp[k])for k,v in tp.items()}
        recall = {k: v/(v+fn[k])for k,v in tp.items()}
        f1 = {k: 2*v/(2*v+fn[k]+fp[k])for k,v in tp.items()}

        df = pd.DataFrame.from_dict(dict({"Precision":precision, "Recall":recall, "F1-score":f1}))
        return df;
    @classmethod 
    def df2table(cls, df, caption):
        ret = [
            "\\begin{table}",
            df.to_latex(float_format="{:.4f}".format),
            f"\\caption{{{caption}}}",
            "\\end{table}",
        ]
        return "\n".join(ret)
    
    @classmethod
    def chr_containment_box_plots(cls, m1, m2):
        df = experimentget.collect_detailed_results()
        split = experimentget.split_containment_df_by_two_methods(df.loc[df.endpos<20000], m1, m2)
#Small Plasmid chromosome containments. Split into 4 sets Missed by both, missedby hyplass, missed by plassembler and not missed by either
        # split = split_df_by_two_methods(results_df.loc[results_df.length<20000], "HyPlAss-2", "Plassembler-Flye")
        fig = plt.figure(layout='tight')
        ax1 = sns.boxplot(split)
        ax1.set_ylim([0,1])
        ax1.set_xticks(ticks=[0,1,2,3],labels=ax1.get_xticklabels(),rotation=15)
        ax1.tick_params(axis='both', labelsize=15)
        plt.savefig(f"{plotdir}/small_plasmid_chr_containment_plot.pdf", **PLOT_PARAMS)
        plt.show()
        fig = plt.figure(layout='tight')
#Large Plasmid chromosome containments. Split into 4 sets Missed by both, missedby hyplass, missed by plassembler and not missed by either
        # split = split_df_by_two_methods(results_df.loc[results_df.length>=20000], "HyPlAss-2", "Plassembler-Flye")
        split = experimentget.split_containment_df_by_two_methods(df.loc[df.endpos>=20000], m1, m2)

        ax2 = sns.boxplot(split)
        ax2.set_ylim([0,1])
        ax2.set_xticks(ticks=[0,1,2,3],labels=ax2.get_xticklabels(),rotation=15)
        ax2.tick_params(axis='both', labelsize=15)
        plt.savefig(f"{plotdir}/large_plasmid_chr_containment_plot.pdf",**PLOT_PARAMS)
        return ax1, ax2
    @classmethod
    def plasmid_lr_coverage_thresholds_plot(cls):

        df = experimentget.collect_selected_lr_coverage_df()
        thresholds = np.arange(50,100,3)
        T = defaultdict(list)
        
        for x,y in df.groupby("method"):
            for t in thresholds:
                T[x].append(np.sum(y.coverage >= t))
        plt.figure(figsize=(8,8))
        ax = sns.lineplot(T)
        ax.set_xticks(range(len(thresholds)),labels=thresholds)
        ax.tick_params(axis='both',labelsize=13)
        ax.set_xlabel("Percentage Covered to be accepted as TP",fontsize=16)
        ax.set_ylabel("Number of Plasmids",fontsize=16)
        plt.savefig(f"{plotdir}/plasmid_coverage_comparison_thresholds.pdf", **PLOT_PARAMS)
        return ax
    @classmethod
    def tp_fp_fn_over_thresholds_plot(cls, thresholds=np.arange(00.1,1.0, 0.1)):
        rlist = defaultdict(list)
        for t in thresholds:
            tp, fp, fn = experimentget.collect_results(threads=32,minimum_req_match=t)
            rlist['TP'].append(tp)
            rlist['FP'].append(fp)
            rlist['FN'].append(fn)
        axes = []
        for plot in ["TP", "FP", "FN"]:
            df = pd.DataFrame.from_dict(rlist[plot]).set_index(thresholds)
            axes.append(sns.lineplot(df,markers=True))
            plt.title(f'{plot.upper()}')
            plt.savefig(f"{plotdir}/plasmid_{plot}_over_thresholds.pdf", **PLOT_PARAMS)
            plt.show()
        return axes
    @classmethod
    def precision_recall_f1_over_thresholds_plot(cls, thresholds=np.arange(00.1,1.0, 0.1)):
        rlist = defaultdict(list)
        for t in thresholds:
            tp, fp, fn = experimentget.collect_results(threads=32,minimum_req_match=t)
                
            precision = {k: v/(v+fp[k])for k,v in tp.items()}
            recall = {k: v/(v+fn[k])for k,v in tp.items()}
            f1 = {k: 2*v/(2*v+fn[k]+fp[k])for k,v in tp.items()}

            rlist['Precision'].append(precision)
            rlist['Recall'].append(recall)
            rlist['F1'].append(f1)
        axes = []
        for plot in rlist.keys():
            df = pd.DataFrame.from_dict(rlist[plot]).set_index(thresholds)
            axes.append(sns.lineplot(df,markers=True))
            plt.title(f'{plot}')
            plt.savefig(f"{plotdir}/plasmid_{plot}_over_thresholds.pdf", **PLOT_PARAMS)
            plt.show()
        return axes

In [None]:
plotget.plasmid_length_x_coverage_scatter()

In [None]:
plotget.ground_truth_plasmid_length_histogram()

In [None]:
prf_df = plotget.tools_precision_recall_f1_table(blacklist_enabled=True, minimum_req_match=.75)
prf_df

In [None]:
print(plotget.df2table(prf_df, "Plasmid prediction accuracy statistics."))

In [None]:
plotget.chr_containment_box_plots("HyPlAs-2", "Plassembler-Flye")

In [None]:
plotget.plasmid_lr_coverage_thresholds_plot()

In [None]:
plotget.plasmid_lr_coverage_thresholds_plot()

In [None]:
tp, fp, fn = experimentget.collect_results(threads=32,minimum_req_match=.5)
tpfpfndf = pd.DataFrame.from_dict(dict(TP=tp, FP=fp, FN=fn))
tpfpfndf

In [None]:
#Blacklist explanation
## chromosome: listed as chromosome
## chromosome_length: not listed as chromosome on the fasta, 
##   but is a chromsome determined by lack of annotation on the sample fasta and confirmation from the ncbi website
## length: smaller than 1Kb
## not_complete: Not listed as a complete assembly
bl_df = experimentget.plasmid_blacklist()
np.unique(bl_df.reason, return_counts=True)

In [None]:
tp, fp, fn = experimentget.collect_results(threads=32,minimum_req_match=0.5,blacklist_enabled=True)
tpfpfndf = pd.DataFrame.from_dict(dict(TP=tp, FP=fp, FN=fn))
tpfpfndf

In [None]:
plotget.tp_fp_fn_over_thresholds_plot()

In [None]:
plotget.precision_recall_f1_over_thresholds_plot()

In [None]:
df = experimentget.collect_detailed_results()

In [None]:
blist = experimentget.plasmid_blacklist()
blist

In [None]:
from scipy.stats import mannwhitneyu
def split_to_four_len(df, hsteps):
    HP = df.loc[df[f"HyPlAs-{hsteps}"]&df["Plassembler-Flye"]]
    H_not = df.loc[~df[f"HyPlAs-{hsteps}"] & df["Plassembler-Flye"]]
    P_not = df.loc[~df["Plassembler-Flye"] & df[f"HyPlAs-{hsteps}"] ]
    HP_not = df.loc[~ df[f"HyPlAs-{hsteps}"] & ~ df["Plassembler-Flye"]]
    return dict(HP=len(HP), H_not=len(H_not), P_not=len(P_not), HP_not=len(HP_not), Total=len(df))
def split_to_four_cont(df, hsteps):
    HP = df.loc[df[f"HyPlAs-{hsteps}"]&df["Plassembler-Flye"]]
    H_not = df.loc[~df[f"HyPlAs-{hsteps}"] & df["Plassembler-Flye"]]
    P_not = df.loc[~df["Plassembler-Flye"] & df[f"HyPlAs-{hsteps}"] ]
    HP_not = df.loc[~ df[f"HyPlAs-{hsteps}"] & ~ df["Plassembler-Flye"]]
    return dict(HP=np.mean(HP.chr_containment), H_not=np.mean(H_not.chr_containment), P_not=np.mean(P_not.chr_containment), HP_not=np.mean(HP_not.chr_containment), Total=np.mean(df.chr_containment))
def split_to_four_manwhitneyu(df, hsteps):
    # HP = df.loc[df[f"HyPlAs-{hsteps}"]&df["Plassembler-Flye"]]
    H_not = df.loc[~df[f"HyPlAs-{hsteps}"] & df["Plassembler-Flye"]]
    P_not = df.loc[~df["Plassembler-Flye"] & df[f"HyPlAs-{hsteps}"] ]
    # HP_not = df.loc[~ df[f"HyPlAs-{hsteps}"] & ~ df["Plassembler-Flye"]]
    return mannwhitneyu(H_not.chr_containment, P_not.chr_containment,method='exact')

In [None]:
df

In [None]:
missing_plasmid_df = pd.DataFrame()
# df_nobl = df.loc[~df.index.isin(blist.index)]
hsteps=2
threshold=20000
missing_plasmid_df["All"] = split_to_four_len(df,hsteps)
missing_plasmid_df["Small"] = split_to_four_len(df.loc[df.endpos < threshold],hsteps)
missing_plasmid_df["Large"] = split_to_four_len(df.loc[df.endpos >= threshold],hsteps)


In [None]:
missing_plasmid_df

In [None]:
missing_containment_df = pd.DataFrame()
# df_nobl = df.loc[~df.index.isin(blist.index)]
hsteps=2
threshold=20000
missing_containment_df["All"] = split_to_four_cont(df,hsteps)
missing_containment_df["Small"] = split_to_four_cont(df.loc[df.endpos < threshold],hsteps)
missing_containment_df["Large"] = split_to_four_cont(df.loc[df.endpos >= threshold],hsteps)


In [None]:
missing_containment_df

In [None]:
split_to_four_manwhitneyu(df,hsteps)

In [None]:
def get_read_set(a):
    cmd = ["samtools",
        "view",
        "-F 260",
        a]
    ret = subprocess.run(cmd, capture_output=True)
    return {line[:line.find(b"\t")] for line in ret.stdout.split(b"\n")}


In [None]:
sum_stats = [np.zeros(7) for i in range(3)]
all_stats = defaultdict(list)

for z,s in enumerate(K.samples):
    
    plassembler_fmt= f'{datapath}/{s}/plassembler_flye-qc/toplasmids.sam'
    GT_fmt = f'{datapath}/{s}/qc/toplasmids.bam'
    
    
    pr = get_read_set(plassembler_fmt)
    gr = get_read_set(GT_fmt)
    

    for i in range(3):
        
        hyplass_fmt = f"{datapath}/{s}/hyplass_b_cov/plasmid_selected_i{i}.sam"
        hr = get_read_set(hyplass_fmt)
        stats = compute_lr_stats3s (gr,pr,hr)
        all_stats[s].append( stats)
        sum_stats[i] += stats
    
    


In [None]:
for i in range(3):
    venn3_unweighted([int(x) for x in sum_stats[i]],set_labels=['All', 'Plassembler', f'HyPlAs-{i}'])
    plt.savefig(f"{plotdir}/read_set_venn_{i}.pdf")
    plt.show()

In [None]:
def tools_overlap(df, t1, t2):
    return (df[t1]&~df[t2], df[t1]&df[t2], ~df[t1]&df[t2])

In [None]:
[sum(x) for x in tools_overlap(df, 'Plassembler-Raven', 'Plassembler-Flye')]