In [None]:
import pathlib as pl
import json

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.stats import mannwhitneyu as mwu
from statsmodels.stats.multitest import multipletests

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

desc="""
This notebook uses the output of the rules
--- 20_preprocessing.smk::prep_t2t_seq_class_cache_file
--- 80_est_assm_errors.smk::merge_agg_seqclass_errors
as input, and produces figure for the supplement.
"""

save_plots = True

print(desc)

def find_repo_base(start_path):
    check_path = start_path
    while 1:
        if check_path.name == "project-male-assembly":
            return check_path
        check_path = check_path.parent
    return


exec_dir = pl.Path('.').resolve(strict=True)
paths_json = exec_dir.parent.joinpath("plotting_paths.json").resolve(strict=True)
plotting_paths = json.load(open(paths_json))

repo_dir = find_repo_base(exec_dir)
wd_dir = pl.Path(plotting_paths["wd_dir"]).resolve(strict=True)
out_dir = pl.Path(plotting_paths["fig_out"]).resolve(strict=True)

print('Execution directory: ', exec_dir)
print('Working directory: ', wd_dir)
print('Output directory: ', out_dir)
print('=================================')

t2t_classes_file = wd_dir.joinpath(
    "annotations", "T2T.chrY-seq-classes.tsv"
)
t2t = pd.read_csv(t2t_classes_file, sep='\t', header=0)
t2t['region_size'] = t2t['end'] - t2t['start']

sample_classes_file = wd_dir.joinpath(
    "data", "error_cluster", "SAMPLES.HIFIRW.ONTUL.na.chrY.mrg-seqclass-errors.tsv"
)
samples = pd.read_csv(sample_classes_file, sep='\t', header=0)

drop = ['NA24385', 'HG02666', 'NA19384', 'HG01457', 'NA18989', 'HG03456']

samples = samples.loc[~samples['sample'].isin(drop), :].copy()

hc_sample_ids = [
    "HC02666", "HC18989", "HC19384", "HC01457",
    "HG00358", "HG01890", "NA19317", "NA19347",
    "HG03471"
]
num_hc = len(hc_sample_ids)
lc_sample_ids = [s for s in samples['sample'] if s not in hc_sample_ids]
num_lc = len(set(lc_sample_ids))
num_all = num_hc + num_lc

hc_samples = samples.loc[samples['sample'].isin(hc_sample_ids), :].copy()
lc_samples = samples.loc[samples['sample'].isin(lc_sample_ids), :].copy()


def plot_region_size_stats(t2t, samples, top_label, out_prefix=None):

    bars = []
    xpos_bars = []
    boxes = []
    xpos_boxes = []
    xpos = []

    xlabels = []
    xpos_labels = []
    
    total_samples = samples['sample'].nunique()
    
    ref_labels = set(t2t['name'])
    sample_labels = set(samples['region_type'])
    
    missing = sample_labels - ref_labels
    assert len(missing) == 0
    missing = ref_labels - sample_labels
    assert len(missing) == 0
    
    colors = []
    
    xloc = 0
    for row in t2t.itertuples(index=False):
        ref_size = row.region_size
        ref_label = row.name
        rgba = row.red, row.green, row.blue, row.alpha

        sample_regions = samples.loc[samples['region_type'] == ref_label, :]
        has_sequence = round(sample_regions['sample'].nunique() / total_samples * 100, 0)

        xloc += 1
        bars.append(has_sequence)
        xpos_bars.append(xloc)
        
        xpos_labels.append(xloc + 0.5)
        xlabels.append(ref_label)
        
        sample_sizes = (sample_regions['region_size'] / ref_size * 100).round(1).values
        xloc += 1
        boxes.append(sample_sizes)
        xpos_boxes.append(xloc)
        
        colors.append(rgba)
        
        xloc += 1
    
    fig, ax = plt.subplots(figsize=(20, 8))
    fig_name = f'{out_prefix}assm_annotated_size_per_seqclass'
    
    ax.bar(
        xpos_bars,
        bars,
        width=0.9,
        align='center',
        color=colors
    )
    
    bplot = ax.boxplot(
        boxes,
        positions=xpos_boxes,
        widths=0.9,
        showmeans=False,
        patch_artist=True,
        medianprops={
            'color': 'black'
        }
        
    )

    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_edgecolor(color)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    
    ax.set_xticks(xpos_labels)
    ax.set_xticklabels(xlabels, fontsize=16, rotation=90)
    ax.tick_params(axis='y', which='major', labelsize=14)
    
    ax.set_ylabel(
        'Bars: Y seq. class assembled in samples (%)\nBoxes: annotated size relative to T2T-Y (%)',
        fontsize=18
    )
    
    ax.set_title(top_label, fontsize=18)

    if save_plots:
        print(f"Dumping figures to {out_dir}")
        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.png'),
            dpi=150, bbox_inches='tight'
        )

        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.pdf'),
            bbox_inches='tight'
        )
        
    return


