In [1]:
%config InlineBackend.figure_format = 'retina'

import os
import numpy as np
import pandas as pd
import arviz as az
import xarray as xr
import pymc as pm
import pytensor
from pytensor.tensor import TensorVariable
from pymc.distributions.transforms import Interval
import pytensor.tensor as pt
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import colorsys
from matplotlib import colors as _colors
import matplotlib.transforms as transforms
import seaborn as sns
import string
from formulae import design_matrices
import itertools
from pytensor.printing import Print
from matplotlib.ticker import MaxNLocator

plt.style.use(['seaborn-v0_8-colorblind', 'seaborn-v0_8-dark'])

from cycler import cycler
# Retrieve the current color cycle
current_cycler = plt.rcParams['axes.prop_cycle']
colors = [c['color'] for c in current_cycler]
new_order = [0, 2, 1, 3, 5, 4]
new_colors = [colors[i] for i in new_order]
plt.rcParams['axes.prop_cycle'] = cycler(color=new_colors)
plt.rcParams['figure.constrained_layout.use'] = True

subset = True



In [None]:
import graphviz

fontsize = "10"

# Create a new directed graph
dot = graphviz.Digraph()

variables = {
    "X1": "14",
    "X2": "8",
    "X3": "8",
    "Y1": "7",
    "Y2": "4"
}


for var, dim in variables.items():
    dot.node(var, f"{var}\n({dim})", fontsize=fontsize)

# Adding edges
edges = [("X1", "Y1"), ("X1", "Y2"), 
         ("X2", "Y1"), ("X2", "Y2"),
         ("X3", "Y1"), ("X3", "Y2"),
         ("Y1", "Y2")]

for edge in edges:
    dot.edge(*edge)

# Display the graph
dot

In [2]:
def build_design_matrices(
    formula,
    df,
    query=None,
    drop=None,
):
    
    indices = np.arange(df.shape[0])
    if query:
        indices = df.query(query).index
        
    dm = design_matrices(formula, df)
    dm_common = None

    if dm.group is not None:
        colnames = []
        for group, vals in dm.group.slices.items():
            stop = vals.stop - vals.start
            val = list(range(0, stop))
            cols = [f"{group}[{i}]" for i in val]
            colnames.extend(cols)
        dm_groups = pd.DataFrame(dm.group.design_matrix, columns=colnames)
    
    if dm.common is not None:
        dm_common = dm.common.as_dataframe()
    
    dm = pd.concat([dm_common, dm_groups], axis=1)
    
    if drop:
        dm.drop(list(set(drop).intersection(dm.columns)),
                    axis=1, inplace=True)
    
    return dm.iloc[indices]

In [3]:
def _make_col_label(col):
    if col[0] == "1":
        col = col[1:].replace("|", "")
    col = col.split("[")[0]
    return col

def _make_label_for_hdi(label, val):
    val = str(val)
    val = val.replace(".0", "")
    new_label = label.replace("[", "=").replace("]", "").split("|")
    if new_label[0] == "1":
        return new_label[-1]
    if len(new_label) == 1:
        return f"{label}={val}"
    else:
        return f"={val}|".join(new_label)
    
def _make_label_for_input_group(col):
    if col in r.columns:
        return "r"
    elif col in b.columns:
        return "b"
    elif col in c.columns:
        return "c"
    else:
        raise ValueError(f"Column {col} not found in any of the design matrices.")

