In [None]:
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
# %matplotlib widget

import os
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm
sns.set_palette(['#1E1E1E', '#BB3524', '#F5D54A', '#384827', '#282F44'])
sns.set_context('paper')
sns.set_style({'axes.axisbelow': True, 
               'axes.edgecolor': '.15',
               'axes.facecolor': 'white',
               'axes.grid': False, 
               'axes.labelcolor': '.15', 
               'figure.facecolor': 'white', 
               'grid.color': '.15',
               'grid.linestyle': ':', 
               'grid.alpha': .5, 
               'image.cmap': 'Greys', 
               'legend.frameon': False, 
               'legend.numpoints': 1, 
               'legend.scatterpoints': 1,
               'lines.solid_capstyle': 'butt', 
               'axes.spines.right': False, 
               'axes.spines.top': False,  
               'text.color': '.15',  
               'xtick.top': False, 
               'ytick.right': False, 
               'xtick.color': '.15',
               'xtick.direction': 'out', 
               'ytick.color': '.15', 
               'ytick.direction': 'out', 
              })


import matplotlib

FONT_SIZE_PT = 5
matplotlib.rcParams['font.family'] = 'Arial'
matplotlib.rcParams['font.size'] = FONT_SIZE_PT
matplotlib.rcParams['axes.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['axes.titlesize'] = FONT_SIZE_PT
matplotlib.rcParams['figure.titlesize'] = FONT_SIZE_PT
matplotlib.rcParams['xtick.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['ytick.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['legend.fontsize'] = FONT_SIZE_PT
matplotlib.rcParams['legend.title_fontsize'] = FONT_SIZE_PT

matplotlib.rcParams['xtick.major.size'] = matplotlib.rcParams['ytick.major.size'] = 2
matplotlib.rcParams['xtick.major.width'] = matplotlib.rcParams['ytick.major.width'] = 0.5


matplotlib.rcParams['xtick.minor.size'] = matplotlib.rcParams['ytick.minor.size'] = 1

matplotlib.rcParams['xtick.minor.width'] = matplotlib.rcParams['ytick.minor.width'] = 0.5

matplotlib.rcParams['axes.linewidth'] = 0.5
matplotlib.rcParams['lines.linewidth'] = 0.5
matplotlib.rcParams['grid.linewidth'] = 0.25
matplotlib.rcParams['patch.linewidth'] = 0.25
matplotlib.rcParams['lines.markeredgewidth'] = 0.25
matplotlib.rcParams['lines.markersize'] = 2

FIVE_MM_IN_INCH = 0.19685
DPI = 600
matplotlib.rcParams['figure.figsize'] = (10 * FIVE_MM_IN_INCH, 9 * FIVE_MM_IN_INCH)
matplotlib.rcParams['savefig.dpi'] = DPI
matplotlib.rcParams['figure.dpi'] = DPI // 4


#http://phyletica.org/matplotlib-fonts/
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

pd.set_option("display.max_columns", 200)

In [None]:
import palettable

In [None]:
HAVE_SNAKEMAKE = 'snakemake' in locals()

if HAVE_SNAKEMAKE:
    
    input_consolidated_stats = snakemake.input.consolidated_tsv
    
    param_factors_x = snakemake.params.chip_features_x
    param_factors_y = snakemake.params.chip_features_y
    param_cell_line = snakemake.params.cell_line
    
    output_plot_dir = snakemake.output.output_dir
    
    param_hue_palette_factor_type = snakemake.params.get('hue_palette_factor_type', None)
    param_shape_palette_factor_type = snakemake.params.get('shape_palette_factor_type', None)
    param_input_header_separator = snakemake.params.get('input_header_sep', '__')
    param_limits_log2 = snakemake.params.get('limits_log2', None)
    
    # Default to hmean
    param_agg_op = snakemake.params.get('chip_feature_agg_op', 'hmean')
    param_main_stat_col = snakemake.params.get('chip_feature_main_stat_col', 'normalised_mi')
    
else:
    print("No snakemake -- DEBUG MODE")
    
    _OUTDIR = '.nb-testing-outputs'
    if not os.path.isdir(_OUTDIR):
        os.makedirs(_OUTDIR)
    
   
    _bin_size = 1000
    _pseudocount = 100
    _min_periods = 1
     
    _OUTDIR = '.nb-testing-outputs'
    if not os.path.isdir(_OUTDIR):
        os.makedirs(_OUTDIR)
    
    param_cell_line = 'K562'
    input_consolidated_stats = f'../../output/final/analysis/params_{_bin_size}bp_pc_{_pseudocount}_mp_{_min_periods}/{param_cell_line}/consolidated_tables/bedstats_consolidated_{param_cell_line}_{_bin_size}bp_params_pc_{_pseudocount}_mp_{_min_periods}_from_bed.csv.gz'
       
    param_correlation_method = 'kendall'
    param_input_header_separator = '__'

    param_factors_x = ["state:1_TssA"]
    param_factors_y = ['state:7_Enh']

#     param_factors_x = ['H3K4me1-K562-ENCFF540NGG', 'H3K4me1-K562-ENCFF759NWD']
#     param_factors_y = ['H3K4me3-K562-ENCFF689QIJ', 'H3K4me3-K562-ENCFF706WUF','H3K4me3-K562-ENCFF885FQN']

    param_hue_palette_factor_type = None
    param_shape_palette_factor_type = None
    
    # Sensible x and y- lim for scatterplot
    param_limits_log2 = None
    
    output_plot_dir = os.path.join(_OUTDIR, 'pairwise_plots')
    
    param_agg_op = 'hmean'
    
    param_main_stat_col = 'normalised_mi'
    
param_fdr_method = 'fdr_bh'
param_alpha = 0.05

In [None]:
param_main_stat_col

In [None]:
import re
param_factors_x_safe = '_and_'.join([re.sub('[^a-zA-Z0-9_-]+', '_', x) for x in param_factors_x])
param_factors_y_safe = '_and_'.join([re.sub('[^a-zA-Z0-9_-]+', '_', x) for x in param_factors_y])
param_cell_line_safe = re.sub('[^a-zA-Z0-9_-]+', '_', param_cell_line)

param_factors_x_safe, param_factors_y_safe, param_cell_line

In [None]:
param_factors_x_label = param_factors_x[0] if len(param_factors_x) == 1 else '{}({})'.format(param_agg_op, ', '.join(param_factors_x))
param_factors_y_label = param_factors_y[0] if len(param_factors_y) == 1 else '{}({})'.format(param_agg_op, ', '.join(param_factors_y))

param_factors_x_label, param_factors_y_label

In [None]:
%load_ext autoreload
%aimport helpers
%autoreload 1

In [None]:
if not os.path.isdir(output_plot_dir):
    os.makedirs(output_plot_dir)

In [None]:
if param_hue_palette_factor_type is None:
    param_hue_palette_factor_type = helpers.get_default_hue_palette_factor_type()
    
if param_shape_palette_factor_type is None:
    param_shape_palette_factor_type = helpers.get_default_shape_palette_factor_type()
    
print(f'{param_hue_palette_factor_type=}')
print(f'{param_shape_palette_factor_type}')

# Input

So the previous step of the pipeline was aimed to assemble the data into an easily-digestable format.

The goal of this notebook is to compare the mutual information responses of two factors, `param_factor_x`, and `param_factor_y`.

This is what the data looks like as we import it:

In [None]:
data = pd.read_csv(input_consolidated_stats, index_col=0)
data.columns = pd.MultiIndex.from_tuples(
    [
        c.split(param_input_header_separator) for c in data.columns
    ],
    names=['header', 'column']
)

data

In [None]:
data[data[('metadata', 'factor_type')] != 'protein']

To match what we are doing in `summarise_results.py.ipynb`, let's reaggregate data by factor

In [None]:
# see helpers.reaggregate_by_factor
data = helpers.reaggregate_by_factor(data, factor_col=('metadata', 'factor'))
data['marcs_feature_significant_category'] = data['marcs_feature_significant_category'].fillna('No data')

Otherwise we really only need the MI data for for the factors of interest. Fetch that

In [None]:
data_stats_x = data[[(param_main_stat_col, xx) for xx in param_factors_x]]
data_stats_y = data[[(param_main_stat_col, yy) for yy in param_factors_y]]

In [None]:
data_stats_x.head()

In [None]:
data_stats_y.head()

For cases where we have more than one param selected, we are going to average the data columnwise.


We will use `param_agg_op`:

In [None]:
print(f'{param_agg_op=}')
_op = helpers.nan_aware_hmean if param_agg_op == 'hmean' else param_agg_op
print(f'{_op=}')

In [None]:
data_stats = pd.DataFrame({
    (param_main_stat_col, 'factors_x'): data_stats_x.agg(_op, axis=1),
    (param_main_stat_col, 'factors_y'): data_stats_y.agg(_op, axis=1)
})
data_stats.head()   

We will also need some metadata

In [None]:
data_meta = data[['metadata']].copy()
data_marcs = data[['marcs_feature_significant_category']]

In [None]:
MARCS_FEATURE_ORDER = helpers.MARCS_FEATURE_ORDER

In [None]:
data_meta

Add the names of the factors that we are comparing to the metadata:

In [None]:
data_meta[('metadata', 'comparison_factor_x')] = [';'.join(param_factors_x)] * len(data_meta)
data_meta[('metadata', 'comparison_factor_y')] = [';'.join(param_factors_y)] * len(data_meta)

data_meta

At this point, it makes life easier to convert the stat to `log2` scale. For MI the zeroes are not an issue because we're handling smoothing at the counting step.

In [None]:
data_stats[(f'{param_main_stat_col}_log2', 'factors_x')] = data_stats[(param_main_stat_col, 'factors_x')].apply(np.log2)
data_stats[(f'{param_main_stat_col}_log2', 'factors_y')] = data_stats[(param_main_stat_col, 'factors_y')].apply(np.log2)


In [None]:
data_stats

We can now calculate the log diff (`log2(x) - log2(y)`)

In [None]:
data_stats[f'{param_main_stat_col}_log2_diff', 'factors_x_minus_factors_y'] = data_stats[(f'{param_main_stat_col}_log2', 'factors_x')] - data_stats[(f'{param_main_stat_col}_log2', 'factors_y')]
data_stats

And this should be everything

In [None]:
data_scatterplot = data_meta.join(data_stats).join(data_marcs)
data_scatterplot

Note that this plot contains all data, not just proteins

In [None]:
data_scatterplot.loc[data_scatterplot['metadata', 'factor_type'] != 'protein']

In [None]:
assert not data_scatterplot.index.duplicated().any()

We can save the output at this point:

In [None]:
data_scatterplot.reset_index().to_csv(
    os.path.join(output_plot_dir, f'data_scatterplot.csv'),
    index=False
)


# Natural scale plots

A sanity scatterplot before making stilised version:

In [None]:
x_col = f'{param_main_stat_col}', 'factors_x'
y_col = f'{param_main_stat_col}', 'factors_y'
c_col = f'{param_main_stat_col}_log2_diff', f'factors_x_minus_factors_y'

ax = plt.gca()

ax.scatter(
    x=data_scatterplot[x_col], 
    y=data_scatterplot[y_col], 
    c=data_scatterplot[c_col], 
    cmap='RdBu_r', 
    vmin=-4, vmax=4,
    edgecolor='black'
)

ax.plot(
    [
        min(data_scatterplot[x_col].min(), data_scatterplot[y_col].min()), 
         max(data_scatterplot[x_col].max(), data_scatterplot[y_col].max())
    ],
    [
        min(data_scatterplot[x_col].min(), data_scatterplot[y_col].min()), 
         max(data_scatterplot[x_col].max(), data_scatterplot[y_col].max())
    ],
    linestyle=':'
)
ax.set_xlabel(x_col)
ax.set_ylabel(y_col)

In [None]:
data_marcs

In [None]:
fig = plt.figure()

In [None]:
fig.bbox

In [None]:
_.bounds

In [None]:
import textwrap
from matplotlib.backends.backend_pdf import PdfPages
from contextlib import suppress
import gc
import matplotlib.ticker as ticker

RECRUITED_FEATURE_GROUP = 'Strongly recruited'
EXCLUDED_FEATURE_GROUP = 'Strongly excluded'

FEATURES_SHAPE_MAP = {
    'Neither': dict(edgecolor='#CDCDCD', linewidth=0.35, zorder=0, s=4, alpha=.8),
    'No data': dict(zorder=0, s=4, marker='x', alpha=.8),
    RECRUITED_FEATURE_GROUP: dict(edgecolor='#525252', linewidth=.35, zorder=1, s=8, alpha=.8),
    EXCLUDED_FEATURE_GROUP: dict(edgecolor='#525252', linewidth=.35, zorder=1, s=8, alpha=.8),
}

FEATURES_PALETTE = {
    'Neither': '#bdbdbd',
    'No data': '#bdbdbd',
    RECRUITED_FEATURE_GROUP: '#BA5047',
    EXCLUDED_FEATURE_GROUP: '#4B82B6',
}

def make_scatterplots(data, *, x_col, y_col, xlabel, ylabel, title, filename, limits=None):
    
    max_symbols_per_line = 40
    data = data.reset_index()
    
    
    
    kws = dict(
        x_col=x_col, 
        xlabel='\n'.join(textwrap.wrap(xlabel, max_symbols_per_line)),
        y_col=y_col, 
        ylabel='\n'.join(textwrap.wrap(ylabel, max_symbols_per_line)),
        annot_col=('metadata', 'factor'),
        approx_number_of_labels=20,
        lines=['xy'], 
        do_not_change_limits=limits is not None,
    )

    with PdfPages(filename) as pdf:
#     with suppress():
        
        for marcs_feature in [None] + list(MARCS_FEATURE_ORDER):
            
            fig = plt.figure(figsize=(10*FIVE_MM_IN_INCH, 10*FIVE_MM_IN_INCH))
            ax_main = plt.gca()
            
            if limits is not None:
                ax_main.set_xlim(*limits)
                ax_main.set_ylim(*limits)
            
            if marcs_feature is None:
                helpers.make_plot(
                    shape_col=('metadata', 'factor_type'),
                    shape_map=param_shape_palette_factor_type,
                    title='\n'.join(textwrap.wrap(title, max_symbols_per_line)),
                    axes=ax_main,
                    hue_palette=param_hue_palette_factor_type,
                    hue_col=('metadata', 'factor_type'), 
                    df=data, 
                    **kws,
                )
            else:
                
                data_feature_augmented = data.copy()
        
                # Hack: Propagate factor types into feature significant category so non-proteins apear with their symbols and not "x"
                new_column = ('marcs_feature_significant_category_or_feature_type', marcs_feature)
                data_feature_augmented[new_column] = data_feature_augmented[('marcs_feature_significant_category', marcs_feature)].copy()
                
                not_protein = data_feature_augmented[('metadata', 'factor_type')] != 'protein'
                data_feature_augmented.loc[not_protein, new_column] = data_feature_augmented.loc[not_protein, ('metadata', 'factor_type')]
                
                hue_col = shape_col = new_column
               
                helpers.make_plot(
                    shape_col=shape_col,
                    shape_map={**param_shape_palette_factor_type, **FEATURES_SHAPE_MAP},
                    title='\n'.join(textwrap.wrap(f'{title} vs MARCS Feature {marcs_feature}', max_symbols_per_line)),
                    axes=ax_main,
                    # Hack
                    label_strategy=set(data_feature_augmented[data_feature_augmented[shape_col].isin([RECRUITED_FEATURE_GROUP, EXCLUDED_FEATURE_GROUP])].index),
                    hue_palette={**param_hue_palette_factor_type, **FEATURES_PALETTE},
                    adjust_text_kws=dict(lim=100), # to reduce runtime
                    df=data_feature_augmented,
                    hue_col=hue_col,
                    **kws,
                )
            
            # Make sure x and y ticks are of the same precision
            
            default_ticker = ticker.AutoLocator()
            
            # Code below tries to get the default tick spacing for x and y axis
            # The ticker returns ticks, diff calculates differences, median is a bit redundant as all returned tick spacings should be the same anyway
            ticks_spacing_x = np.median(np.diff(default_ticker.tick_values(*ax_main.get_xlim())))
            ticks_spacing_y = np.median(np.diff(default_ticker.tick_values(*ax_main.get_ylim())))
            
            ticks_spacing_xy = np.max([ticks_spacing_x, ticks_spacing_y]) # use maximum spacing of the two axes
            
            # Make the tick locator constant
            ax_main.xaxis.set_major_locator(ticker.MultipleLocator(ticks_spacing_xy))
            ax_main.yaxis.set_major_locator(ticker.MultipleLocator(ticks_spacing_xy))
            
            sns.despine(ax=ax_main, offset=3)
            
            ax_main.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            
            pdf.savefig(bbox_inches='tight')
            # There's some sort of memory leak going on, let's try to avoid that
            plt.close(fig)
            plt.close('all')
            plt.close()
            gc.collect()
         
        
    # There's some sort of memory leak going on, let's try to avoid that
    plt.close('all')
    plt.close()
    gc.collect()


In [None]:
x_col = f'{param_main_stat_col}', 'factors_x'
y_col = f'{param_main_stat_col}', 'factors_y'

xlabel = f"Fraction of entropy explained by {param_factors_x_label}"
ylabel = f"Fraction of entropy explained by {param_factors_y_label}"
title = f'{param_factors_x_label} vs {param_factors_y_label} (natural scale, {param_cell_line})'


make_scatterplots(
    data_scatterplot, 
    x_col=x_col,
    y_col=y_col,
    xlabel=xlabel,
    ylabel=ylabel,
    title=title,
    filename=os.path.join(output_plot_dir, f'pairwise_scatterplot_{param_cell_line_safe}_{param_factors_x_safe}__vs__{param_factors_y_safe}_natural_scale.pdf')
)         

# Log2-scale plots

Quickly, figure out a sensible limits for the data if they were not specified:

In [None]:
if param_limits_log2 is None:
    # The reason why limits are asymmetrical is because we really care more a
    # about proteins with higher explained variance, than lower
    low_limit = data_scatterplot[f'{param_main_stat_col}_log2'].min(axis=1).quantile(0.02)
    high_limit = data_scatterplot[f'{param_main_stat_col}_log2'].stack().max() + 1
    
    param_limits_log2 = (low_limit, high_limit)

print(f"{param_limits_log2=}")

quick sanity check before making styled plot:

In [None]:
x_col = f'{param_main_stat_col}_log2', 'factors_x'
y_col = f'{param_main_stat_col}_log2', 'factors_y'
c_col = f'{param_main_stat_col}_log2_diff', 'factors_x_minus_factors_y'

ax = plt.gca()

ax.scatter(
    x=data_scatterplot[x_col], 
    y=data_scatterplot[y_col], 
    c=data_scatterplot[c_col], 
    cmap='RdBu_r', 
    vmin=-4, vmax=4,
    edgecolor='black'
)

ax.plot(param_limits_log2, param_limits_log2, linestyle=':')
ax.set_xlabel(x_col)
ax.set_ylabel(y_col)
ax.set_xlim(*param_limits_log2)
ax.set_ylim(*param_limits_log2)

And a styled plot

In [None]:
x_col = f'{param_main_stat_col}_log2', 'factors_x'
y_col = f'{param_main_stat_col}_log2', 'factors_y'

xlabel = f"Fraction of entropy explained by {param_factors_x_label}, log2"
ylabel = f"Fraction of entropy explained by {param_factors_y_label}, log2"
title = f'{param_factors_x_label} vs {param_factors_y_label} (log scale, {param_cell_line})'


make_scatterplots(
    data_scatterplot, 
    x_col=x_col,
    y_col=y_col,
    xlabel=xlabel,
    ylabel=ylabel,
    title=title,
    filename=os.path.join(output_plot_dir, f'pairwise_scatterplot_{param_cell_line_safe}_{param_factors_x_safe}__vs__{param_factors_y_safe}_log2_scale.pdf'),
    limits=param_limits_log2,
)         

# Co-relationship with MARCS features

In [None]:
helpers.get_stats??

In [None]:
stats = helpers.get_stats(
    data_scatterplot,
    column="normalised_mi_log2_diff",
    control_groups=['Neither', 'No data'], 
    test_groups=[RECRUITED_FEATURE_GROUP, EXCLUDED_FEATURE_GROUP],
    fdr_method=param_fdr_method,
    fdr_alpha=param_alpha,
    log2=False, # Data is already log2, no need to log2 again in the code.
)

# Rename the mean_diff to something more descriptive
# the added `log2` is because we're passing data as log2 already
stats = stats.rename(columns={'mean_diff': 'mean_log2_diff_of_diffs'})

# The get_stats function is written for many datasets... In our case, however, we will have only one "dataset" and it is factors_y_minus_factors_x
stats = stats.loc['factors_x_minus_factors_y']
stats['comparison_factor_x'] = ';'.join(param_factors_x)
stats['comparison_factor_y'] = ';'.join(param_factors_y)

stats = stats.sort_values(by='p-val')
stats.to_csv(os.path.join(output_plot_dir, 'mwu_stats.csv'))
stats

# World's smallest heatmap

In [None]:
_df = stats.copy()
_df['title'] = f'{param_factors_x_label} minus {param_factors_y_label}'
_df = _df.reset_index().set_index(['title', 'marcs_feature', 'group'])

matrix = _df['mean_log2_diff_of_diffs'].unstack(['marcs_feature', 'group'])
matrix_mask =  ~(_df['significant'].unstack(['marcs_feature', 'group']).fillna(False))
matrix_pvals = _df['p-val corrected'].unstack(['marcs_feature', 'group'])

matrix = matrix[MARCS_FEATURE_ORDER]
matrix_mask = matrix_mask[MARCS_FEATURE_ORDER]
matrix_pvals = matrix_pvals[MARCS_FEATURE_ORDER]

In [None]:
matrix

In [None]:
matrix_mask

In [None]:
matrix_pvals

In [None]:
_matrix = matrix.copy()
_mask = matrix_mask.copy()
_pvals = matrix_pvals.copy()


_row_order = list(_matrix.index)
_col_order = ['-'.join(ix) for ix in _matrix.columns]

_row_coords = pd.Series(
    np.arange(len(_row_order), 0, -1) * 10.0 - 5, # times five because of matplotlib dendrogram madness,
    index=_row_order,
    name='row_coord'
)

_col_coords = pd.Series(
    np.arange(1, len(_col_order)+1) * 10.0 - 5, # times five because of matplotlib dendrogram madness
    index=_col_order,
    name='col_coord'
)


# Don't mess with the wide format, convert all into narrow
# To do this make sure we have normal indices column names
for _df in [_matrix, _mask, _pvals, _row_coords, _col_coords]:
    if isinstance(_df.index, pd.MultiIndex):
        _df.index = ['-'.join(ix) for ix in _df.index]
    
    if isinstance(_df, pd.DataFrame) and isinstance(_df.columns, pd.MultiIndex):
        _df.columns = ['-'.join(ix) for ix in _df.columns]

# A couple operations to force the matrix into long format:
_matrix_long = pd.concat([_matrix, _mask, _pvals], keys=['matrix', 'mask', 'pvals'])
_matrix_long = _matrix_long.stack()
_matrix_long.index.names = ['measurement', 'row_name', 'col_name']
_matrix_long = _matrix_long.unstack('measurement')
_matrix_long = _matrix_long.join(_row_coords, on='row_name')
_matrix_long = _matrix_long.join(_col_coords, on='col_name')

# Bin the p-values
# If you change the bins, change the legend code below too
_matrix_long['pvals_binned'] = pd.cut(_matrix_long['pvals'], bins=[0, 0.01, 0.05, 1.0])
assert len(_matrix_long.pvals_binned.cat.categories) == 3
_pval_kwarg_dict = dict(zip(_matrix_long.pvals_binned.cat.categories, [dict(s=30, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=20, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=10, marker='s', edgecolor='#CDCDCD', linewidth=0.45)]))
_ns_pval_group = _matrix_long.pvals_binned.cat.categories[-1]


In [None]:
from matplotlib.gridspec import GridSpec

# Force the colour scale for min and max
_vmin = -2.1 # don't make this a round number (will help with the ticks)
_vmax = 2.1
_title = f'Mean log2 difference\nin normed MI\n(to other proteins)'
_cmap = 'RdBu_r'

n_rows, n_cols = _matrix.shape

fig = plt.figure(
    figsize=(
        # four mm * (columns + dendrogram) + [labels]
        4/5*FIVE_MM_IN_INCH*(n_cols+5) + FIVE_MM_IN_INCH*5, 
        # four mm * (rows + dendrogram)
        4/5*FIVE_MM_IN_INCH*(n_rows+5+5),
    ),
    constrained_layout=True,
)

gs = GridSpec(3, 2, width_ratios=[5, n_cols], height_ratios=[5, n_rows, 5], wspace=0.05, hspace=0.05)

ax_heatmap = fig.add_subplot(gs[1,1])
ax_heatmap.invert_yaxis()
ax_heatmap.xaxis.tick_top()
ax_heatmap.yaxis.tick_right()

sns.despine(ax=ax_heatmap, bottom=False, right=False, top=False, left=False)
ax_heatmap.xaxis.set_tick_params(length=0)
ax_heatmap.yaxis.set_tick_params(length=0)

# Draw heatmap
for _pval_group, _submatrix in _matrix_long.groupby('pvals_binned'):
    
     _colours = ax_heatmap.scatter(
        _submatrix['col_coord'], 
        _submatrix['row_coord'], 
        c=_submatrix['matrix'], 
        vmin=_vmin, vmax=_vmax, 
        cmap=_cmap,
        **_pval_kwarg_dict[_pval_group],
    )
        
# Set ticks
_yticks = _row_coords.sort_values()
ax_heatmap.set_yticks(_yticks.values)
ax_heatmap.set_yticklabels(_yticks.index)

_xticks = _col_coords.sort_values()
ax_heatmap.set_xticks(_xticks.values)
ax_heatmap.set_xticklabels(_xticks.index, rotation=90)


# Add legends
ax_legend = fig.add_subplot(gs[0,0])

# Colourbar
cax = ax_legend.inset_axes([0.0, 0.7, 1.0, 0.3])
fig.colorbar(_colours, ax=ax_legend, cax=cax, orientation='horizontal')
cax.set_title(_title)
cax.minorticks_on()


# P-val legend
ax_legend.text(0.5, 0.4, 'MWU p-val (B/H):', ha='center', va='center')
assert len(_pval_kwarg_dict) == 3
ax_legend.scatter(0.25, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[0]])
ax_legend.scatter(0.5, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[1]])
ax_legend.scatter(0.75, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[2]])

ax_legend.text(0.25, 0.15, '<0.01', ha='center', va='top')
ax_legend.text(0.5, 0.15, '<0.05', ha='center', va='top')
ax_legend.text(0.75, 0.15, 'n.s', ha='center', va='top')

ax_legend.set_ylim([0, 1])
ax_legend.set_xlim([0, 1])
ax_legend.axis('off')

plt.savefig(
    os.path.join(output_plot_dir, f'pairwise_scatterplot_{param_cell_line_safe}_{param_factors_x_safe}__vs__{param_factors_y_safe}_mwu_heatmap.pdf'),
    bbox_inches='tight'
)