# Visualization utilities and functions

To avoid cluttering the analysis notebooks

In [None]:
import ast
import csv
import glob
import math
import os
import re
from collections import Counter

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import numpy.polynomial.polynomial as poly
import pandas as pd
import seaborn as sns
import scipy.stats as stats
from matplotlib import cm
from matplotlib.colors import ListedColormap, TwoSlopeNorm, Normalize
from mpl_toolkits import mplot3d

%matplotlib inline
np.polynomial.set_default_printstyle('unicode')

## Visualization settings

In [None]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

# Set up the font
font_path = '/n/home10/msak/utils/helvetica/helvetica-light.ttf'
fm.fontManager.addfont(font_path)
#font_name = fm.FontProperties(fname=font_path).get_name()

In [None]:
SMALL_SIZE = 9
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE) # controls default text sizes

plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE, labelpad=5)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['figure.dpi'] = 300

In [None]:
_new_black = '#373737'
sns.set_theme(style='ticks', font_scale=0.6, rc={
    'font.family': 'Helvetica Light',
    'svg.fonttype': 'none',
    'text.usetex': False,
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'font.size': 9,
    'axes.labelsize': 12,
    'axes.titlesize': 12,
    'axes.labelpad': 4,
    'axes.linewidth': 0.8,
    'axes.titlepad': 8,
    'lines.linewidth': 1,
    'legend.fontsize': 10,
    'legend.title_fontsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'xtick.major.size': 2,
    'xtick.major.pad': 1,
    'xtick.major.width': 0.5,
    'ytick.major.size': 2,
    'ytick.major.pad': 1,
    'ytick.major.width': 0.5,
    'xtick.minor.size': 2,
    'xtick.minor.pad': 1,
    'xtick.minor.width': 0.5,
    'ytick.minor.size': 2,
    'ytick.minor.pad': 1,
    'ytick.minor.width': 0.5,

    # Avoid black unless necessary
    'text.color': _new_black,
    'patch.edgecolor': _new_black,
    'patch.force_edgecolor': False,
    'hatch.color': _new_black,
    'axes.edgecolor': _new_black,
    # 'axes.titlecolor': _new_black # should fallback to text.color
    'axes.labelcolor': _new_black,
    'xtick.color': _new_black,
    'ytick.color': _new_black

})

In [None]:
from matplotlib.colors import LinearSegmentedColormap

# Create a colormap with linear segments from white to red
colors = ['#999999', 'white', '#AC4B4E']
red_cmap = LinearSegmentedColormap.from_list('custom_red', colors[1:])
red_grey_cmap = LinearSegmentedColormap.from_list('custom_red_grey', colors)

colors = ['#999999', 'white', '#3e75c2']
blue_cmap = LinearSegmentedColormap.from_list('custom_blue', colors[1:])
blue_grey_cmap = LinearSegmentedColormap.from_list('custom_blue_grey', colors)

## Visualizing landscapes

Given a landscape solution file, what does the landscape (for doubles) look like?

In [None]:
import csv
import ast
from typing import Dict, Tuple, Union

def csv_to_kv(file: str,
              coop: bool = False
              ) -> Union[Dict[int, float], Tuple[Dict[Tuple[int, int], float],
                                                 Dict[Tuple[int, int], float]]]:
    """
    Convert a CSV file to key-value dictionaries.

    Parameters
    ----------
    file : str
        Path to the CSV file.
    coop : bool, default False
        Set to True when the file is a cooperativity matrix (“*_coop_mat.csv”).

    Returns
    -------
    dict | (dict, dict)
        • If coop is False  ->  {cat_id: rate}
        • If coop is True   ->  (positive_dict, negative_dict)
    """
    # These will be returned
    rates: Dict[int, float] = {}
    pos_dict: Dict[Tuple[int, int], float] = {}
    neg_dict: Dict[Tuple[int, int], float] = {}

    with open(file, newline='') as fh:
        reader = csv.reader(fh)

        for row in reader:
            # Skip completely empty rows
            if not row:
                continue

            # Attempt to detect and skip the header (first field not convertible)
            if coop:
                # coop file → key is a tuple string like "(3, 7)"
                try:
                    key = ast.literal_eval(row[0])
                    if not isinstance(key, tuple) or len(key) != 2:
                        # It was probably a header, skip it.
                        continue
                except (ValueError, SyntaxError):
                    continue
                value = float(row[1])
                if value >= 1:
                    pos_dict[key] = value
                else:
                    neg_dict[key] = value
            else:
                # non-coop file → key is an integer catalyst id
                try:
                    key = int(row[0])
                except ValueError:          # header or bad line
                    continue
                value = float(row[1])
                rates[key] = value

    return (pos_dict, neg_dict) if coop else rates

