In [None]:
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
# %matplotlib widget

import os
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm
sns.set_palette(['#1E1E1E', '#BB3524', '#F5D54A', '#384827', '#282F44'])
sns.set_context('paper')
sns.set_style({'axes.axisbelow': True, 
               'axes.edgecolor': '.15',
               'axes.facecolor': 'white',
               'axes.grid': False, 
               'axes.labelcolor': '.15', 
               'figure.facecolor': 'white', 
               'grid.color': '.15',
               'grid.linestyle': ':', 
               'grid.alpha': .5, 
               'image.cmap': 'Greys', 
               'legend.frameon': False, 
               'legend.numpoints': 1, 
               'legend.scatterpoints': 1,
               'lines.solid_capstyle': 'butt', 
               'axes.spines.right': False, 
               'axes.spines.top': False,  
               'text.color': '.15',  
               'xtick.top': False, 
               'ytick.right': False, 
               'xtick.color': '.15',
               'xtick.direction': 'out', 
               'ytick.color': '.15', 
               'ytick.direction': 'out', 
              })


import matplotlib

FONT_SIZE_PT = 5
matplotlib.rcParams['font.family'] = 'Arial'
matplotlib.rcParams['font.size'] = FONT_SIZE_PT
matplotlib.rcParams['axes.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['axes.titlesize'] = FONT_SIZE_PT
matplotlib.rcParams['figure.titlesize'] = FONT_SIZE_PT
matplotlib.rcParams['xtick.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['ytick.labelsize'] = FONT_SIZE_PT
matplotlib.rcParams['legend.fontsize'] = FONT_SIZE_PT
matplotlib.rcParams['legend.title_fontsize'] = FONT_SIZE_PT

matplotlib.rcParams['xtick.major.size'] = matplotlib.rcParams['ytick.major.size'] = 2
matplotlib.rcParams['xtick.major.width'] = matplotlib.rcParams['ytick.major.width'] = 0.5


matplotlib.rcParams['xtick.minor.size'] = matplotlib.rcParams['ytick.minor.size'] = 1

matplotlib.rcParams['xtick.minor.width'] = matplotlib.rcParams['ytick.minor.width'] = 0.5

matplotlib.rcParams['axes.linewidth'] = 0.5
matplotlib.rcParams['lines.linewidth'] = 0.5
matplotlib.rcParams['grid.linewidth'] = 0.25
matplotlib.rcParams['patch.linewidth'] = 0.25
matplotlib.rcParams['lines.markeredgewidth'] = 0.25
matplotlib.rcParams['lines.markersize'] = 2

FIVE_MM_IN_INCH = 0.19685
DPI = 600
matplotlib.rcParams['figure.figsize'] = (10 * FIVE_MM_IN_INCH, 9 * FIVE_MM_IN_INCH)
matplotlib.rcParams['savefig.dpi'] = DPI
matplotlib.rcParams['figure.dpi'] = DPI // 4


#http://phyletica.org/matplotlib-fonts/
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

pd.set_option("display.max_columns", 200)

In [None]:
HAVE_SNAKEMAKE = 'snakemake' in locals()

if HAVE_SNAKEMAKE:
    # Code assumes a list..
    input_consolidated_stats = [snakemake.input.consolidated_tsv]
    
    param_correlation_method = snakemake.params['correlation_method']
    param_input_header_separator = snakemake.params.get('input_header_sep', '__')
    threads = snakemake.threads
    
    output_stats_mi_tsv = snakemake.output.stats_mi_tsv
    output_stats_corr_tsv = snakemake.output.stats_corr_tsv
    
    output_heatmap_mi_pdf = snakemake.output.heatmap_mi_pdf
    output_heatmap_corr_pdf = snakemake.output.heatmap_corr_pdf
    output_diagnostic_plots = snakemake.output.diagnostic_plots
    
    param_hue_palette_factor_type = snakemake.params.get('hue_palette_factor_type', None)
    param_shape_palette_factor_type = snakemake.params.get('shape_palette_factor_type', None)
    
    
else:
    print("No snakemake -- DEBUG MODE")
    
    _OUTDIR = '.nb-testing-outputs'
    if not os.path.isdir(_OUTDIR):
        os.makedirs(_OUTDIR)
    
    _bin_size = 1000
    _pseudocount = 100
    _min_periods = 1
    
    input_consolidated_stats = []
    for _cell_line in ['K562']:
        input_consolidated_stats.append(f'../../output/final/analysis/params_{_bin_size}bp_pc_{_pseudocount}_mp_{_min_periods}/{_cell_line}/consolidated_tables/bedstats_consolidated_{_cell_line}_{_bin_size}bp_params_pc_{_pseudocount}_mp_{_min_periods}_from_bed.csv.gz')
       
       
    param_correlation_method = 'kendall'
    param_input_header_separator = '__'
    
    output_heatmap_mi_pdf = os.path.join(_OUTDIR, 'heatmap.mi.pdf')
    output_heatmap_corr_pdf = os.path.join(_OUTDIR, 'heatmap.corr.pdf')
    
    output_stats_mi_tsv = os.path.join(_OUTDIR, 'stats.mi.tsv')
    output_stats_corr_tsv = os.path.join(_OUTDIR, 'stats.corr.tsv')
    
    output_diagnostic_plots = os.path.join(_OUTDIR, 'diagnostic')
    
    param_hue_palette_factor_type = None
    param_shape_palette_factor_type = None
    
    threads = 8
    
param_fdr_method = 'fdr_bh'
param_alpha = 0.05

In [None]:
%load_ext autoreload
%autoreload 1

In [None]:
%aimport helpers

In [None]:
# Create output directory if we need to
if not os.path.isdir(output_diagnostic_plots):
    os.makedirs(output_diagnostic_plots)

In [None]:


if param_hue_palette_factor_type is None:
    param_hue_palette_factor_type = helpers.get_default_hue_palette_factor_type()
    
if param_shape_palette_factor_type is None:
    param_shape_palette_factor_type = helpers.get_default_shape_palette_factor_type()
    
print(f'{param_hue_palette_factor_type=}')
print(f'{param_shape_palette_factor_type}')

# Input

So the previous step of the pipeline was aimed to assemble the data into an easily-digestable format.

The goal of this notebook is to provide some sort of summary statistic regarding the relationships between MARCS data and the ChIP-seq data.
Particularly, we will be interested to see whether on average, a particular histone/chromatin feature ChIP-seq explains more or less entropy of proteins associated with MARCS feature compared to a random protein.

This is what the data looks like as we import it:

In [None]:
data = {}

for filename in input_consolidated_stats:
    basename = os.path.basename(filename)
    
    df = pd.read_csv(filename, index_col=0)
    df.columns = pd.MultiIndex.from_tuples(
        [
            c.split(param_input_header_separator) for c in df.columns
        ],
        names=['header', 'column']
    )
    
    data[basename] = df

In [None]:
next(iter(data.values()))

In [None]:
next(iter(data.values())).dtypes

As the MARCS annotations are aggregated by `marcs_gene_label` column, 
and the latter is aggregated by `factor`, in order to avoid some of the double counting, 
we will reaggregate the dataframe by Factor column, taking the mean:

In [None]:
data_reaggregated = {}

for k, df in data.items():
    # see helpers.py
    data_reaggregated[k] = helpers.reaggregate_by_factor(df, factor_col=('metadata', 'factor'))
    data_reaggregated[k]['marcs_feature_significant_category'] = data_reaggregated[k]['marcs_feature_significant_category'].fillna('No data')

The aggregatted data looks like this (see the index by `factor`)

In [None]:
next(iter(data_reaggregated.values()))

We will now try to summarise this data.

Primarily we will be doing a Mann-Whitney-U statistical test on the scores.
We will be comparing a MARCS category, e.g. "Recruited by H3K4me3", to proteins in "other" category, specifically proteins neither recruited, nor excluded by H3K4me3, and proteins that we have no estimate for.
For visualisation purposes, we will additionally compute the mean log2 difference between the scores for proteins in MARCS category vs scores in the control category.


This is implemented in the get stats function in `helpers.py`

In [None]:
helpers.get_stats??

The `normalised_mi` stats can be computed in a fairly easy way using this `get_stats` function.
We do not need to worry about the normalised_mi being equal to zero due to smoothing upstream.

In [None]:
recruited_group = 'Strongly recruited'
excluded_group = 'Strongly excluded'

In [None]:
df['normalised_mi'].min().min()

In [None]:
for df in data_reaggregated.values():
    # Make sure there are no zero MIs
    assert df['normalised_mi'].min().min() > 0

stats_mi = {
    k: helpers.get_stats(
        df, 
        column='normalised_mi', 
        control_groups=['Neither', 'No data'], 
        test_groups=[recruited_group, excluded_group],
        fdr_method=None # don't correct at this point
    )
    for k, df in data_reaggregated.items()
}

# concatenate the results from multiple reaggregated datasets
stats_mi = pd.concat(stats_mi.values(), keys=stats_mi.keys())
stats_mi.index.names = ['dataset'] + list(stats_mi.index.names[1:])

# Have easier indices if we only have one dataset...
if stats_mi.index.get_level_values('dataset').nunique() == 1:
    stats_mi = stats_mi.droplevel('dataset')

# Jointly adjust p-vals where they are not null
# Correct here
stats_mi = helpers.adjust_pvals(stats_mi, fdr_method=param_fdr_method, fdr_alpha=param_alpha).sort_values(by='p-val corrected')

stats_mi

In [None]:
stats_mi.to_csv(output_stats_mi_tsv, sep='\t')

Correlation can be exactly equal to zero. In such cases to avoid division by zero we need to replace these values with a small other value.
In this particualr case we will use the smallest absolute non-zero correlation from the dataset:

In [None]:
stats_corr = {}

for k, df in data_reaggregated.items():
    min_nonzero_abs_corr = df[f'{param_correlation_method}_correlation'].stack().abs()
    min_nonzero_abs_corr = min_nonzero_abs_corr[min_nonzero_abs_corr > 0].min()
    
    print(f'{k=}: {min_nonzero_abs_corr=}')
        
    stats_corr[k] = helpers.get_stats(
        df, 
        column=f'{param_correlation_method}_correlation', 
        control_groups=['Neither', 'No data'], 
        test_groups=[recruited_group, excluded_group], 
        log2=False,
        fdr_method=None, # Don't correct here
    )

    
# concatenate the results from multiple reaggregated datasets
stats_corr = pd.concat(stats_corr.values(), keys=stats_corr.keys())
stats_corr.index.names = ['dataset'] + list(stats_corr.index.names[1:])

# Have easier indices if we only have one dataset...
if stats_corr.index.get_level_values('dataset').nunique() == 1:
    stats_corr = stats_corr.droplevel('dataset')
    
# Jointly adjust p-vals where they are not null
stats_corr = helpers.adjust_pvals(stats_corr, fdr_method=param_fdr_method, fdr_alpha=param_alpha).sort_values(by='p-val corrected')

stats_corr

In [None]:
stats_corr.to_csv(output_stats_corr_tsv, sep='\t')

Now that we have the stats dataframe we can plot the results.
We will use heatmap to visualise the matrices of the results.

Let's prepare these matrices:

In [None]:
matrices = {}
matrix_masks = {}
matrix_pvals = {}

main_matrix = 'mi'

for name, stat_df, diff_col in [('mi', stats_mi, 'mean_log2_diff'), ('corr', stats_corr, 'mean_diff')]:
    _df = stat_df.copy()
    
    # We can make the matrix wide by unstacking marcs_feature and group
    matrices[name] = _df[diff_col].unstack(['marcs_feature', 'group'])
    
    matrix_masks[name] = ~(_df['significant'].unstack(['marcs_feature', 'group']).fillna(False))
    matrix_pvals[name] = _df['p-val corrected'].unstack(['marcs_feature', 'group'])

# Equalise indices
for name in matrices.keys():
    
    for dict_ in [matrices, matrix_masks, matrix_pvals]:
        dict_[name] =  dict_[name].reindex(matrices[main_matrix].index, axis=0).reindex(matrices[main_matrix].columns, axis=1)
    
    # Mask=True implies "not significant", therefore fill should be with True
    matrix_masks[name] = matrix_masks[name].fillna(True)
    
# Let's drop columns that are always null (in all matrices)
always_null = pd.Series(True, index=matrices['mi'].columns)
for m in matrices.values():
    always_null &= m.isnull().all()

to_drop = always_null[always_null].index
print(f"Dropping these columns as they are always null:\n{to_drop}")
for name in matrices.keys():
    matrices[name] = matrices[name].loc(axis=1)[~always_null]
    matrix_masks[name] = matrix_masks[name].loc(axis=1)[~always_null]
    matrix_pvals[name] = matrix_pvals[name].loc(axis=1)[~always_null]
    
    
# Let's also drop the null rows in MI matrix

always_null_row = matrices['mi'].isnull().all(axis=1)
for name in matrices.keys():
    matrices[name] = matrices[name].loc[~always_null_row]
    matrix_masks[name] = matrix_masks[name].loc[~always_null_row]
    matrix_pvals[name] = matrix_pvals[name].loc[~always_null_row]

Now that we have the matrices, let's compute the linkage for their rows and columns.


We will be clustering only the `main_matrix`. The earlier version was concatenating the matrices but that is counterproductive

We will be using correlation distance for both rows and columns
Complete linkage for both with optimal leaf ordering.

NaNs will be filled with zero

In [None]:
from scipy.cluster import hierarchy as hcluster

In [None]:
row_concatenated_matrix = matrices[main_matrix] # pd.concat(matrices.values(), keys=matrices.keys(), axis=1)
col_concatenated_matrix = matrices[main_matrix] # pd.concat(matrices.values(), keys=matrices.keys(), axis=0)

linkage_rows = hcluster.linkage(row_concatenated_matrix.fillna(0), metric='correlation', method='complete', optimal_ordering=True)
linkage_rows_order = row_concatenated_matrix.index[hcluster.dendrogram(linkage_rows, no_plot=True)['leaves']]

linkage_cols = hcluster.linkage(col_concatenated_matrix.T.fillna(0), metric='correlation', method='complete', optimal_ordering=True)
linkage_cols_order = col_concatenated_matrix.columns[hcluster.dendrogram(linkage_cols, no_plot=True)['leaves']]


# Seaborn Heatmaps

Now that we have the matrices we can simply plot them.
First, plot the (rather ugly) seaborn heatmaps to make sure we're not plotting the data wrong.

MI Matrix:

In [None]:
_name = 'mi'
_matrix = matrices[_name]
_mask = matrix_masks[_name]

# Helps with interpretation of labels
if isinstance(_matrix.index, pd.MultiIndex):
    yticklabels = ['-'.join([x[1], x[0]]) for x in _matrix.index]
else:
    yticklabels = _matrix.index
    
assert _matrix.columns.equals(_mask.columns)
assert _matrix.index.equals(_mask.index)

_annot = _mask.applymap(lambda x: '*' if not x else '')

_row_linkage = linkage_rows
_col_linkage = linkage_cols
                  
_cmap = sns.clustermap(
    _matrix.fillna(0), #mask=_mask, 
    annot=_annot,
   cmap='RdBu_r',
   row_linkage=_row_linkage,
   col_linkage=_col_linkage,
   figsize=(FIVE_MM_IN_INCH*30, FIVE_MM_IN_INCH*27),
   yticklabels=yticklabels,
   xticklabels=1,
   center=0,
   fmt='',
   linewidth=0.1, linecolor='black',
   robust=True,
)

_cmap.ax_heatmap.xaxis.set_tick_params(length=0)
_cmap.ax_heatmap.yaxis.set_tick_params(length=0)
_cmap.cax.set_ylabel(f"Average difference in\n {_name}, log2")
_cmap.ax_col_dendrogram.set_title(f"{_name}")

# _cmap.savefig(output_heatmap_mi_pdf)

Correlation

In [None]:
_name = 'corr'
_matrix = matrices[_name]
_mask = matrix_masks[_name]

# Helps with interpretation of labels
if isinstance(_matrix.index, pd.MultiIndex):
    yticklabels = ['-'.join([x[1], x[0]]) for x in _matrix.index]
else:
    yticklabels = _matrix.index
    

assert _matrix.columns.equals(_mask.columns)
assert _matrix.index.equals(_mask.index)

_annot = _mask.applymap(lambda x: '*' if not x else '')

_row_linkage = linkage_rows
_col_linkage = linkage_cols
                  
_cmap = sns.clustermap(
    _matrix.fillna(0), #mask=_mask, 
    annot=_annot,
   cmap='RdBu_r',
   row_linkage=_row_linkage,
   col_linkage=_col_linkage,
   figsize=(FIVE_MM_IN_INCH*30, FIVE_MM_IN_INCH*30),
   yticklabels=yticklabels,
   xticklabels=1,
   center=0,
   fmt='',
   linewidth=0.1, linecolor='black',
   robust=True,
)

_cmap.ax_heatmap.xaxis.set_tick_params(length=0)
_cmap.ax_heatmap.yaxis.set_tick_params(length=0)
_cmap.cax.set_ylabel(f"Average difference in\n {_name}, log2")
_cmap.ax_col_dendrogram.set_title(f"{_name}")

# _cmap.savefig(output_heatmap_corr_pdf)

# Styled heatmaps

Now let's redo the heatmaps as styled scatterplots 

In [None]:
from matplotlib.gridspec import GridSpec

_name = 'mi'
_title = 'normalised MI'
_matrix = matrices[_name].copy()
_mask = matrix_masks[_name].copy()
_pvals = matrix_pvals[_name].copy()

_row_order = linkage_rows_order.copy()
_col_order = linkage_cols_order.copy()

_row_linkage = linkage_rows
_col_linkage = linkage_cols

_row_coords = pd.Series(
    np.arange(len(_row_order), 0, -1) * 10.0 - 5, # the weird multiplication are due to matplotlib's dendrogram madness
    # Reversed order matches the dendrogram better
    index=list(reversed(_row_order)),
    name='row_coord'
)

_col_coords = pd.Series(
    np.arange(1, len(_col_order)+1) * 10.0 - 5, # the weird multiplications are due to matplotlib's dendrogram madness
    index=_col_order,
    name='col_coord'
)

# Don't mess with the wide format, convert all into narrow
# To do this make sure we have normal indices column names
for _df in [_matrix, _mask, _pvals, _row_coords, _col_coords]:
    if isinstance(_df.index, pd.MultiIndex):
        _df.index = ['-'.join(ix) for ix in _df.index]
    
    if isinstance(_df, pd.DataFrame) and isinstance(_df.columns, pd.MultiIndex):
        _df.columns = ['-'.join(ix) for ix in _df.columns]

# A couple operations to force the matrix into long format:
_matrix_long = pd.concat([_matrix, _mask, _pvals], keys=['matrix', 'mask', 'pvals'])
_matrix_long = _matrix_long.stack()
_matrix_long.index.names = ['measurement', 'row_name', 'col_name']
_matrix_long = _matrix_long.unstack('measurement')
_matrix_long = _matrix_long.join(_row_coords, on='row_name')
_matrix_long = _matrix_long.join(_col_coords, on='col_name')

# Bin the p-values
# If you change the bins, change the legend code below too
_matrix_long['pvals_binned'] = pd.cut(_matrix_long['pvals'], bins=[0, 0.01, 0.05, 1.0])
assert len(_matrix_long.pvals_binned.cat.categories) == 3
_pval_kwarg_dict = dict(zip(_matrix_long.pvals_binned.cat.categories, [dict(s=30, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=20, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=10, marker='s', edgecolor='#CDCDCD', linewidth=0.45)]))
_ns_pval_group = _matrix_long.pvals_binned.cat.categories[-1]

# Force the colour scale for min and max
_vmin = -2.1 # don't make this a round number (will help with the ticks)
_vmax = 2.1


_cmap = 'RdBu_r'

n_rows, n_cols = _matrix.shape

fig = plt.figure(
    figsize=(
        # four mm * (columns + dendrogram) + [labels]
        4/5*FIVE_MM_IN_INCH*(n_cols+5) + FIVE_MM_IN_INCH*5, 
        # four mm * (rows + dendrogram)
        4/5*FIVE_MM_IN_INCH*(n_rows+5+5),
    ),
    constrained_layout=True,
)

gs = GridSpec(3, 2, width_ratios=[5, n_cols], height_ratios=[5, n_rows, 5], wspace=0.05, hspace=0.05)

ax_heatmap = fig.add_subplot(gs[1,1])

ax_heatmap.xaxis.tick_top()
ax_heatmap.yaxis.tick_right()

sns.despine(ax=ax_heatmap, bottom=False, right=False, top=False, left=False)
ax_heatmap.xaxis.set_tick_params(length=0)
ax_heatmap.yaxis.set_tick_params(length=0)

# Draw heatmap
for _pval_group, _submatrix in _matrix_long.groupby('pvals_binned'):
    
     _colours = ax_heatmap.scatter(
        _submatrix['col_coord'], 
        _submatrix['row_coord'], 
        c=_submatrix['matrix'], 
        vmin=_vmin, vmax=_vmax, 
        cmap=_cmap,
        **_pval_kwarg_dict[_pval_group],
    )


# Draw dendrograms
ax_col_dendrogram = fig.add_subplot(gs[2,1], sharex=ax_heatmap)
ax_row_dendrogram = fig.add_subplot(gs[1,0], sharey=ax_heatmap)

for ax, link, orient in [
    (ax_col_dendrogram, _col_linkage, 'bottom'),
    (ax_row_dendrogram, _row_linkage, 'left')
]:
    
    hcluster.dendrogram(
        link, 
        ax=ax, 
        orientation=orient, 
        link_color_func= lambda x: 'black'
    )
    
    sns.despine(ax=ax, left=True, bottom=True, right=True, top=True)
    
    ax.xaxis.set_tick_params(length=0)
    ax.yaxis.set_tick_params(length=0)
    
    for tick in ax.get_yticklabels():
        tick.set_visible(False)
    
    for tick in ax.get_xticklabels():
        tick.set_visible(False)
    
    
# Set ticks
_yticks = _row_coords.sort_values()
ax_heatmap.set_yticks(_yticks.values)
ax_heatmap.set_yticklabels(_yticks.index)

_xticks = _col_coords.sort_values()
ax_heatmap.set_xticks(_xticks.values)
ax_heatmap.set_xticklabels(_xticks.index, rotation=90)

# Inverting yaxis matches seaborn 
ax_heatmap.invert_yaxis() 

# Add legends
ax_legend = fig.add_subplot(gs[0,0])

# Colourbar
cax = ax_legend.inset_axes([0.0, 0.7, 1.0, 0.3])
fig.colorbar(_colours, ax=ax_legend, cax=cax, orientation='horizontal')
cax.set_title(f"Difference in {_title}, log2")
cax.minorticks_on()


# P-val legend
ax_legend.text(0.5, 0.4, 'MWU p-val (B/H):', ha='center', va='center')
assert len(_pval_kwarg_dict) == 3
ax_legend.scatter(0.25, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[0]])
ax_legend.scatter(0.5, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[1]])
ax_legend.scatter(0.75, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[2]])

