In [2]:
import pathlib
import functools
import matplotlib.pyplot as plt
import matplotlib.patches as patch
import math
import random
import numpy as np
import scipy.stats as stat

random.seed()

%cd -q "/home/ebertp/work/code/cubi/project-run-hgsvc-hybrid-assemblies/notebooks"
_PROJECT_CONFIG_NB = str(pathlib.Path("00_project_config.ipynb").resolve(strict=True))
_PLOT_CONFIG_NB = str(pathlib.Path("05_plot_config.ipynb").resolve(strict=True))
_ASSM_STATS_NB = str(pathlib.Path("10_assm_stats.ipynb").resolve(strict=True))

%run $_PROJECT_CONFIG_NB
%run $_PLOT_CONFIG_NB
%run $_ASSM_STATS_NB

_MYNAME="plot-qv-estimates"
_MYSTAMP=get_nb_stamp(_MYNAME)

_MY_OUT_PATH = PLOT_OUT_MAIN_FIG1.joinpath("panels")

ASSEMBLER = "verkko"
assert ASSEMBLER in ["verkko", "hifiasm"]


def compute_qv(num_errors, ref_size):
    p = num_errors / ref_size
    try:
        q = -10 * math.log10(p)
    except ValueError:
        return 99
    return int(round(q, 0))


def deselect_old_version(sample):

    # 2024-12-06 fix
    # the Verkko Merqury summary tables contain the QV data
    # for both HG00514/v1 and HG00514/v2 (fix), need to drop
    # the former and rename the latter

    if "HG00514-fix" in sample:
        return True
    elif "HG00514" in sample:
        return False
    else:
        return True


def read_merqury_qv_estimates(assembler=ASSEMBLER, aggregate="sample"):

    table_files = PROJECT_DATA_ROOT.joinpath(
            "2024_merqury"
    ).glob(f"{assembler}*.tab")

    summary_columns = [
        "sample", "asm_unit", "merqury_error_kmer", "merqury_total_kmer",
        "merqury_qv_est", "merqury_error_rate"
    ]
    contig_columns = [
        "sample", "sequence", "merqury_error_kmer", "merqury_total_kmer",
        "merqury_qv_est", "merqury_error_rate"
    ]
    
    concat = []
    for table_file in table_files:
        if "contigs" in table_file.name:
            use_names = contig_columns
        else:
            use_names = summary_columns
        df = pd.read_csv(
            table_file, sep="\t",
            header=None, names=use_names
        )
        if "asm_unit" not in df.columns:
            df["asm_unit"] = df["sequence"].apply(set_assembly_unit)
        df["assembler"] = assembler
        no_errors = df["merqury_error_kmer"] == 0
        df.loc[no_errors, "merqury_qv_est"] = 99
        df["merqury_qv_est"] = df["merqury_qv_est"].astype(float)
        if assembler == "verkko":
            df["sample"] = df["sample"].apply(lambda s: s.replace("-unassigned", ""))
        concat.append(df)      
    
    mrg = pd.concat(concat, axis=0, ignore_index=False)
    no_seq = pd.isnull(mrg["sequence"])
    mrg.loc[no_seq, "sequence"] = mrg.loc[no_seq, "asm_unit"]

    # 2024-12-06 fix entries / keep only Verkko HG00514 v2
    if assembler == "verkko":
        drop_old_hg00514 = mrg["sample"].apply(deselect_old_version)
        mrg = mrg.loc[drop_old_hg00514, :].copy()
        mrg["sample"] = mrg["sample"].replace({"HG00514-fix": "HG00514"})
    
    add_phased = []
    for sample, qv_est in mrg.groupby("sample"):
        select_h1 = qv_est["sequence"] == "hap1"
        select_h2 = qv_est["sequence"] == "hap2"
        
        qv1 = qv_est.loc[select_h1, "merqury_qv_est"].iloc[0]
        qv2 = qv_est.loc[select_h2, "merqury_qv_est"].iloc[0]
        wt_qv = round((qv1 + qv2)/2, 2)
        
        add_phased.append(
            (sample, "phased", "total", -1, -1, wt_qv, -1, assembler)
        )
    add_phased = pd.DataFrame.from_records(add_phased,
        columns=[
            "sample", "asm_unit", "sequence",
            "merqury_error_kmer", "merqury_total_kmer",
            "merqury_qv_est", "merqury_error_rate", "assembler"
        ]
    )
    mrg = pd.concat([mrg, add_phased], axis=0, ignore_index=False)
    mrg.sort_values(["sample", "asm_unit", "sequence"], inplace=True)
        
    if aggregate == "sample":
        mrg = mrg.loc[mrg["asm_unit"] == "phased", :].copy()
        assert mrg.shape[0] == HGSVC_TOTAL
    
    #df["qv_est"] = df["qv_est"].round(0).astype(int)
    #df.rename({"qv_est": "merqury_qv_est"}, axis=1, inplace=True)
    return mrg


