In [None]:
# default_exp clone_analysis

In [None]:
# hide
%load_ext autoreload
%autoreload 2

In [None]:
# export
from functools import reduce

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Clone Analysis Functions

## data wrangling functions

In [None]:
# export
def _combine_agg_functions(additional_agg_functions):
    if additional_agg_functions is None:
        additional_agg_functions = {}

    agg_functions = {"label": "count", "area_um2": [np.mean, np.std]}
    return {**agg_functions, **additional_agg_functions}

In [None]:
# export
def _individual_filter_condition(
    df, filtered_col_name: str, query: str, clone_channel: str, agg_functions
):
    if query is not None:
        df = df.query(query)

    temp_df = (df.groupby(["int_img", clone_channel]).agg(agg_functions)).copy()

    temp_df.columns = pd.MultiIndex.from_tuples(
        [(filtered_col_name,) + a for a in temp_df.columns]
    )
    return temp_df

In [None]:
# export
def query_df_groupby_by_clone_channel(
    df, queries: dict, clone_channel: str = "C1", additional_agg_functions: dict = None,
):
    """additional agg_functions could be something like:
    additional_agg_functions = {"mean_intensity": [np.mean, np.std]}"""

    agg_functions = _combine_agg_functions(additional_agg_functions)
    df = df.reset_index()

    l = list()
    for key, query in queries.items():
        l.append(
            _individual_filter_condition(df, key, query, clone_channel, agg_functions)
        )

    return reduce(
        lambda df_left, df_right: pd.merge(
            df_left, df_right, how="outer", left_index=True, right_index=True
        ),
        l,
    )

## data visualization functions

In [None]:
# export
def create_stack_bar_plot(
    df,
    df_error_bar=None,
    x_figSize=2.5,
    y_figSize=2.5,
    y_label=None,
    y_axis_start=0,
    y_axis_limit=None,
    color_pal=sns.color_palette(palette="Blues_r"),
    bar_width=0.8,
):

    fig, ax = plt.subplots(figsize=(x_figSize, y_figSize))

    sns.set(style="ticks")

    ax = df.plot(
        kind="bar",
        stacked=True,
        color=color_pal,
        width=bar_width,
        ax=ax,
        yerr=df_error_bar,
        capsize=4,
    )
    ax.set_ylabel(y_label)
    sns.despine(ax=ax)
    ax.xaxis.set_tick_params(width=1)
    ax.yaxis.set_tick_params(width=1)
    ax.tick_params(axis="both", which="major", pad=1)
    plt.setp(ax.spines.values(), linewidth=1)

    if not y_axis_limit == None:
        ax.set_ylim(top=y_axis_limit)

    handles, labels = ax.get_legend_handles_labels()

    ax.legend(
        reversed(handles), reversed(labels), bbox_to_anchor=(1, 1), loc="upper left"
    )

In [None]:
#export
def plot_stat_annotation(
    x_indexes: tuple, y: int, p_values: list, sep: int = None, text_colors: list = None
):
    if sep is None:
        sep = y / 50

    if text_colors is None:
        text_colors = ["k"] * len(p_values)

    x1, x2 = x_indexes
    plt.plot([x1, x1, x2, x2], [y, y + sep, y + sep, y], lw=1.5, c="k")

    for i, (pval, col) in enumerate(zip(p_values, text_colors), 0):
        spacing = y + i * sep * 4
        plt.text(
            (x1 + x2) / 2,
            spacing,
            pval,
            ha="center",
            va="bottom",
            color=col,
            fontsize=18,
        )

In [None]:
#export
def pvals_to_stat_anots(
    pvals_arr,
    pval_thresholds=(0.0001, 0.001, 0.01, 0.05, 1),
    annotations=("****", "***", "**", "*", r"$^{ns}$"),
):
    return pd.cut(
        pvals_arr, bins=(0,) + pval_thresholds, labels=annotations
    ).astype(str)