In [None]:
def viz_coop_zero_center(N, coop_file, simple=False):
    """
    Heat-map of pairwise cooperativity.
    • Negative cooperativity: values < 0
    • Positive cooperativity: values > 0
    Colour-bar is centred on 0 but NOT forced to be symmetric.
    
    Parameters:
    -----------
    N : int
        Size of the N×N matrix
    coop_file : str
        Path to the cooperativity data file
    simple : bool, optional
        If True, removes colorbar and x/y tick marks for a cleaner plot.
        Default is False.
    """
    # ------------------------------------------------------------- #
    # 1. Load the data into an N×N matrix (lower-triangle filled)
    # ------------------------------------------------------------- #
    heatmap = np.full((N, N), np.nan)           # start with NaNs
    for d in csv_to_kv(coop_file, coop=True):
        for (i, j), val in d.items():
            heatmap[j, i] = val - 1

    # ------------------------------------------------------------- #
    # 2. Colour map and normalisation
    # ------------------------------------------------------------- #
    def truncate_cmap(cmap='PuOr', lo=0.05, hi=0.95, n=256):
        base = plt.get_cmap(cmap) if isinstance(cmap, str) else cmap
        return mpl.colors.LinearSegmentedColormap.from_list(
            f'trunc({base.name},{lo:.2f},{hi:.2f})',
            base(np.linspace(lo, hi, n))
        )

    cmap        = truncate_cmap('PuOr')
    finite_vals = heatmap[~np.isnan(heatmap)]

    if finite_vals.size == 0:
        raise ValueError('No numeric values found in the data.')

    vmin, vmax = finite_vals.min(), finite_vals.max()

    # Use a diverging norm only if we have both + and – values
    if vmin < 1 < vmax:
        norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
    else:                      # all positive OR all negative
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

    # ------------------------------------------------------------- #
    # 3. Plot
    # ------------------------------------------------------------- #
    mask = np.triu(np.ones_like(heatmap, dtype=bool))

    plt.figure(figsize=(6, 6), dpi=300)
    ax = sns.heatmap(
        heatmap,
        mask=mask,
        cmap=cmap,
        norm=norm,
        square=True,
        linewidths=0.1,
        linecolor='#d7d7d7',
        annot=False,
        cbar=not simple,  # Hide colorbar if simple=True
        cbar_kws={'shrink': 0.6} if not simple else {}
    )

    if not simple:
        # Make tick labels start from 1 (only if not simple)
        ax.set_xticklabels([str(int(t.get_text()) + 1) for t in ax.get_xticklabels()])
        ax.set_yticklabels([str(int(t.get_text()) + 1) for t in ax.get_yticklabels()])
    else:
        # Remove all tick marks and labels for simple mode
        #ax.set_xticks([])
        #ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        # Add black frame
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.8)
            spine.set_edgecolor(_new_black)

    # Remove any automatic title that might be present
    ax.set_title('')         

    plt.tight_layout()
    plt.show()

In [None]:
def viz_pair_landscape(N, soln_file, label=None):
    # Initialize a 2D array with zeros
    df = pd.read_csv(soln_file)
    heatmap_data = np.ones((N, N))

    # Convert the 'combi' column from string to tuple and filter out tuples of length 2
    df['combi'] = df['Combi'].apply(eval)
    df = df[df['combi'].apply(len) == 2]

    for _, row in df.iterrows():
        pair = row['combi']
        heatmap_data[int(pair[1])][int(pair[0])] = row['k_eff']

    # Create a mask for the upper triangle
    mask = np.triu(np.ones_like(heatmap_data, dtype=bool))

    # Create a custom colormapa

    # Create the heatmap
    plt.figure(figsize=(6, 6))
    ax = sns.heatmap(heatmap_data, mask=mask, cmap=red_cmap, cbar=False, square=True, annot=False, linewidths=0.1, linecolor='#d7d7d7', fmt="")

    x_labels = ax.get_xticklabels()
    y_labels = ax.get_yticklabels()

    # Shift all the labels by 1
    x_labels_new = [str(int(label.get_text()) + 1) for label in x_labels]
    y_labels_new = [str(int(label.get_text()) + 1) for label in y_labels]

    # Set the new labels
    ax.set_xticklabels(x_labels_new)
    ax.set_yticklabels(y_labels_new)

    plt.title(label)
    plt.show()