def get_hdi_dataframe(
    idata_posterior, 
    design_matrices,
    dimname, 
    hdi_prob=.95):
    
    """
    Returns a dataframe with the HDI for each parameter.
    idata_posterior: filtered idata posterior. 
    """
    

    # Create a list to store the maximum values for sorting

    data_dict = {
        "input_group": [],
        "col": [],
        "labels": [],
        "mean": [],
        "lower": [],
        "q1": [],
        "q3": [],
        "upper": [],
    }

    parameters = design_matrices.columns.tolist()

    for col in parameters:
        for val in design_matrices[col].unique():
            if val == 0:
                continue
            
            adjusted_coeff = (idata_posterior.sel({dimname: col}) * val).values
            mean = np.mean(adjusted_coeff)
            lower, q1, q3, upper = np.quantile(adjusted_coeff, [(1 - hdi_prob)/2, .25, .75, 1 - (1 - hdi_prob)/2])
            
            # Store the maximum value along with the plot information    
            input_group = _make_label_for_input_group(col)  
            label = _make_label_for_hdi(col, val)
            new_col = _make_col_label(col)
            tmp = [input_group, new_col, label, mean, lower, q1, q3, upper]
            
            for key, val in zip(data_dict.keys(), tmp):
                data_dict[key].append(val)
                
    hdi_df = pd.DataFrame(data_dict)
    hdi_df.input_group = pd.Categorical(hdi_df.input_group, categories=["c", "b", "r"], ordered=True)
    sorted_col = hdi_df[["input_group", "col", "mean"]].groupby(["input_group", "col"], observed=True).mean().sort_values(by=["input_group", "mean"]).reset_index()[["input_group", "col"]].values
    sorted_col
    indices = []
    for input_group, col in sorted_col:
        indices.append(
            np.argwhere(hdi_df.col.values ==  col)
        )

    indices = np.concatenate(indices).flatten()
    hdi_df = hdi_df.iloc[indices]
    return hdi_df

In [4]:
def plot_hdi(
    idata_posterior, 
    design_matrices,
    dimname, 
    hdi_prob=.95,
    fontsize=7,
    title=None,
):
    parameters = design_matrices.columns.tolist()
    nrows = len(parameters)
    fig = plt.figure(figsize=(4, nrows/4))
    
    gs = fig.add_gridspec(
        nrows=nrows, 
        ncols=2, figure=fig, 
        width_ratios=[.2, .3],
        hspace=0,
        wspace=0,
    )

    axes = gs.subplots()

    for ax in axes.flatten():
        c = ax.get_facecolor()
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor("none")

        # Set y position start and increment
        y_start = 0
        y_increment = 2.2
        y_increment_minor = 1.1
        hdi_prob = hdi_prob

        hdi_df = get_hdi_dataframe(
            idata_posterior=idata_posterior,
            design_matrices=design_matrices,
            dimname=dimname,
            hdi_prob=hdi_prob
        )


    ax = fig.add_subplot(gs[:,-1])
    ys = []
    ypos = y_start
    col_prev = hdi_df.col.iloc[0]
    input_group_prev = hdi_df.input_group.iloc[0]
    ypos_group_start = y_start
    ypos_group = []
    
    for i, (input_group, col, label, mean, lower, q1, q3, upper) in enumerate(hdi_df.values):
        
        if input_group != input_group_prev:
            ypos_group.append((ypos_group_start, ypos))
            input_group_prev = input_group
            
            ypos_group_start = ypos

        if (col == col_prev) & (i > 0):
            ypos += y_increment_minor
        else:
            ypos += y_increment
            col_prev = col

        ax.hlines(ypos, lower, upper, color="k", lw=.4, zorder=1)
        ax.hlines(ypos, q1, q3, color="k", lw=.8, zorder=1)
        ax.scatter(mean, ypos, marker="o", color=c, edgecolors="k", linewidth=.4,  s=5, alpha=1, zorder=2)
        ys.append(ypos)

    ypos_group.append((ypos_group_start, ypos))
        
    ax.set_yticks([])
    ax.tick_params(axis="x", labelsize=fontsize)
    ax1 = fig.add_subplot(gs[:,0], sharey=ax, frameon=False)
    
    
    for ypos, label in zip(ys, hdi_df.labels):
        ax1.annotate(label, (0.2, ypos), ha='left', va='center', fontsize=fontsize)
    
    xmin, xmax = ax1.get_xlim()
    
    for (ystart, yend), label in zip(ypos_group, hdi_df.input_group.unique()):
        ypos = ystart + (yend - ystart)/2
        print(f"{ypos = }")
        ax1.annotate(label, (0, ypos), 
                     #xytext=(-2, 0), 
                     ha='center', va='bottom',
                     # textcoords="offset points", 
                     fontsize=fontsize, rotation=90, fontweight="bold")
        ax1.hlines(yend + y_increment/2, xmin=xmin, xmax=xmax, color="k", lw=.8, zorder=1)

    ax1.set_xticklabels([]);
    
    if title is False:
        title = ""
    elif title is None:
        title = f"Effects of {dimname} with HDI {hdi_prob:.0%}"
    
    ax.set_title(title, fontsize=fontsize, fontweight="bold", pad=0)
    
    ymin, ymax = ax.get_ylim()
    ymin = y_start #+ .05
    ymax = (ymax - yend)/2 + yend
    ax.vlines(x=0, ymin=ymin, ymax=ymax, colors="k", ls="--", lw=.8, alpha=.3, zorder=0)
    ax.set_ylim(ymin, ymax)
    ax1.set_ylim(ymin, ymax)