def set_assembly_unit(seq_name):

    if any(h in seq_name for h in ["h1tg", "haplotype1"]):
        return "hap1"
    elif any(h in seq_name for h in ["h2tg", "haplotype2"]):
        return "hap2"
    elif "unassigned" in seq_name:
        return "unassigned"
    elif "genome" in seq_name:
        return "wg"
    else:
        raise ValueError(seq_name)


def read_variant_qv_estimates(assembler=ASSEMBLER, aggregate="sample"):
    """For Verkko, add QV estimates (as weighted averages)
    for the respective assembly units (hap1, hap2, phased etc.)
    in the same way as for Merqury
    """

    root_folder = PROJECT_DATA_ROOT.joinpath(
        "2023_var_qv_est", "hgsvc", f"{assembler}"
    )
    table_files = list(root_folder.glob("*.tsv"))
    assert len(table_files) == HGSVC_TOTAL
    merge = []
    combinations = [
        ("hap1",), ("hap2",), ("hap1", "hap2"),
        ("unassigned",)
    ]
    labels = ["hap1", "hap2", "phased", "unassigned"]
    for table_file in table_files:
        sample = table_file.name.rsplit(".", 4)[0]
        assembly = table_file.name.rsplit(".", 3)[0]
        df = pd.read_csv(table_file, sep="\t", header=0)
        df["asm_unit"] = df["seq_name"].apply(set_assembly_unit)
        df["seq_name"] = df["seq_name"].replace({"genome": "total"}, inplace=False)
        # in the Merqury post-processing, the global QV estimates
        # are simply derived as weighted averages; implement the
        # same strategy here for comparability
        global_estimates = []
        for comb, label in zip(combinations, labels):
            subset = df.loc[df["asm_unit"].isin(comb), :]
            if subset.empty:
                continue
            total_error = subset["num_errors"].sum()
            total_adj_len = subset["adj_length"].sum()
            total_seq_len = subset["seq_length"].sum()
            qv_est = round(np.average(subset["qv"].values, weights=subset["adj_length"].values), 4)
            global_estimates.append(
                [sample, label, "total", total_seq_len, total_adj_len, total_error, qv_est]
            )
        global_estimates = pd.DataFrame(
            global_estimates,
            columns=["sample", "asm_unit", "seq_name", "seq_length", "adj_length", "num_errors", "qv"]
        )
        df["sample"] = sample
        df["assembly"] = assembly
        df = pd.concat([df, global_estimates], axis=0, ignore_index=False)
        df.rename({"seq_name": "sequence"}, axis=1, inplace=True)
        merge.append(df)
    merge = pd.concat(merge, axis=0, ignore_index=False)
    merge.rename({"qv": "variant_qv_est"}, axis=1, inplace=True)
    if aggregate == "sample":
        merge = merge.loc[merge["asm_unit"] == "phased", :].copy()
        assert merge.shape[0] == HGSVC_TOTAL
    return merge        