## viz simulation results

In [None]:
def plot_tradeoff(filename, num_combi, col1, col2, color1, title=None):
    #previously plot_one-Pair_perf_csv
    # Read the CSV file into a pandas DataFrame
    df = pd.read_csv(filename)
    
    # Create a scatter plot of the two columns
    fig, ax1 = plt.subplots(figsize=(6, 6))
    ax1.scatter(df[col1], df[col2], color=color1)

    # Label the axes with the column names and set their color
    ax1.set_xlabel(col1)
    ax1.set_ylabel(col2)
    ax1.set_title(title)

    # Create a second y-axis that shares the same x-axis
    ax2 = ax1.twinx()
    ax2.scatter(df[col1], 1 - df[col2]/num_combi, color=color1)

    # Label the second y-axis and set its color

    # Invert the second y-axis
    ax1.invert_yaxis()
    ax2.set_ylabel('screening efficiency')

    # Show the plot
    plt.show()


In [7]:
def is_pareto_efficient(costs):
    is_efficient = np.ones(costs.shape[0], dtype=bool)
    for i, c in enumerate(costs):
        if is_efficient[i]:
            is_efficient[is_efficient] = np.any(costs[is_efficient] > c, axis=1)
            is_efficient[i] = True
    return is_efficient

In [None]:
def plot_tradeoffs(filenames, num_combi, col1, col2, labels, title=None, plot_pareto=False):
    '''
    Also plot the pareto front
    
    '''
    # Create a colormap
    #colormap = plt.cm.get_cmap('viridis', len(filenames))
    colormap = ['#00356B', '#A51C30', '#a4a4a4']

    # Create a figure
    fig, ax1 = plt.subplots(figsize=(5, 5), dpi=300)
    #ax2 = ax1.twinx()

    # Loop over the filenames
    for i, filename, label in zip(range(len(filenames)), filenames, labels):
        # Read the CSV file into a pandas DataFrame
        df = pd.read_csv(filename)

        
        df['eff'] = df[col2].apply(lambda x: 1 - x/num_combi)

        ax1.scatter(df[col1], df['eff'], color=colormap[i], alpha=0.2, edgecolor='none', label=None)
        if plot_pareto:
            pareto = is_pareto_efficient(df[[col1, 'eff']].values)
            print(pareto)
            #ax1.plot(df[col1][pareto], df['eff'][pareto], color=colormap[1], alpha=0.7)
            ax1.scatter(df[col1][pareto], df['eff'][pareto], color=colormap[i], alpha=0.6, edgecolors=colormap[i], label=f'Pareto front')
        # Create a scatter plot of the two columns
        
        #ax1.scatter(df[col1], df[col2], color=colormap[i], label=label, alpha=0.5)

        # Create a second y-axis that shares the same x-axis

    # Label the axes with the column names and set their color
    ax1.set_xlabel(col1)
    #ax1.set_ylabel(col2)
    ax1.set_title(title)
    ax1.set_ylim(0, 1)

    # Label the second y-axis and set its color
    ax1.set_ylabel('efficiency')

    # Invert the second y-axis
    #ax1.invert_yaxis()

    # Add a legend
    ax1.legend()

    # Show the plot
    plt.show()