def plot_region_error_stats(t2t, samples, top_label, out_prefix=""):

    bars = []
    xpos_bars = []
    boxes = []
    xpos_boxes = []
    xpos = []

    xlabels = []
    xpos_labels = []
    
    total_samples = samples['sample'].nunique()
    
    ref_labels = set(t2t['name'])
    sample_labels = set(samples['region_type'])
        
    missing = sample_labels - ref_labels
    assert len(missing) == 0
    missing = ref_labels - sample_labels
    assert len(missing) == 0
    
    colors = []
    
    xloc = 0
    for row in t2t.itertuples(index=False):
        ref_size = row.region_size
        ref_label = row.name
        rgba = row.red, row.green, row.blue, row.alpha

        sample_regions = samples.loc[samples['region_type'] == ref_label, :]
        # difference here: count percentage of samples with errors
        samples_with_errors = sample_regions.loc[sample_regions['errors_bp'] > 0, 'sample'].nunique()
        has_errors = round(samples_with_errors / total_samples * 100, 0)

        xloc += 1
        bars.append(has_errors)
        xpos_bars.append(xloc)
        
        xpos_labels.append(xloc)
        xlabels.append(ref_label)
        
        sample_sizes = (sample_regions['errors_bp'] / sample_regions['region_size'] * 100).round(1).values
        boxes.append(sample_sizes)
        xpos_boxes.append(xloc)
        
        colors.append(rgba)
   
    fig, axes = plt.subplots(figsize=(20, 16), nrows=2, ncols=1, sharex=True, sharey=False)
    fig_name = f'{out_prefix}assm_error_by_seqclass'

    # upper panel / first axis: percent samples with flagged regions
    ax_pct_smp = axes[0]
    ax_pct_smp.bar(
        xpos_bars,
        bars,
        width=0.8,
        align='center',
        color=colors
    )
    
    ax_pct_smp.spines['top'].set_visible(False)
    ax_pct_smp.spines['right'].set_visible(False)
    ax_pct_smp.tick_params(axis='y', which='major', labelsize=14)
    ax_pct_smp.set_ylabel(
        'Samples with flagged regions (%)\n(putative assembly errors)',
        fontsize=18
    )
    
    # lower panel / second axis: percent flagged bp
    ax_pct_bp = axes[1]
    bplot = ax_pct_bp.boxplot(
        boxes,
        positions=xpos_boxes,
        widths=0.8,
        showmeans=False,
        patch_artist=True,
        medianprops={
            'color': 'black'
        }
        
    )
    
    ax_pct_bp.spines['top'].set_visible(False)
    ax_pct_bp.spines['right'].set_visible(False)

    ax_pct_bp.set_xticks(xpos_labels)
    ax_pct_bp.set_xticklabels(xlabels, fontsize=16, rotation=90)
    ax_pct_bp.tick_params(axis='y', which='major', labelsize=14)

    ax_pct_bp.set_ylabel(
        'Flagged bp relative to assembled size (%)',
        fontsize=18
    )
    
    ax_pct_bp.set_title(top_label, fontsize=18)
    
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_edgecolor(color)
    
    if save_plots:
        print(f"Dumping figures to {out_dir}")
        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.png'),
            dpi=150, bbox_inches='tight'
        )

        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.pdf'),
            bbox_inches='tight'
        )
    
    return boxes, colors, xlabels