def cache_qv_estimate_table(assembler=ASSEMBLER, aggregate="sample"):

    if aggregate == "sample":
        data_file = PROJECT_NB_CACHE.joinpath(f"cache.{assembler}-qvest.tsv.gz")
        if not data_file.is_file():
            merq_qv = read_merqury_qv_estimates(aggregate="sample")
            var_qv = read_variant_qv_estimates(aggregate="sample")
            qv_est = merq_qv.merge(var_qv, left_on=["sample", "sequence", "asm_unit"], right_on=["sample", "sequence", "asm_unit"])
            qv_est.to_csv(data_file, sep="\t", header=True, index=False)
    elif aggregate == "sequence":
        # create full table output / supplementary table
        data_file = PROJECT_BASE.joinpath(
            "annotations", "autogen", f"{ASSEMBLER}_qv-est.tsv"
        )
        if not data_file.is_file():
            merq_qv = read_merqury_qv_estimates(aggregate="sequence")
            var_qv = read_variant_qv_estimates(aggregate="sequence")
            qv_est = merq_qv.merge(var_qv, left_on=["sample", "sequence", "asm_unit"], right_on=["sample", "sequence", "asm_unit"])
            qv_est.rename(
                {
                    "num_errors": "variant_num_errors",
                    "error_bp": "merqury_error_bp",
                    "error_rate": "merqury_error_rate",
                    "total_adj_bp": "merqury_adj_length",
                    "adj_length": "variant_adj_length"
                }, axis=1, inplace=True
            )
            with open(data_file, "w") as table:
                _ = table.write("# AUTOGEN TABLE - DO NOT EDIT\n")
                _ = table.write(f"# {_MYSTAMP}\n")
                qv_est.to_csv(table, sep="\t", header=True, index=False)
    else:
        raise ValueError(aggregate)
    
    return data_file
       

def add_vector_jitter(values):
    return [v+random.gauss(0, 0.2) for v in values]