In [None]:
def viz_batch_optim(infiles, N, metric_column='coop_accuracy', beta=1, viz_dist=True, plt_agg=False, report_params=False, coop_only=False):
    '''
    Given a number of output files, finds the best F-scoring in each run and plots distribution across that batch. 
    The two simultaneous objectives are accuracy/sensitivity and efficiency. We ignore any efficiency < 0 entries.

    Parameters:
    - metric_column: either 'coop_accuracy' or 'mean_sensitivity' to specify which metric to use in F-score
    - viz_dist: if False, don't plot or show anything at all
    - report_params: path to an output file that contains the params for the highest f-score runs
    - coop_only: if True, only consider rows where target_metric == 'coop'
    
    Returns:
    - mean_metric: mean accuracy or sensitivity across all top-scoring params
    - mean_efficiency: mean efficiency across all top-scoring params
    '''
    
    def calculate_f_score(row, beta=beta, metric_col=metric_column):
        return (1 + beta**2) * (row[metric_col] * row['efficiency']) / ((beta**2 * row[metric_col]) + row['efficiency'])

    top_scores = []
    top_metrics = []  # This will be either accuracies or sensitivities
    top_effcs = []
    top_params = []
    top_scores_dict = {}
    top_metric_dict = {}
    top_eff_dict = {}

    for infile in glob.glob(infiles):
        df = pd.read_csv(infile)

        # Filter by target_metric if coop_only is True
        if coop_only:
            df = df[df['target_metric'] == 'coop']
            if df.empty:
                print(f"Warning: No 'coop' target_metric entries found in {infile}")
                continue

        # Calculate efficiency
        df['efficiency'] = 1 - (df['num_pools'] / math.comb(N, 2))
        # Filter out rows where efficiency is less than 0
        df = df[df['efficiency'] >= 0]
        
        if df.empty:
            print(f"Warning: No valid entries after filtering in {infile}")
            continue
        
        # Calculate F-beta score
        df['score'] = df.apply(lambda row: calculate_f_score(row), axis=1)
        
        best_row = df.loc[df['score'].idxmax()]
    
        # Append to our lists
        top_score = best_row['score']
        top_metric = best_row[metric_column]
        top_eff = best_row['efficiency']
        
        # if top_score == 0:
        #     print(f'top_score is 0 in {infile}')
            
        top_scores.append(top_score)
        top_metrics.append(top_metric)
        top_effcs.append(top_eff)

        # Extract parameters (without num_top)
        params = tuple(best_row[['pool_size', 'num_meet', 'num_redun', 'target_metric']].tolist())
        top_params.append(params)

        # Store the score with the params
        if params in top_scores_dict:
            top_scores_dict[params].append(top_score)
            top_metric_dict[params].append(top_metric)
            top_eff_dict[params].append(top_eff)
        else:
            top_scores_dict[params] = [top_score]
            top_metric_dict[params] = [top_metric]
            top_eff_dict[params] = [top_eff]
    
    # Check if we have any valid data
    if not top_scores:
        print("Warning: No valid data found after all filtering")
        return None, None
    
# Visualization (only if viz_dist=True)
    if viz_dist:
        # Create a figure with subplots
        fig, (ax2, ax3) = plt.subplots(2, 1, figsize=(6, 8), dpi=300)

        top_metrics_array = np.array(top_metrics).flatten()
        # Plot the distribution of corresponding metrics
        sns.histplot(top_metrics_array, kde=True, ax=ax2, color='#0F4D92', edgecolor='#ffffff')
        metric_name = 'Accuracy' if metric_column == 'coop_accuracy' else 'Accuracy'
        ax2.set_title(f'{metric_name} for highest-F parameter sets')
        ax2.set_xlabel(metric_name.lower())
        ax2.set_ylabel('frequency')

        # Plot the distribution of corresponding efficiencies
        sns.histplot(top_effcs, kde=True, ax=ax3, color='#0F4D92', edgecolor='#ffffff')
        ax3.set_title('Efficiencies for highest-F parameter sets')
        ax3.set_xlabel('efficiency')
        ax3.set_ylabel('frequency')

        plt.tight_layout()
        plt.show()

        if plt_agg:
            plt.figure(figsize=(5, 4))
            scatter = plt.scatter(top_effcs, top_metrics, c=top_scores, cmap='viridis')
            plt.colorbar(scatter, label='F-beta score')
            plt.title(f'{metric_name} vs Efficiency for Highest F-beta scores')
            plt.xlabel('Efficiency')
            plt.ylabel(metric_name)
            plt.show()

    if report_params:
        # Count occurrences of each parameter set
        param_counts = Counter(top_params)

        report_data = []
        for params, count in param_counts.items():
            avg_score = round(np.mean(top_scores_dict[params]), 2)
            avg_metric = round(np.mean(top_metric_dict[params]), 2)
            avg_eff = round(np.mean(top_eff_dict[params]), 2)
            report_data.append(list(params) + [count, avg_score, avg_metric, avg_eff])

        metric_col_name = 'avg_accuracy' if metric_column == 'coop_accuracy' else 'avg_sensitivity'
        report_df = pd.DataFrame(report_data,
                                 columns=['pool_size', 'num_meet', 'num_redun', 'target_metric',
                                          'count', 'avg_f_score', metric_col_name, 'avg_efficiency'])

        # Sort by count in descending order, then by average score in descending order
        report_df = report_df.sort_values(['count', 'avg_f_score'], ascending=[False, False])

        # Save to CSV
        report_df.to_csv(report_params, index=False)
        print(f"Parameter report saved to {report_params}")
    
    # Always return the means
    return pd.Series(top_metrics).mean(), pd.Series(top_effcs).mean()