In [6]:
hdi_df = []

for output in b_output:
    _df = get_hdi_dataframe(
        idata_posterior=idata.posterior.param,
        design_matrices=_dm_h,
        dimname="input_h",
    )
    _df["y"] = "h"
    _df["response"] = output
    _df["param"] = "input_h"
    hdi_df.append(_df)
    

for output in r_output:
    _df = get_hdi_dataframe(
        idata_posterior=idata.posterior.param,
        design_matrices=_dm_f,
        dimname="input_f",
    )
    _df["y"] = "f"
    _df["response"] = output
    _df["param"] = "input_f"
    hdi_df.append(_df)
    
hdi_df = pd.concat(hdi_df)

In [None]:
def get_feature_importance(df, groupby:list):
    assert "mean" in df.columns.tolist(), "Requre mean column in dataframe"

    numeric = df.select_dtypes(float).columns.tolist()
    df["positive"] = df["mean"].apply(lambda x: 1 if x >= 0 else 0)
    df["negative"] = df["mean"].apply(lambda x: 1 if x <= 0 else 0)
    df["log_abs_post_mean"] = np.log(np.abs(df["mean"]))
    df = df.groupby(groupby)[["log_abs_post_mean", "positive", "negative"]].mean().reset_index()
    df["geometric_post_mean"] = np.exp(df["log_abs_post_mean"])
    df["consistency_score"] = np.abs(df["positive"] - df["negative"])
    df["importance"] = df["geometric_post_mean"] * df["consistency_score"]
    df.sort_values(by="importance", ascending=False, inplace=True)
    df.reset_index(drop=True, inplace=True)
    return df

In [None]:
importance_by_label = get_feature_importance(hdi_df, groupby=["labels"])
importance_by_response = get_feature_importance(hdi_df, groupby=["y", "col"])
importance_overall = get_feature_importance(hdi_df, groupby=["col"])

In [None]:
fig, ax = plt.subplots(figsize=(6, 10))

ys = np.arange(len(importance_by_label))[::-1]
xmax = importance_by_label['importance'].values
 
# The horizontal plot is made using the hline function
ax.hlines(y=ys, xmin=0, xmax=xmax, lw=1)
ax.scatter(x=xmax, y=ys, marker="o", s=30, edgecolor="k", alpha=1)

ax.set_xlim(-.001)
ax.set_yticks(ys)
ax.set_yticklabels(importance_by_label['labels'])
ax.set_xlabel('Feature importance')

ax.tick_params(axis="y", labelsize=8)

In [None]:

ax = az.plot_forest(
    data=idata,
    var_names=[

        ],
    combined=True,
    figsize=(12, 160),
    kind="ridgeplot",
    ridgeplot_truncate=False,
    ridgeplot_alpha=.7,
    ridgeplot_quantiles=[.5,],
    hdi_prob=.95,
    markersize=2,
    linewidth=1,
)
ymin, ymax = ax[0].get_ylim()
plt.vlines(x=0, ymin=ymin, ymax=ymax, colors="k", ls="--", lw=.8)

