In [None]:
import pandas as pd
from plotnine import *

# read viral identification report
report = pd.read_csv(str(snakemake.input))
tools = []
tool_counts = pd.DataFrame()

# filter based on config paramters
# extract mgv virus sequences
if snakemake.params.run_mgv: 
    mgv_report = report[report['MGV_viral'] == 'Viral']
    tools.append('MGV')
    tool_counts['assembly'] = mgv_report.groupby(['assembly'], as_index=False).count()['assembly']
    tool_counts['MGV'] = mgv_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_vf:
    vf_report = report[(report['VirFinder_score'] >= snakemake.params.vf_score)]
    tools.append('VirFinder')
    tool_counts['VirFinder'] = vf_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_vs:
    report['VirSorter_cat'] = report.apply(lambda x: x.Category_number if x.Category_text == 'complete_phage' else x.Category_number + 3, axis=1)
    vs_report = report[report['VirSorter_cat'].isin(snakemake.params.vs_cat)]
    tools.append('VirSorter')
    tool_counts['VirSorter'] = vs_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_vs2:
    vs2_report = report[report['VirSorter2_max_score'] >= snakemake.params.vs2_score]
    tools.append('VirSorter2')
    tool_counts['VirSorter2'] = vs2_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_dvf:
    dvf_report = report[(report['DeepVirFinder_score'] >= snakemake.params.dvf_score)]
    tools.append('DeepVirFinder')
    tool_counts['DeepVirFinder'] = dvf_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_vb:
    vb_report = report[report['VIBRANT_viruses'].notnull()]
    tools.append('VIBRANT')
    tool_counts['VIBRANT'] = vb_report.groupby(['assembly'], as_index=False).count()['contig_id']

if snakemake.params.run_kraken2:
    kraken2_report = report[report['kraken_classification'] == 'C']
    tools.append('Kraken2')
    tool_counts['Kraken2'] = kraken2_report.groupby(['assembly'], as_index=False).count()['contig_id']

enrich_score = tool_counts.melt(id_vars=['assembly'], value_vars=tools)

enrich_score_plot = (
    ggplot(enrich_score)
    + geom_boxplot(aes(x='variable', y='value'))
    + theme(figure_size=(16, 8)) 
)

enrich_score_plot.save(str(snakemake.output), dpi=600)