ax_legend.text(0.25, 0.15, '<0.01', ha='center', va='top')
ax_legend.text(0.5, 0.15, '<0.05', ha='center', va='top')
ax_legend.text(0.75, 0.15, 'n.s', ha='center', va='top')

ax_legend.set_ylim([0, 1])
ax_legend.set_xlim([0, 1])
ax_legend.axis('off')


plt.savefig(output_heatmap_mi_pdf, bbox_inches='tight')



In [None]:
from matplotlib.gridspec import GridSpec

_name = 'corr'
_title = f'{param_correlation_method}\ncorrelation'
_matrix = matrices[_name].copy()
_mask = matrix_masks[_name].copy()
_pvals = matrix_pvals[_name].copy()

_row_order = linkage_rows_order.copy()
_col_order = linkage_cols_order.copy()

_row_linkage = linkage_rows
_col_linkage = linkage_cols

_row_coords = pd.Series(
    np.arange(len(_row_order), 0, -1) * 10.0 - 5, # the weird multiplication are due to matplotlib's dendrogram madness
    # Reversed order matches the dendrogram better
    index=list(reversed(_row_order)),
    name='row_coord'
)

_col_coords = pd.Series(
    np.arange(1, len(_col_order)+1) * 10.0 - 5,  # the weird multiplication are due to matplotlib's dendrogram madness
    index=_col_order,
    name='col_coord'
)

