In [5]:
import pathlib

%cd -q "/home/ebertp/work/code/cubi/project-run-hgsvc-hybrid-assemblies/notebooks"
_PROJECT_CONFIG_NB = str(pathlib.Path("00_project_config.ipynb").resolve(strict=True))
_PLOT_CONFIG_NB = str(pathlib.Path("05_plot_config.ipynb").resolve(strict=True))

%run $_PROJECT_CONFIG_NB
%run $_PLOT_CONFIG_NB

_MYNAME="sample-summary"
_MYSTAMP=get_nb_stamp(_MYNAME)

_MY_OUT_PATH = PLOT_OUT_SUPPL_FIG.joinpath("sample_info")

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def continental_group_bar_chart(axes):

    pop_sex_counts = HGSVC_SAMPLES.groupby(["supergroup", "sex", "rgb_rel_super"])["sample"].nunique()
    
    bars = []
    xlabels = []
    colors = []
    
    pop_legend = set()
    
    run_count = 0
    for (pop, sex, rgb), count in pop_sex_counts.items():
        rgb_t = tupleize(rgb)
        label = f"{sex[0]}\n{count}"
        xlabels.append(label)
        bars.append(count)
        colors.append(rgb_t)
        pop_legend.add((pop, rgb_t))
        run_count += count
        
    assert run_count == HGSVC_TOTAL
    
    axes.bar(
        list(range(1, len(xlabels)+1)),
        bars,
        color=colors
    )
    
    axes.set_xticks(list(range(1, len(xlabels)+1)))
    axes.set_xticklabels(xlabels)
    
    yticks = [3, 6, 9, 12]
    axes.set_yticks(yticks)
    axes.set_yticklabels(
        list(map(str, yticks)))
    axes.set_ylabel("Count")
    axes.set_xlabel("Samples: continental groups and sex")
    
    axes.spines["top"].set_visible(False)
    axes.spines["right"].set_visible(False)
    axes.spines["left"].set_visible(False)
    
    custom_legend = build_patch_legend(sorted(pop_legend))
    axes.legend(handles=custom_legend, loc="best")

    return axes


def create_contgroup_bars():

    fig, ax = plt.subplots(figsize=(8,6))
    ax = continental_group_bar_chart(ax)

    for ext in DEFAULT_PLOT_EXT:
        out_path = _MY_OUT_PATH.joinpath(f"fig_s2_pe_sample-pop-bars.{ext}")
        save_figure(out_path, fig)
    plt.close()
    return None


def population_legend_matrix(axes):

    pop_lexsort = HGSVC_SAMPLES.sort_values(["supergroup", "population"], inplace=False)

    # we have 28 populations plus 6 continental groups
    # = 34 fields
    nrows = 6
    ncols = 7
    rgb_matrix = np.ones((nrows, ncols, 3), dtype=float)
    labels = []

    def get_idx(pos_idx):
        row = pos_idx // ncols
        col = pos_idx % ncols
        return row, col

    cgroup_pos_reset = {
        "AMR": 14,
        "EAS": 21,
        "EUR": 28,
        "SAS": 35
    }
    
    last_cgroup = None
    pos_idx = 0
    reset_done = set()
    for (cgroup, pop), samples in pop_lexsort.groupby(["supergroup", "population"]):
        if last_cgroup is None or last_cgroup != cgroup:
            if cgroup in cgroup_pos_reset and cgroup not in reset_done:
                pos_idx = cgroup_pos_reset[cgroup]
                reset_done.add(cgroup)
            cgroup_color = tupleize(samples["rgb_rel_super"].iloc[0])
            row_idx, col_idx = get_idx(pos_idx)
            rgb_matrix[row_idx, col_idx, :] = cgroup_color
            labels.append((row_idx, col_idx, cgroup))
            last_cgroup = cgroup
            pos_idx += 1

        pop_color = tupleize(samples["rgb_rel_pop"].iloc[0])
        row_idx, col_idx = get_idx(pos_idx)
        rgb_matrix[row_idx, col_idx, :] = pop_color
        labels.append((row_idx, col_idx, pop))
        pos_idx += 1

    axes.imshow(rgb_matrix, interpolation="none")
    minor_xticks = axes.set_xticks(np.arange(-.5, ncols-0.4, 1), minor=False)
    minor_yticks = axes.set_yticks(np.arange(-.5, nrows-0.4, 1), minor=False)
    axes.grid(which='major', axis="both", color="white", linestyle="-", linewidth=1)
    axes.tick_params(
        which="both", axis="both",
        top=False, bottom=False, left=False, right=False,
        labeltop=False, labelbottom=False, labelleft=False, labelright=False
    )
    axes.spines["left"].set_visible(False)
    axes.spines["bottom"].set_visible(False)
    
    for (row, col, label) in labels:
        if label == "AFR|AMR":
            label = "AFR\nAMR"
        color = "white"
        fw = "bold"
        if label in ["AFR", "AFR\nAMR", "AMR", "EAS", "EUR", "SAS"]:
            fw = "bold"
            color="black"
        axes.text(
            col, row, label, ha="center", va="center",
            fontsize=MPL_TEXT_SIZE-2, color=color, fontweight=fw
        )

    axes.set_ylabel("Continental groups / populations")
    return axes


def create_pop_legend_matrix():

    fig, ax = plt.subplots(figsize=(20,4))    
    ax = population_legend_matrix(ax)

    for ext in DEFAULT_PLOT_EXT:
        out_path = _MY_OUT_PATH.joinpath(f"fig_s1_pe_color-code.{ext}")
        save_figure(out_path, fig)
    plt.close()
    return None


_ = create_contgroup_bars()
_ = create_pop_legend_matrix()