In [None]:
from scipy.stats import gaussian_kde

def plot_density_with_ref_and_rope(data, reference=0, rope=None, ax=None, 
                                   add_legend=False, remove_yticks=True, remove_xaxis=False,
                                   title=None, xlabel=None, xtick_fontsize=None,
                                   title_fontsize=None, subplot_title_fontsize=None,
                                   percent_region=0.05):
    
    # If no axis is provided, create one
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 5))
    
    # Generate kernel density estimate
    density = gaussian_kde(data)
    xs = np.linspace(min(data) - 1, max(data) + 1, 1000)
    ys = density(xs)
    
    # Plot the density
    ax.plot(xs, ys, label='Density', color='blue')
    ax.fill_between(xs, 0, ys, color='blue', alpha=0.25)  # Light fill for entire density
    
    # Remove y-ticks if specified
    if remove_yticks:
        ax.set_yticks([])
    
    # Remove x-axis if specified
    if remove_xaxis:
        ax.set_xticks([])
        ax.spines['bottom'].set_visible(False)
        ax.axhline(y=0, color='black', linewidth=1.2)  # Add bottom line to close the density plot
    else:
        ax.spines['bottom'].set_visible(True)
        ax.spines['bottom'].set_position(('data', 0))  # Ensures no gap between the density and x-axis
    
    # Remove top and right axis borders
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    # Set subplot title
    if title is not None:
        ax.set_title(title, fontsize=subplot_title_fontsize)
    
    # Set x-label
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=xtick_fontsize)
    
    # Set x-tick font size
    if xtick_fontsize is not None:
        ax.tick_params(axis='x', labelsize=xtick_fontsize)
    
    # Add legend only to the first subplot if specified
    if add_legend and (ax.is_first_col() if hasattr(ax, 'is_first_col') else True):
        ax.legend(loc='upper right', fontsize=xtick_fontsize)
    
    # If ROPE is provided, compute and show the probability within ROPE
    if rope is not None:
        lower_ref_density = density(rope[0])[0]  # Get the scalar value
        upper_ref_density = density(rope[1])[0]  # Get the scalar value
        ymax_lower = lower_ref_density / max(ys)
        ymax_upper = upper_ref_density / max(ys)
        ax.axvline(x=rope[0], ymax=ymax_lower, color='green', linestyle='--')
        ax.axvline(x=rope[1], ymax=ymax_upper, color='green', linestyle='--')
        ax.fill_between(xs, 0, ys, where=(xs >= rope[0]) & (xs <= rope[1]), color='green', alpha=0.3)
        prob_in_rope = density.integrate_box_1d(rope[0], rope[1])
        ax.annotate(f'{prob_in_rope:.1%} in ROPE', 
                     xy=((rope[0]+rope[1])/2, max(lower_ref_density, upper_ref_density)/100), 
                     xytext=(0, 10), 
                     textcoords='offset points',
                     ha='center', va='bottom')
    
    # Show the reference line if provided
    if reference is not None:
        ref_density = density(reference)[0]  # Get the scalar value
        ymax = ref_density / max(ys)
        ax.axvline(x=reference, ymax=ymax, color='red', linestyle='--')
        ax.fill_between(xs, 0, ys, where=(xs <= reference), color='red', alpha=0.3)
        prob_below = density.integrate_box_1d(-np.inf, reference)
        prob_above = 1 - prob_below
        ax.annotate(f'{prob_below:.1%} below < {reference} < {prob_above:.1%} above', 
                     xy=(reference, ref_density/20), 
                     xytext=(0, 10), 
                     textcoords='offset points',
                     ha='center', va='bottom')
    
    if (add_legend is False) and (ax.get_legend() is not None):
        ax.get_legend().remove()

    # Ensure the y-axis starts at 0
    ax.set_ylim(bottom=0)
    
    if percent_region is not None:
        y = ax.get_ylim()[1]*.01
        lower, upper = np.quantile(data, q=[percent_region, 1-percent_region])
        ax.hlines(xmin=lower, xmax=upper, y=y, color='black', linewidth=2)

    # Show the plot if we are not passing in an Axes object
    if ax is None:
        plt.show()