# Don't mess with the wide format, convert all into narrow
# To do this make sure we have normal indices column names
for _df in [_matrix, _mask, _pvals, _row_coords, _col_coords]:
    if isinstance(_df.index, pd.MultiIndex):
        _df.index = ['-'.join(ix) for ix in _df.index]
    
    if isinstance(_df, pd.DataFrame) and isinstance(_df.columns, pd.MultiIndex):
        _df.columns = ['-'.join(ix) for ix in _df.columns]

# A couple operations to force the matrix into long format:
_matrix_long = pd.concat([_matrix, _mask, _pvals], keys=['matrix', 'mask', 'pvals'])
_matrix_long = _matrix_long.stack()
_matrix_long.index.names = ['measurement', 'row_name', 'col_name']
_matrix_long = _matrix_long.unstack('measurement')
_matrix_long = _matrix_long.join(_row_coords, on='row_name')
_matrix_long = _matrix_long.join(_col_coords, on='col_name')

# Bin the p-values
# If you change the bins, change the legend code below too
_matrix_long['pvals_binned'] = pd.cut(_matrix_long['pvals'], bins=[0, 0.01, 0.05, 1.0])
assert len(_matrix_long.pvals_binned.cat.categories) == 3
_pval_kwarg_dict = dict(zip(_matrix_long.pvals_binned.cat.categories, [dict(s=30, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=20, marker='s', edgecolor='#525252', linewidth=0.45), dict(s=10, marker='s', edgecolor='#CDCDCD', linewidth=0.45)]))
_ns_pval_group = _matrix_long.pvals_binned.cat.categories[-1]