def plot_pairwise_region_errors(group1, group2, labels, colors):
        
    mwu_results = [mwu(g1,g2, method="auto") for g1, g2 in zip(group1, group2)]
    mwu_stats = [res[0] for res in mwu_results]
    mwu_pval = [res[1] for res in mwu_results]
    mwu_pval_corr = multipletests(
        mwu_pval, alpha=0.05, method="fdr_bh", is_sorted=False, returnsorted=False
    )
    print(mwu_pval_corr)
    
    fig, ax = plt.subplots(figsize=(20, 8))
    
    xlocs_g1 = np.arange(0, len(group1))
    xlocs_g2 = np.arange(0.5, len(group2))
    
    boxes_g1 = ax.boxplot(
        group1,
        positions=xlocs_g1,
        widths=0.4,
        showmeans=False,
        patch_artist=True,
        medianprops={
            "color": "black"
        }
    )
    
    boxes_g2 = ax.boxplot(
        group2,
        positions=xlocs_g2,
        widths=0.4,
        showmeans=False,
        patch_artist=True,
        medianprops={
            "color": "black"
        }
    )
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    xlocs_labels = (xlocs_g1 + xlocs_g2) / 2
    ax.set_xticks(xlocs_labels)
    ax.set_xticklabels(labels, fontsize=16, rotation=90)
    ax.tick_params(axis='y', which='major', labelsize=14)

    ax.set_ylabel(
        'Flagged bp relative to assembled size (%)',
        fontsize=18
    )
    
    ax.set_xlabel(
        'Assembled Y sequence classes\nHigh (n=9) vs. lower (n=36) coverage assemblies',
        fontsize=18
    )
   
    for patch, color in zip(boxes_g1['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_edgecolor(color)

    for patch, color in zip(boxes_g2['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_edgecolor(color)

    if save_plots:
        fig_name = 'assm_error_by_seqclass_low-vs-high'
        print(f"Dumping figures to {out_dir}")
        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.png'),
            dpi=150, bbox_inches='tight'
        )

        plt.savefig(
            out_dir / pl.Path(f'{fig_name}.pdf'),
            bbox_inches='tight'
        )

    return
    

#pvals, alpha=0.05, method='hs', is_sorted=False, returnsorted=False
    
# plot all samples
_ = plot_region_size_stats(t2t, samples, f"(all samples, N={num_all})", "all_")
flagged_bp_all, colors, labels = plot_region_error_stats(
    t2t, samples, f"(all samples, N={num_all})", "all_"
)

# plot high-coverage samples
_ = plot_region_size_stats(t2t, hc_samples, f"(high-coverage samples, N={num_hc})", "highcov_")
flagged_bp_hc, _, _ = plot_region_error_stats(
    t2t, hc_samples, f"(high-coverage samples, N={num_hc})", "highcov_"
)

# plot lower-coverage samples
_ = plot_region_size_stats(t2t, lc_samples, f"(lower-coverage samples, N={num_lc})", "avgcov_")
flagged_bp_lc, _, _ = plot_region_error_stats(
    t2t, lc_samples, f"(lower-coverage samples, N={num_lc})", "avgcov_"
)

# plot high-vs-lower coverage samples
_ = plot_pairwise_region_errors(flagged_bp_hc, flagged_bp_lc, labels, colors)




This notebook uses the output of the rules
--- 20_preprocessing.smk::prep_t2t_seq_class_cache_file
--- 80_est_assm_errors.smk::merge_agg_seqclass_errors
as input, and produces figure for the supplement.

Execution directory:  /home/ebertp/work/code/marschall-lab/project-male-assembly/notebooks/plotting/errors
Working directory:  /home/ebertp/work/projects/sig_chry/paper
Output directory:  /home/ebertp/work/projects/sig_chry/paper/output/figures
Dumping figures to /home/ebertp/work/projects/sig_chry/paper/output/figures
Dumping figures to /home/ebertp/work/projects/sig_chry/paper/output/figures
Dumping figures to /home/ebertp/work/projects/sig_chry/paper/output/figures
Dumping figures to /home/ebertp/work/projects/sig_chry/paper/output/figures
Dumping figures to /home/ebertp/work/projects/sig_chry/paper/output/figures