In [None]:
def heatmap_ae_tradeoff(csv_path, mode='accuracy'):
    """
    Create heatmap for either accuracy/efficiency or sensitivity/efficiency tradeoff.
    
    Parameters:
    - csv_path: path to CSV file
    - mode: either 'accuracy' or 'sensitivity'
    """
    df = pd.read_csv(csv_path)
    
    if mode == 'accuracy':
        expected = {"p_pos", "p_neg", "acc_disc", "efficiency"}
        metric_col = "acc_disc"
        title_prefix = "Accuracy"
    elif mode == 'sensitivity':
        expected = {"p_pos", "p_neg", "sensitivity", "efficiency"}
        metric_col = "sensitivity"
        title_prefix = "Sensitivity"
    else:
        raise ValueError("mode must be either 'accuracy' or 'sensitivity'")
    
    missing = expected.difference(df.columns)
    if missing:
        raise ValueError(f"CSV is missing columns: {sorted(missing)}")

    # Get all unique values and create complete grid
    x_vals = np.sort(df["p_pos"].unique())
    y_vals = np.sort(df["p_neg"].unique())
    
    # Create all possible combinations (complete grid)
    all_combinations = []
    for y in y_vals:
        for x in x_vals:
            all_combinations.append({'p_pos': x, 'p_neg': y})
    
    complete_grid = pd.DataFrame(all_combinations)
    
    # Merge with original data to fill missing combinations with NaN
    df_complete = complete_grid.merge(df, on=['p_pos', 'p_neg'], how='left')

    fig, axs = plt.subplots(1, 2, figsize=(10, 4.5), dpi=300)

    metric_info = {
        metric_col: blue_cmap,
        "efficiency": red_cmap,
    }

    def luminance(rgba):
        r, g, b, _ = rgba
        return 0.299 * r + 0.587 * g + 0.114 * b

    for ax, (metric, cmap) in zip(axs, metric_info.items()):
        # Create pivot table with complete grid
        z = (df_complete.pivot(index="p_neg", columns="p_pos", values=metric)
                .reindex(index=y_vals, columns=x_vals))
        
        # Create a masked array for missing values
        z_values = z.values.copy()
        mask = np.isnan(z_values)
        
        # Plot the heatmap (only for non-NaN values)
        im = ax.imshow(z_values, origin="lower", cmap=cmap, aspect="equal")

        # Convert to percentage labels for ticks
        x_labels = [f"{x*100:.1f}" if x >= 0.01 else f"{x*100:.2f}" for x in x_vals]
        y_labels = [f"{y*100:.1f}" if y >= 0.01 else f"{y*100:.2f}" for y in y_vals]
        
        # Set ticks with percentage values
        ax.set_xticks(range(len(x_vals)))
        ax.set_xticklabels(x_labels, rotation=90)
        ax.set_yticks(range(len(y_vals)))
        ax.set_yticklabels(y_labels)

        # Add numbers inside cells
        norm = im.norm
        for i, y in enumerate(y_vals):
            for j, x in enumerate(x_vals):
                if mask[i, j]:  # Missing value
                    # Fill with white background
                    ax.add_patch(plt.Rectangle((j-0.5, i-0.5), 1, 1, 
                                             facecolor='white', edgecolor='gray'))
                    ax.text(j, i, "n/a", ha="center", va="center", 
                           color='black', fontsize=10)
                else:
                    val = z_values[i, j]
                    rgba = cmap(norm(val))
                    txt_colour = "white" if luminance(rgba) < 0.5 else _new_black
                    ax.text(j, i, f"{val:.2f}",
                            ha="center", va="center", color=txt_colour, fontsize=10)

        title = f"{title_prefix}" if metric == metric_col else "Efficiency"
        ax.set_title(title)
        ax.set_xlabel("p_pos (%)")
        ax.set_ylabel("p_neg (%)")

    plt.tight_layout()
    plt.show()
    return