# Force the colour scale for min and max
_vmin = -0.25 # don't make this a round number (will help with the ticks)
_vmax = 0.25
_cmap = 'RdBu_r'

n_rows, n_cols = _matrix.shape

fig = plt.figure(
    figsize=(
        # four mm * (columns + dendrogram) + [labels]
        4/5*FIVE_MM_IN_INCH*(n_cols+5) + FIVE_MM_IN_INCH*5, 
        # four mm * (rows + dendrogram)
        4/5*FIVE_MM_IN_INCH*(n_rows+5+5),
    ),
    constrained_layout=True,
)

gs = GridSpec(3, 2, width_ratios=[5, n_cols], height_ratios=[5, n_rows, 5], wspace=0.05, hspace=0.05)

ax_heatmap = fig.add_subplot(gs[1,1])

ax_heatmap.xaxis.tick_top()
ax_heatmap.yaxis.tick_right()

sns.despine(ax=ax_heatmap, bottom=False, right=False, top=False, left=False)
ax_heatmap.xaxis.set_tick_params(length=0)
ax_heatmap.yaxis.set_tick_params(length=0)

# Draw heatmap
for _pval_group, _submatrix in _matrix_long.groupby('pvals_binned'):
    
     _colours = ax_heatmap.scatter(
        _submatrix['col_coord'], 
        _submatrix['row_coord'], 
        c=_submatrix['matrix'], 
        vmin=_vmin, vmax=_vmax, 
        cmap=_cmap,
        **_pval_kwarg_dict[_pval_group],
    )