In [None]:
f_cols = [

]

exclude = [

]

conditional_effect_grouping = {

}

queries = []
queries_idx = []
subplot_titles = []
query_label = defaultdict(str)
f_query_map = defaultdict(list)

for col in f_cols:
    col_list = col.split("|")
    values = df[col_list].value_counts().index.to_list()
    if len(col_list) == 1:
        values = [val[0] for val in values]
        for val in values:
            new_query = f"({col} == {val})"
            query_idx = df.query(new_query).index.tolist()
            queries.append(new_query)
            queries_idx.append(query_idx)
            f_query_map[col].append(new_query)
            query_label[new_query] = f"{col}={val}"
    else:
        col1, col2 = col_list
        for val in values:
            query1 = f"({col1} == {val[0]})"
            query2 = f"({col2} == {val[1]})"
            
            if (query1 in exclude) | (query2 in exclude):
                continue
            else:
                new_query = query1 + " & " + query2   
                both_vary = False                   
                if col in conditional_effect_grouping:
                    if conditional_effect_grouping[col] != "both_vary":
                        conditioned_col = conditional_effect_grouping[col]
                        varying_col_idx = np.argwhere(np.array(col_list) != conditioned_col).ravel()[0]
                        varying_col = col_list[varying_col_idx]
                    else:
                        both_vary = True
                else:
                    varying_col = col1
                    conditioned_col = col2

                if both_vary:
                    f_query_map[f"{col1}|{col2}"].append(new_query)
                else:
                    if conditioned_col == col2:
                        f_query_map[f"{col1}|{col2}={val[1]}"].append(new_query)
                    else:
                        f_query_map[f"{col1}={val[0]}|{col2}"].append(new_query)
                    
                query_label[new_query] = f"{col}={val[0]},{col2}={val[1]}"                    
                query_idx = df.query(new_query).index.tolist()
                queries.append(new_query)
                queries_idx.append(query_idx)

# check for errors
for idx, query in zip(queries_idx, queries):
    assert (df.iloc[idx] == df.query(query)).values.all()

In [None]:
# to assist with manually changing the grouping
for col, query in f_query_map.items():
    print(f"{col}: {len(query)}")
    
numeric_vars = [

]

In [None]:
def scale_lightness(rgb, scale_l):
    # convert rgb to hls
    h, l, s = colorsys.rgb_to_hls(*rgb)
    # manipulate h, l, s values and return as rgb
    return colorsys.hls_to_rgb(h, min(1, l * scale_l), s = s)