## viz deconv

In [None]:
def viz_deconv_score_pairs(csv_file,
                           label=None,
                           twoslope=False,
                           cell_size=0.45,        # inches per cell
                           dpi=300,
                           inner_lw=5,            # width of white cell spacers
                           spine_lw=0.8,
                           tickmarks=False):         # width of black frame
    """
    Lower–triangular heat-map of pair-wise scores.

    • every cell is separated by a white line (`inner_lw`)  
    • the whole matrix has a black frame (`spine_lw`)  
    • tick-marks are shown, labels are blank
    """

    # ───── 1. load & parse CSV
    df = pd.read_csv(csv_file)

    pairs, scores = [], []
    for _, row in df.iterrows():
        pool_txt = str(row['pool']).replace('\\"', '"')
        pair = None

        try:                                      # safest: ast
            pool = ast.literal_eval(pool_txt)
            if isinstance(pool, list) and len(pool) == 2:
                pair = pool
        except (ValueError, SyntaxError):
            pass

        if pair is None:                          # fallback: regex
            m = re.search(r'\[(\d+)\s*,\s*(\d+)\]', pool_txt)
            if m:
                pair = [int(m.group(1)), int(m.group(2))]

        if pair is not None:
            i, j = pair[0] - 1, pair[1] - 1       # 0-index
            pairs.append((i, j))
            scores.append(float(row['score']))

    # ───── 2. build matrix
    n = max(max(p) for p in pairs) + 1
    data = np.ones((n, n))
    for (i, j), s in zip(pairs, scores):
        data[j, i] = s                            # fill lower triangle

    mask = np.triu(np.ones_like(data, bool))      # hide upper triangle

    # ───── 3. colour-map & norm
    cmap = mpl.colormaps.get_cmap('Purples')
    #cmap = mpl.colormaps.get_cmap('plasma')
    norm = None
    if twoslope:
        vmin, vmax = min(scores), max(scores)
        norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

    # ───── 4. plot
    fig_size = cell_size * n
    plt.figure(figsize=(fig_size, fig_size), dpi=dpi)

    ax = sns.heatmap(data,
                     mask=mask,
                     cmap=cmap,
                     norm=norm,
                     square=True,
                     linewidths=inner_lw,
                     linecolor='#ffffff',          # white spacers
                     cbar=True,
                     annot=False,
                     xticklabels=False,
                     yticklabels=False,
                     cbar_kws={'shrink': .6})
    
    cbar = ax.collections[0].colorbar   # grab the colour-bar object
    cbar.ax.tick_params(labelsize=14)   # 12-point tick labels (pick any size)
    # tick-marks without labels
    centres = np.arange(n) + 0.5
    if tickmarks:
        ax.set_xticks(centres)
        ax.set_yticks(centres)
        ax.set_xticklabels([''] * n)
        ax.set_yticklabels([''] * n)
    ax.tick_params(axis='both', length=4, width=.8, color=_new_black)

    # black frame
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_edgecolor(_new_black)
        spine.set_linewidth(spine_lw)

    if label:
        ax.set_title(label, pad=6)

    plt.tight_layout()
    plt.show()


## viz experiments

In [None]:
binary_cmap = ListedColormap(['#bfbfbe', '#4d4d4d'])   # for the grid
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import ListedColormap, TwoSlopeNorm, Normalize

# colour maps ------------------------------------------------------------
# define/import these before calling:
# blue_cmap  = ...
# -----------------------------------------------------------------------