# Draw dendrograms
ax_col_dendrogram = fig.add_subplot(gs[2,1], sharex=ax_heatmap)
ax_row_dendrogram = fig.add_subplot(gs[1,0], sharey=ax_heatmap)

for ax, link, orient in [
    (ax_col_dendrogram, _col_linkage, 'bottom'),
    (ax_row_dendrogram, _row_linkage, 'left')
]:
    
    hcluster.dendrogram(
        link, 
        ax=ax, 
        orientation=orient, 
        link_color_func= lambda x: 'black'
    )
    
    sns.despine(ax=ax, left=True, bottom=True, right=True, top=True)
    
    ax.xaxis.set_tick_params(length=0)
    ax.yaxis.set_tick_params(length=0)
    
    for tick in ax.get_yticklabels():
        tick.set_visible(False)
    
    for tick in ax.get_xticklabels():
        tick.set_visible(False)
    
    
# Set ticks
_yticks = _row_coords.sort_values()
ax_heatmap.set_yticks(_yticks.values)
ax_heatmap.set_yticklabels(_yticks.index)

_xticks = _col_coords.sort_values()
ax_heatmap.set_xticks(_xticks.values)
ax_heatmap.set_xticklabels(_xticks.index, rotation=90)

# Add legends
ax_legend = fig.add_subplot(gs[0,0])