def plot_density(data, reference=0, rope=None, 
                 percent_region=0.05, remove_xaxis=True,
                 get_reference_vals=False, 
                 return_colour=False,
                 kde_kwargs=None,
                 fill_kwargs=None,
                 **kwargs):
    
    # If no axis is provided, create one
    if "ax" not in kwargs:
        fig, ax = plt.subplots(figsize=(5, 3.5))
    
    # Default empty kwargs
    if kde_kwargs is None:
        kde_kwargs = {}  
    if fill_kwargs is None:
        fill_kwargs = {} 
        
    combined_kwargs = {**kde_kwargs, **kwargs}
    ax = sns.kdeplot(data, cut=0, color="k", lw=.6, **kwargs, alpha=0)
    
    # Retrieve the density line for ROPE and reference area shading
    density_line = ax.get_lines()
    line_idx = len(density_line) - 1
    xs = density_line[line_idx].get_xdata()
    ys = density_line[line_idx].get_ydata()
    
    xmin, xmax = xs.min(), xs.max()
    
    ax.set_ylabel('')
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.hlines(y=0, xmin=xmin, xmax=xmax, color='black', linewidth=2.2)
    lower, upper = np.quantile(data, q=[percent_region, 1-percent_region])
    fill = ax.fill_between(xs, 0, ys, where=(xs >= lower) & (xs <= upper), alpha=0.25, **fill_kwargs)  # Light fill for entire density
    colour = fill.get_facecolor()
    ax.fill_between(xs, 0, ys, where=(xs <= lower) & (xs >= upper), alpha=0.05, color=colour)  # Light fill for entire density
    
    ax = sns.kdeplot(data, cut=0, color=scale_lightness(_colors.to_rgb(colour), .9), lw=.6, **kwargs)
    
    
    if remove_xaxis:
        ax.set_xticks([])
        ax.spines['bottom'].set_visible(False)
    else:
        ax.spines['bottom'].set_visible(True)
        ax.tick_params(axis='x', labelsize=8)
    ax.spines['bottom'].set_position(('data', 0))  # Ensures no gap between the density and x-axis
    
    # Show the reference line if provided
    if reference is not None:
        ref_density = (xs.max() - xs.min())/2  # Get the scalar value
        _idx = np.argwhere(xs <= reference)
        if len(_idx) == 0:
            ymax_idx = np.argmin(xs)
        else:
            ymax_idx = np.argwhere(xs <= reference).max()
        ymax = ys[ymax_idx]
        ax.vlines(x=reference, ymin=0, ymax=ymax, 
                  linestyle='dotted', lw=1, 
                  color=scale_lightness(_colors.to_rgb(colour), .9), # darker than fill
                )
        prob_below = (xs <= reference).sum() / len(xs)
        prob_above = 1 - prob_below
        ax.annotate(f'{prob_below:.1%} < {reference} < {prob_above:.1%}', 
                     xy=(ref_density, ymax/20), 
                     xytext=(0, 10), 
                     textcoords='offset points',
                     ha='center', va='bottom')
        
    ax.set_xlim(xmin, xmax)
    
    if "ax" not in kwargs:
        plt.show()
        
    if return_colour:
        return colour

In [None]:
density_dict = defaultdict(list)
for row, query in f_query_map.items():
    num_var = np.array([var in row for var in numeric_vars]).any()
    if not num_var:
        density_dict[row] = query
        
density_labels = defaultdict(list)
for key, query in density_dict.items():
    for q in query:
        qtext = q.replace(" ", "").replace("==", "=")\
            .replace(".0", "").replace("(", "").replace(")", "").replace("&", ",")
        density_labels[key].append(qtext)

In [None]:
def plot_custom_text(ax, xpos, ypos, annotate_str, colour, fontsize=7):
    
    spacing = 0.005  # Initial spacing between text elements

    for text, color in zip(annotate_str, colour):
        # Add the text to the plot
        ax.text(xpos, ypos, text, fontsize=fontsize, color=color, ha='left', va='center', transform=ax.transAxes, alpha=1)

        # Calculate the width of the text to adjust xpos for the next text
        text_width = get_text_width(text, fontsize, ax)
        xpos += text_width / ax.figure.dpi + spacing

def get_text_width(text, fontsize, ax):
    renderer = ax.figure.canvas.get_renderer()
    t = ax.text(0, 0, text, fontsize=fontsize, alpha=0, transform=ax.transAxes)
    bbox = t.get_window_extent(renderer)
    t.remove()
    return bbox.width

