In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, List, Dict, Any, Optional

In [None]:
def dpca_custom_stimONLY(Xfull, W, V, plot_function: Callable, **kwargs):
    """
    dPCA custom function 
    """

    # Default options
    options = {
        'time': None,
        'whichMarg': None,
        'timeEvents': None,
        'ylims': None,
        'componentsSignif': None,
        'timeMarginalization': None,
        'legendSubplot': None,
        'marginalizationNames': None,
        'marginalizationColours': None,
        'explainedVar': None,
        'numCompToShow': 15,
        'X_extra': None,
        'showNonsignificantComponents': False
    }
    options.update(kwargs)
    numCompToShow = min(options['numCompToShow'], W.shape[1])

    X = Xfull.reshape((-1, Xfull.shape[-1])).T
    Xcen = X - X.mean(axis=0)
    Z = Xcen @ W

    # Determine components to plot
    components_to_plot = []
    if options['whichMarg'] is not None:
        unique_margs = np.unique(options['whichMarg'])
        if len(unique_margs) <= 4 and len(unique_margs) > 1:
            marg_row_seq = (
                [options['timeMarginalization']] +
                [m for m in unique_margs if m != options['timeMarginalization']]
            ) if options['timeMarginalization'] else unique_margs

            for marg in marg_row_seq:
                significant_components = np.where((options['whichMarg'] == marg) &
                                                  (options['componentsSignif'][:, :].sum(axis=1) != 0))[0]
                if options['showNonsignificantComponents'] and len(significant_components) < 3:
                    additional = np.setdiff1d(np.where(options['whichMarg'] == marg)[0], significant_components)[:3]
                    significant_components = np.concatenate((significant_components, additional))
                components_to_plot.extend(significant_components[:3])
        else:
            components_to_plot = np.sort(
                np.concatenate([np.where(options['whichMarg'] == marg)[0][:2] for marg in unique_margs])
            )[:12]
    else:
        components_to_plot = np.arange(min(numCompToShow, 12))

    Zfull = Z[:, components_to_plot].T.reshape((len(components_to_plot),) + Xfull.shape[1:])

    # y-axis spans
    if options['ylims'] is None:
        options['ylims'] = [np.nanmax(np.abs(Zfull)) * 1.1] * (len(unique_margs) if unique_margs is not None else 1)

    fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(18, 10))
    axes = axes.ravel()

    for c_idx, component in enumerate(components_to_plot):
        ax = axes[c_idx]
        if options['componentsSignif'] is not None:
            signif_trace = options['componentsSignif'][component]
        else:
            signif_trace = None

        this_ylim = options['ylims'][options['whichMarg'][component]] if options['whichMarg'] is not None else options['ylims']
        marg_idx = options['whichMarg'][component] if options['whichMarg'] is not None else None

        plot_data = Zfull[c_idx]
        plot_function(ax, plot_data, options['time'], ylims=[-this_ylim, this_ylim],
                      component_variance=options['explainedVar']['componentVar'][component] if options['explainedVar'] else None,
                      component_idx=component, time_events=options['timeEvents'],
                      significance=signif_trace, marginalization=marg_idx)

        if marg_idx is not None and options['marginalizationNames']:
            ax.text(0.1, 0.9, options['marginalizationNames'][marg_idx], transform=ax.transAxes)

    plt.tight_layout()
    plt.show()

    return components_to_plot, Zfull, options, Z