# Colourbar
cax = ax_legend.inset_axes([0.0, 0.7, 1.0, 0.3])
fig.colorbar(_colours, ax=ax_legend, cax=cax, orientation='horizontal')
cax.set_title(f"Mean difference in {_title}")
cax.minorticks_on()

ax_heatmap.invert_yaxis()

# P-val legend
ax_legend.text(0.5, 0.4, 'MWU p-val (B/H):', ha='center', va='center')
assert len(_pval_kwarg_dict) == 3
ax_legend.scatter(0.25, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[0]])
ax_legend.scatter(0.5, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[1]])
ax_legend.scatter(0.75, 0.25, color='#bdbdbd', **_pval_kwarg_dict[_matrix_long.pvals_binned.cat.categories[2]])

ax_legend.text(0.25, 0.15, '<0.01', ha='center', va='top')
ax_legend.text(0.5, 0.15, '<0.05', ha='center', va='top')
ax_legend.text(0.75, 0.15, 'n.s', ha='center', va='top')

ax_legend.set_ylim([0, 1])
ax_legend.set_xlim([0, 1])
ax_legend.axis('off')



plt.savefig(output_heatmap_corr_pdf, bbox_inches='tight')

Additionally, we should make some make diagnostic scatterplots, one for each chip factor

In [None]:
marcs_feature_order = helpers.MARCS_FEATURE_ORDER