In [None]:
def plot_column_densities(
    ppc_idata,
    density_dict,
    density_labels,
    row_headers=["", ""],
    legend_pos=1,
    figwidth=None,
    figheight=None,
):
    
    rng = np.random.default_rng(seed=88)
    n_samples = 200
    nrows = len(density_dict.keys()) + 1
    ylabels = list(density_dict.keys())

    if figwidth is None:
        figwidth = 5
    if figheight is None:
        figheight = nrows/1.5
    
    fig = plt.figure(figsize=(figwidth, figheight))
    gs = fig.add_gridspec(
        nrows=nrows, 
        ncols=3, figure=fig, 
        width_ratios=[.4, 1, .5], 
        height_ratios=[.1] + [1]*(nrows-1)
    )

    axes = gs.subplots()

    for ax in axes.flatten():
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor("none")

    for row, ax in enumerate(axes[1:,0]):
        ax.annotate(ylabels[row-1], xy=(.5, 0.5), 
                    xycoords="axes fraction", ha="center", 
                    va="center", fontsize=8)

    for header, ax in zip(row_headers, axes[0,1:]):
        ax.annotate(header, xy=(.5, 0.5), 
                xycoords="axes fraction", ha="center", 
                va="center", fontsize=9)
        
    stored_values = {}

    for row, ax, (col_label, query) in zip(
        list(range(1, nrows)), 
        axes[1:,1], 
        density_dict.items()):
        
        # ax.clear() to store facecolour and ticklabels
        ax.clear()
        stored_values[col_label] = {}
        for l, q in enumerate(query): 
            label = density_labels[col_label][l]       
            indices = df.query(q).index
            replace = len(indices) < n_samples
            random_indices = rng.choice(indices, size=n_samples, replace=replace)
            values = ppc_idata.sel(i=random_indices).values
            kde_values = values.mean(axis=0)
            colour = plot_density(kde_values, ax=ax, reference=None, 
                        fill_kwargs=dict(label=label),
                        return_colour=True,)
            ax.legend(loc=legend_pos, fontsize=6)
            
            # save for later
            stored_values[col_label][label] = {"values": values, "colour": colour}

    for row, ax, (col_label, query) in zip(
        list(range(1, nrows)), 
        axes[1:,2], 
        density_dict.items()):
        
        ypos = .8    
        spacing = 0.03
        for i in range(len(query)):    
            for j in range(i + 1, len(query)):   
                
                q1 = density_labels[col_label][i]
                q2 = density_labels[col_label][j]     
                v1 = stored_values[col_label][q1]["values"]
                v2 = stored_values[col_label][q2]["values"]
                c1 = stored_values[col_label][q1]["colour"]
                c2 = stored_values[col_label][q2]["colour"]
                
                if v1.ndim == 2:
                    diff = (v1 > v2).sum(axis=0)/v1.shape[0]
                    mu_diff = diff.mean()
                elif v1.ndim == 1:
                    diff = (v1 > v2).sum(axis=0)/len(v1)
                    mu_diff = diff.mean()
                else:
                    raise ValueError("ndim must be 1 or 2")
                # Loop through the characters in the annotation string
                colour = ["k", c1, "k", c2, "k"]
                annotate_str = ["P(", r"$c_1$", ">", r"$c_2$", rf") = {mu_diff:.1%}"]  
                plot_custom_text(ax, .2, ypos, annotate_str, colour)            
                ypos -= 0.25

    xmin, xmax = np.inf, -np.inf
    for axis in axes[1:,1]:
        xmin, xmax = axis.get_xlim()
        if xmin < xmin:
            xmin = xmin
        if xmax > xmax:
            xmax = xmax

    for axis in axes[1:,1]:
        axis.set_xlim(xmin, xmax)

    locator = MaxNLocator(4)
    axes[-1,1].xaxis.set_major_locator(locator)
    axes[-1,1].tick_params(axis='x', labelsize=8)
    axes[-1,1].set_xlabel("Count", fontsize=8);

In [None]:
numeric_dict = defaultdict(list)
for row, query in f_query_map.items():
    num_var = np.array([var in row for var in numeric_vars]).any()
    if num_var:
        numeric_dict[row] = query

numeric_labels = defaultdict(list)
for key, query in numeric_dict.items():
    for q in query:
        qtext = q.replace(" ", "").replace("==", "=")\
            .replace(".0", "").replace("(", "").replace(")", "").replace("&", ",")
        numeric_labels[key].append(qtext)

In [None]:
# to assist with manually changing the grouping
for row, query in f_query_map.items():
    num_var = np.array([var in row for var in numeric_vars]).any()
    if num_var:
        print(f"{row}: {query}")