def plot_sample_vs_sample_qv(axes):

    qv_cache_file = cache_qv_estimate_table(aggregate="sample")
    qv_data = pd.read_csv(qv_cache_file, sep="\t", header=0)
    qv_data["sid"] = qv_data["sample"].apply(lambda x: x.split(".")[0])
    qv_data.set_index("sid", inplace=True)
    
    female_qv_var = add_vector_jitter([qv_data.at[sample, "variant_qv_est"] for sample in HGSVC_FEMALES])
    female_qv_kmer = add_vector_jitter([qv_data.at[sample, "merqury_qv_est"] for sample in HGSVC_FEMALES])
    #female_colors = [get_super_color(sample) for sample in HGSVC_FEMALES]
    female_colors = [get_pop_color(sample) for sample in HGSVC_FEMALES]

    # 2024-05 --- for some reason, HG00514 is a strong outlier; make this explicit
    # and adjust axes limits ignoring that sample
    hg00514_x = int(round(qv_data.at["HG00514", "variant_qv_est"], 0))
    hg00514_y = int(round(qv_data.at["HG00514", "merqury_qv_est"], 0))
    
    median_female_var = sorted(female_qv_var)[HGSVC_FEMALE//2]
    median_female_kmer = sorted(female_qv_kmer)[HGSVC_FEMALE//2]
    
    male_qv_var = add_vector_jitter([qv_data.at[sample, "variant_qv_est"] for sample in HGSVC_MALES])
    male_qv_kmer = add_vector_jitter([qv_data.at[sample, "merqury_qv_est"] for sample in HGSVC_MALES])
    #male_colors = [get_super_color(sample) for sample in HGSVC_MALES]
    male_colors = [get_pop_color(sample) for sample in HGSVC_MALES]
    
    median_male_var = sorted(male_qv_var)[HGSVC_MALE//2]
    median_male_kmer = sorted(male_qv_kmer)[HGSVC_MALE//2]

    median_all_x = sorted(male_qv_var + female_qv_var)[HGSVC_TOTAL//2]
    median_all_y = sorted(male_qv_kmer + female_qv_kmer)[HGSVC_TOTAL//2]

    print("median QV variant: ", median_all_x)
    print("median QV k-mer: ", median_all_y)
    
    axes.scatter(
        female_qv_var,
        female_qv_kmer,
        c=female_colors,
        label="female",
        marker=FEMALE_MARKER,
        edgecolors="none"
    )
    
    axes.scatter(
        male_qv_var,
        male_qv_kmer,
        c=male_colors,
        label="male",
        marker=MALE_MARKER,
        edgecolors="none"
    )

    # add arrow for HG00514
    #axes.annotate(
    #    f"Outlier not shown\nHG00514 x:{hg00514_x} / y:{hg00514_y}",
    #    (50.5, 52.5), fontsize=10
    #)
    
    #ax.axhline(median_fy, 0, 0.95, zorder=0, color="grey", ls="dashed")
    #ax.axvline(median_fx, 0, 0.95, zorder=0, color="grey", ls="dashed")
    axes.axhline(median_all_y, 0.01, 0.99, zorder=0, color="grey", ls="dashed")
    axes.axvline(median_all_x, 0.01, 0.99, zorder=0, color="grey", ls="dashed")
        
    axes.set_xlabel("QV estimate (variant-based)")
    axes.set_ylabel("QV estimate (kmer-based)")
    
    axes.spines["top"].set_visible(False)
    axes.spines["right"].set_visible(False)
    
    handles = get_line_legend(
        [
            {
                "marker": FEMALE_MARKER,
                "label": "female",
                "linestyle": "",
                "color": "black"
            },
            {
                "marker": MALE_MARKER,
                "label": "male",
                "linestyle": "",
                "color": "black"
            },
            {
                "linestyle": "dashed",
                "label": "median",
                "color": "grey"
            }
        ]
    )
    
    axes.legend(handles=handles, loc="lower right")

    if False:
        all_x_vals = male_qv_var + female_qv_var
        min_x = min(all_x_vals)
        max_x = max(all_x_vals)
        print(min_x)
        print(max_x)
        
        all_y_vals = male_qv_kmer + female_qv_kmer
        min_y = min(all_y_vals)
        max_y = max(all_y_vals)
        print(min_y)
        print(max_y)
        raise

    if ASSEMBLER == "verkko":
    
        axes.set_xlim(51,58)
        axes.set_ylim(48,63)
    
        axes.plot(
            np.arange(48, 64),
            np.arange(48, 64),
            ls="dotted",
            lw=1,
            zorder=0,
            c="lightgrey"
        )

    if ASSEMBLER == "hifiasm":
        
        axes.set_xlim(48,60)
        axes.set_ylim(48,62)
    
        axes.plot(
            np.arange(48, 64),
            np.arange(48, 64),
            ls="dotted",
            lw=1,
            zorder=0,
            c="lightgrey"
        )
        
    return axes


def create_sample_qv_scatter(skip_plot=False):

    fig, ax = plt.subplots(figsize=(8,8))
    ax = plot_sample_vs_sample_qv(ax)
    for ext in DEFAULT_PLOT_EXT:
        if not skip_plot:
            out_path = _MY_OUT_PATH.joinpath(f"fig1_panel_qv-est.{ASSEMBLER}.{ext}")
            save_figure(out_path, fig)
    plt.close()
    return plot_sample_vs_sample_qv
        
cache_qv_estimate_table(aggregate="sequence")
cache_qv_estimate_table(aggregate="sample")

get_qv_panel = create_sample_qv_scatter(True)

#_ = plot_unassigned_stats()
#_ = plot_unassigned_vs_read_stats("ontul_cov", "ONT-UL cov. (x-fold)")
#_ = plot_unassigned_vs_read_stats("hifi_cov", "HiFi cov. (x-fold)")
#_ = plot_unassigned_vs_read_stats("ontul_n50", "ONT-UL N50 (kbp)")

median QV variant:  53.37269860193465
median QV k-mer:  58.46979496878602