In [None]:
next(iter(data_reaggregated.values()))['marcs_feature_significant_category'].stack().unique()

In [None]:
# hue_palette = {
#     'Neither': '#1E1E1E',
#     'No data': '#1E1E1E',
#     'Recruited': '#b2182b',
#     'Excluded': '#2166ac'
# }

shape_map = {
    'Neither': dict(edgecolor='#CDCDCD', linewidth=0.35, zorder=0, s=4, alpha=.8),
    'No data': dict(zorder=0, s=5, marker='x', alpha=.8),
    recruited_group: dict(edgecolor='#525252', linewidth=.35, zorder=1, s=8, alpha=.8),
    excluded_group: dict(edgecolor='#525252', linewidth=.35, zorder=1, s=8, alpha=.8),
}

FEATURES_PALETTE = {
    'Neither': '#bdbdbd',
    'No data': '#bdbdbd',
    recruited_group: '#BA5047',
    excluded_group: '#4B82B6',
}

In [None]:
from matplotlib.gridspec import GridSpec
from numpy.random import RandomState
from matplotlib.backends.backend_pdf import PdfPages
# from contextlib import suppress

import gc

def make_diagnostic_plots(data, *, filename, chip_factor, palette_features, shape_map_features, 
                          marcs_feature_order=marcs_feature_order, approx_number_of_labels=20, logx = False, no_y=False):
    
    x_col = ('normalised_mi', chip_factor)
    y_col = (f'{param_correlation_method}_correlation', chip_factor)

    ylabel = "{} Correlation".format(param_correlation_method.capitalize())

    if not logx:
         xlabel = "Normalised MI"
    else:
        data = data.copy()
        data['normalised_mi'] = data['normalised_mi'].apply(np.log2)
        xlabel = "Normalised MI (log2)"
    
    lines = ['y=0']
    
    yticks = True

    df = data.reset_index().copy()
    
    # If there is no y axis (no correlation), 
    # Plot ranks of x column instead
    if no_y or y_col not in df.columns:
        df[('rank_x', 'rank_x')] = df[x_col].rank(ascending=True, pct=True)

        y_col = ('rank_x', 'rank_x')
        ylabel = 'Rank'
        lines = None
        yticks = False

        no_y = True
    
    with PdfPages(filename) as pdf:
        