def plot_oneshot_expt(file_name, cell_size=0.45, dpi=300):

    # ───── 1. load & pre-process
    df = pd.read_csv(file_name)
    df['pool'] = df['pool'].apply(
        lambda s: list(map(int, s.strip('[]').split(',')))
    )
    pool_len = df['pool'].str.len()
    df['generation'] = (pool_len.shift(fill_value=pool_len.iloc[0]) > pool_len).cumsum() + 1

    # ───── 2. global colour scales
    out_min, out_max   = df['output'].min(), df['output'].max()
    coop_min, coop_max = df['coop'].min(),  df['coop'].max()

    norm_out  = Normalize(vmin=out_min,  vmax=out_max)
    norm_coop = TwoSlopeNorm(vmin=coop_min, vcenter=0, vmax=coop_max)

    cmap_out  = blue_cmap
    cmap_coop = mpl.colormaps.get_cmap('PuOr')

    figs_per_gen = []

    for gen_id, gen_df in df.groupby('generation', sort=False):

        n_rows, n_cats = len(gen_df), 11

        # =================================================================
        #  A. GRID  (catalyst presence)
        # =================================================================
        fig_g, ax_g = plt.subplots(
            figsize=(cell_size * n_cats, cell_size * n_rows), dpi=dpi
        )
        fig_g.subplots_adjust(0, 0, 1, 1)           # remove outer padding

        mat = np.zeros((n_rows, n_cats), dtype=int)
        for r, pool in enumerate(gen_df['pool']):
            mat[r, np.asarray(pool) - 1] = 1

        # Use white background for empty cells, keep dark grey for filled
        modified_binary_cmap = ListedColormap(['white', '#4d4d4d'])
        ax_g.imshow(mat, cmap=modified_binary_cmap, aspect='equal', interpolation='none')

        # thick white gridlines (as before)
        for r in range(n_rows + 1):
            ax_g.axhline(r - .5, color='white', lw=5, zorder=5)
        for c in range(n_cats + 1):
            ax_g.axvline(c - .5, color='white', lw=5, zorder=5)

        # Add light grey borders inside empty cells (value 0)
        for r in range(n_rows):
            for c in range(n_cats):
                if mat[r, c] == 0:  # empty cell
                    # Draw rectangle border inside the cell
                    rect = plt.Rectangle((c - 0.4, r - 0.4), 0.8, 0.8, 
                                    fill=False, edgecolor='#bfbfbe', linewidth=1.5, zorder=10)
                    ax_g.add_patch(rect)

        ax_g.set_xticks([])
        ax_g.set_yticks([])
        ax_g.set_frame_on(False)

        # =================================================================
        #  B. OUTPUT  (blue_cmap)
        # =================================================================
        fig_o, ax_o = plt.subplots(
            figsize=(cell_size, cell_size * n_rows), dpi=dpi
        )
        fig_o.subplots_adjust(0, 0, 1, 1)

        vals_o = gen_df['output'].to_numpy().reshape(-1, 1)
        ax_o.imshow(vals_o, cmap=cmap_out, norm=norm_out,
                    aspect='equal', interpolation='none')

        # separators: horizontal + the two vertical borders
        for r in range(n_rows + 1):
            ax_o.axhline(r - .5, color='white', lw=5, zorder=5)
        for x in (-.5, .5):
            ax_o.axvline(x, color='white', lw=5, zorder=5)

        # text
        for r, v in enumerate(vals_o.flat):
            rgba   = cmap_out(norm_out(v))
            bright = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
            ax_o.text(0, r, f'{v:.0f}', ha='center', va='center',
                      fontsize=10, color=('white' if bright < .5 else 'black'))

        ax_o.set_xticks([])
        ax_o.set_yticks([])
        ax_o.set_frame_on(False)

        # =================================================================
        #  C. COOP  (PuOr centred at 0)
        # =================================================================
        fig_c, ax_c = plt.subplots(
            figsize=(cell_size, cell_size * n_rows), dpi=dpi
        )
        fig_c.subplots_adjust(0, 0, 1, 1)

        vals_c = gen_df['coop'].to_numpy().reshape(-1, 1)
        ax_c.imshow(vals_c, cmap=cmap_coop, norm=norm_coop,
                    aspect='equal', interpolation='none')

        for r in range(n_rows + 1):
            ax_c.axhline(r - .5, color='white', lw=5, zorder=5)
        for x in (-.5, .5):
            ax_c.axvline(x, color='white', lw=5, zorder=5)

        for r, v in enumerate(vals_c.flat):
            rgba   = cmap_coop(norm_coop(v))
            bright = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
            ax_c.text(0, r, f'{v:+.1f}', ha='center', va='center',
                      fontsize=9, color=('white' if bright < .5 else 'black'))

        ax_c.set_xticks([])
        ax_c.set_yticks([])
        ax_c.set_frame_on(False)

        # display & store
        for fig in (fig_g, fig_o, fig_c):
            fig.show()

        figs_per_gen.append((fig_g, fig_o, fig_c))

    return figs_per_gen