#    with suppress():
        
        for marcs_feature in [None] + list(marcs_feature_order):
            
            fig = plt.figure(figsize=(10*FIVE_MM_IN_INCH, 10*FIVE_MM_IN_INCH))
            ax_main = plt.gca()
            
            if marcs_feature is None:
           
                helpers.make_plot(
                    df=df,
                    x_col=x_col,
                    xlabel=xlabel,
                    y_col=y_col,
                    ylabel=ylabel,
                    annot_col=('metadata', 'factor'),
                    hue_col=('metadata', 'factor_type'),
                    shape_col=('metadata', 'factor_type'),
                    axes=ax_main,
                    legend=False,
                    title=chip_factor,
                    shape_map=param_shape_palette_factor_type,
                    hue_palette=param_hue_palette_factor_type,
                    approx_number_of_labels=approx_number_of_labels,
                    lines=lines,
                    do_not_change_limits=True,
                    adjust_text_kws=dict(lim=250),
                )            
            else:
                
                # Hack: Propagate factor types into feature significant category so non-proteins apear with their symbols and not "x"
                _df = df.copy()
                new_column = ('marcs_feature_significant_category_or_feature_type', marcs_feature)
                _df[new_column] = _df[('marcs_feature_significant_category', marcs_feature)].copy()
                
                not_protein = _df[('metadata', 'factor_type')] != 'protein'
                _df.loc[not_protein, new_column] = _df.loc[not_protein, ('metadata', 'factor_type')]
                
                _hue_col = _shape_col = new_column

                _ycol = y_col

                # I we don't have y, then make a "jitter" axis across the shape_map_features 
                _hue_palette_order = list(shape_map_features.keys())
                if no_y:
                    _df =  _df.loc[_df[('metadata', 'factor_type')] == 'protein'].copy()
                    _ycol = ('jitter_y', 'jitter_y')
                    _df[_ycol] = _df[_shape_col].apply(_hue_palette_order.index)

                    random = RandomState(42)
                    _df[_ycol] += random.uniform(-0.25, 0.25, len(_df))
                

                helpers.make_plot(
                    df=_df,
                    x_col=x_col,
                    xlabel=xlabel,
                    y_col=_ycol,
                    ylabel=None if no_y else ylabel,
                    annot_col=('metadata', 'factor'),
                    approx_number_of_labels=approx_number_of_labels,
                    # This is hacky....
                    label_strategy=set(_df[_df[_shape_col].isin([recruited_group, excluded_group])].index),
                    hue_col=_hue_col,
                    axes=ax_main,
                    legend=False,
                    title=f'{chip_factor} vs MARCS {marcs_feature}',
                    hue_palette={**param_hue_palette_factor_type, **palette_features},
                    shape_col=_shape_col,
                    shape_map={**param_shape_palette_factor_type, **shape_map_features},
                    lines=lines,
                    do_not_change_limits=True,
                    adjust_text_kws=dict(lim=100),
                )
                
            
            if not yticks:
                ax_main.set_yticks([])
        
            sns.despine(ax=ax_main, offset=3)
            
            pdf.savefig(bbox_inches='tight')
            # There's some sort of memory leak going on, let's try to avoid that
            plt.close(fig)
            plt.close('all')
            plt.close()
            gc.collect()
        
        
    # There's some sort of memory leak going on, let's try to avoid that
    plt.close('all')
    plt.close()
    gc.collect()

In [None]:
import re

for dataset, dataframe in tqdm(data_reaggregated.items(), total=len(data_reaggregated)):
    diagnostic_plot_columns = df['normalised_mi'].columns
    
    for col in tqdm(diagnostic_plot_columns):
        
        for logx in [False, True]:
            for no_y in [False, True]:
                if col.startswith('state') and not no_y:
                    continue
                    
                fname = re.sub('[^a-z0-9A-Z\-\_]+', '-', col)
                prefix = (re.sub('[^a-z0-9A-Z\-\_]+', '-', dataset) + '-') if len(data_reaggregated) > 1 else ''

                suffix = 'log2-scale' if logx else 'natural-scale'
                if no_y:
                    suffix += '-no_y'
                    
                full_filename = os.path.join(output_diagnostic_plots, f'diagnostic-{prefix}{fname}-{suffix}.pdf')
                                
                make_diagnostic_plots(
                    dataframe, chip_factor=col,
                    approx_number_of_labels=10 if no_y else 30,
                    palette_features=FEATURES_PALETTE,
                    shape_map_features=shape_map,
                    logx=logx,
                    no_y=no_y,
                    filename=full_filename,
                )
