# Convergence of IBCM and BioPCA habituation to turbulent backgrounds

# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns
import os, json
pj = os.path.join

# Initialization

### Aesthetic parameters

In [None]:
do_save_plots = True
# Resources
root_dir = pj("..", "..", "..")
data_folder = pj(root_dir, "results", "for_plots")
data_folder_conv = pj(root_dir, "results", "for_plots", "convergence")
panels_folder = "panels/"
params_folder = pj(root_dir, "results", "common_params")

In [None]:
# rcParams
with open(pj(params_folder, "olfaction_rcparams.json"), "r") as f:
    new_rcParams = json.load(f)
plt.rcParams.update(new_rcParams)

# color maps
with open(pj(params_folder, "back_colors.json"), "r") as f:
    all_back_colors = json.load(f)
back_color = all_back_colors["back_color"]
back_color_samples = all_back_colors["back_color_samples"]
back_palette = all_back_colors["back_palette"]

with open(pj(params_folder, "orn_colors.json"), "r") as f:
    orn_colors = json.load(f)
    
with open(pj(params_folder, "inhibitory_neuron_two_colors.json"), "r") as f:
    neuron_colors = np.asarray(json.load(f))
with open(pj(params_folder, "inhibitory_neuron_full_colors.json"), "r") as f:
    neuron_colors_full = np.asarray(json.load(f))

with open(pj(params_folder, "model_colors.json"), "r") as f:
    model_colors = json.load(f)
with open(pj(params_folder, "model_nice_names.json"), "r") as f:
    model_nice_names = json.load(f)

models = list(model_colors.keys())
print(models)

In [None]:
# Extra aesthetic parameters for this figure
# Figures slightly less high, to squeeze four rows of plots
plt.rcParams["figure.figsize"] = (plt.rcParams["figure.figsize"][0], 1.6)

# More legend rcParams: make everything smaller by 30 %
plt.rcParams["patch.linewidth"] = 0.75
legend_rc = {"labelspacing":0.5, "handlelength":2.0, "handleheight":0.7, 
             "handletextpad":0.8, "borderaxespad":0.5, "columnspacing":2.0}
for k in legend_rc:
    plt.rcParams["legend."+k] = 0.75 * legend_rc[k]

new_color = "r"
linestyles = ["-", "--", ":", (0, (5, 1, 2, 1)), "-."]
neuron_styles = linestyles + [(0, (1, 2, 1, 2))]

markerstyles = ["o", "s", "^", "v", "X", "*", "d", "h", "<", "p", "P"]

# IBCM plotting functions

In [None]:
def plot_ibcm_hgammas_series(t_axis, i_highlights, hgammaser, squeeze=0.65, figax=[]):
    # Show three neurons
    if figax is None:
        fig, ax = plt.subplots()
    else:
        fig, ax = figax
    # By default, we squeeze to be able to put matrix of hgammas series besides
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*squeeze, 
                        plt.rcParams["figure.figsize"][1])

    #ax.axhline(0.0, ls="-", color=(0.8,)*3, lw=0.8)
    legend_styles = [[0,]*6, [0,]*6, [0,]*6]
    neuron_colors3 = neuron_colors_full[[8, 17, 23]]
    clr_back = back_palette[-1]
    plot_skp = 20
    n_b = hgammaser.shape[2]

    # plot all other neurons first, skip some points
    for i in range(n_i_ibcm):
        if i in i_highlights: 
            continue
        elif i % 2 == 0:   # thinning
            continue
        else: 
            for j in range(n_b):
                ax.plot(t_axis[::plot_skp], hgammaser[::plot_skp, i, j], color=clr_back, 
                    ls="-", alpha=1.0-0.1*j, lw=plt.rcParams["lines.linewidth"]-j*0.1)

    # Now plot the highlighted neuron
    for j in range(n_b):
        for i in range(len(i_highlights)):
            li, = ax.plot(t_axis[::plot_skp], hgammaser[::plot_skp, i_highlights[i], j], 
                          color=neuron_colors3[i], ls="-", alpha=1.0-0.2*j, 
                          lw=plt.rcParams["lines.linewidth"]-j*0.2)
            legend_styles[i][j] = li
    
    # Annotations
    ax.set(xlabel="Time (min)", 
           ylabel=r"Alignments $\bar{h}_{i\gamma} = \mathbf{\bar{m}}_i \cdot \mathbf{s}_{\gamma}$")

    return fig, ax

def plot_ibcm_hgammas_matrix(hgammas_mat, i_high, squeeze=0.4):
    n_i, n_comp = hgammas_mat.shape
    neuron_colors3 = neuron_colors_full[[8, 17, 23]]
    
    fig, ax = plt.subplots()
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*squeeze, 
                        plt.rcParams["figure.figsize"][1])
    # Extent: left, right, bottom, top
    # Greyscale version
    #ax.imshow(hgammas_matrix, cmap="Greys", aspect=0.6, extent=(0.5, n_comp+0.5, 0.5, n_i))
    #ax.set_xticks(list(range(1, n_comp+1)))
    # Colorful version: add patches manually with fill_between. 
    # Color highlighted neurons, leave others grayscale!
    normed_matrix = (hgammas_mat - hgammas_mat.min()) / (hgammas_mat.max() - hgammas_mat.min())
    for i in range(n_i):
        # Full rainbow version
        #cmap = sns.light_palette(neuron_colors_full[i], as_cmap=True)
        # Version where only highlights are colored
        if i in i_high:
            cmap = sns.light_palette(neuron_colors3[i_high.index(i)], as_cmap=True)
        else:
            cmap = sns.color_palette("Greys", as_cmap=True)
        for j in range(n_comp):
            ax.fill_between([-0.5+j, 0.5+j], -0.5+i, 0.5+i, color=cmap(normed_matrix[i, j]))

    ax.set_xlim([-0.6, -0.6+n_comp])
    ax.set_ylim([-0.6, -0.6+n_i])
    ax.set_yticks(list(range(0, n_i, 2)))

    for i, lbl in enumerate(ax.get_yticklabels()):
        if int(lbl.get_text()) in i_high:
            clr = neuron_colors3[i_high.index(int(lbl.get_text()))]
            lbl.set_color(clr)
            ax.yaxis.get_ticklines()[i].set_color(clr)

    ax.set_xlabel(r"Component $\gamma$", size=6)
    ax.set_ylabel(r"IBCM neuron index $i$", size=6)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(hgammas_mat.min(), hgammas_mat.max()), 
        cmap="Greys"), ax=ax, aspect=30, pad=0.1)
    cbar.set_ticks([])
    cbar.set_label(label=r"Alignments ${\bar{h}}_{i\gamma}$, 45 min", fontsize=6)
    fig.tight_layout()
    return fig, ax, cbar


In [None]:
def plot_ibcm_hgammas_series_percomp(t_axis, comp_high, hgammaser, n_b, figax=None):
    """ Assumes a separate hgammas matrix will be plotted as a legend """
    if figax is None:
        fig, ax = plt.subplots()
        fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.75, 
                        plt.rcParams["figure.figsize"][1])
    else:
        fig, ax = figax

    odor_colors = sns.color_palette("colorblind", n_colors=n_b)
    clr_back = back_palette[-1]
    plot_skp = 20

    # plot all other components first, skip some points
    for j in range(n_b):
        if j in comp_high: 
            continue
        else:
            for i in range(0, n_i_ibcm, 2):
                ax.plot(t_axis[::plot_skp], hgammaser[::plot_skp, i, j], color=clr_back, 
                    ls="-", alpha=0.8, lw=plt.rcParams["lines.linewidth"])

    # Now plot the highlighted components
    for j in range(len(comp_high)):
        for i in range(0, n_i_ibcm):
            lbl = r"$\gamma = {}$".format(j) if i == 0 else ""
            li, = ax.plot(t_axis[::plot_skp], hgammaser[::plot_skp, i, comp_high[j]], 
                          color=odor_colors[j], ls="-", alpha=1.0-0.05*i, 
                          label=lbl, lw=plt.rcParams["lines.linewidth"]-0.05*i)

    ax.set(xlabel="Time (min)", 
           ylabel=r"Alignments $\bar{h}_{i\gamma} = \mathbf{\bar{m}}_i \cdot \mathbf{s}_{\gamma}$")
    ax.legend(frameon=False, title="Component", title_fontsize=6)

    fig.tight_layout()
    return fig, ax

def plot_ibcm_hgammas_matrix_percomp(hgammas_mat, comp_high, n_i, n_comp):
    fig, ax = plt.subplots()
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.6, 
                        plt.rcParams["figure.figsize"][1])
    
    normed_matrix = (hgammas_mat - hgammas_mat.min()) / (hgammas_mat.max() - hgammas_mat.min())
    odor_colors = sns.color_palette("colorblind", n_colors=n_comp)
    for j in range(n_comp):
        if j in comp_high:
            cmap = sns.light_palette(odor_colors[comp_high.index(j)], as_cmap=True)
        else:
            cmap = sns.color_palette("Greys", as_cmap=True)
        for i in range(n_i):
            ax.fill_between([-0.5+j, 0.5+j], -0.5+i, 0.5+i, color=cmap(normed_matrix[i, j]))

    ax.set_xlim([-0.6, -0.6+n_comp])
    ax.set_ylim([-0.6, -0.6+n_i])
    ax.set_xticks(list(range(0, n_comp)))
    ax.set_yticks(list(range(0, n_i, 2)))

    for j, lbl in enumerate(ax.get_xticklabels()):
        if int(lbl.get_text()) in comp_high:
            clr = odor_colors[comp_high.index(int(lbl.get_text()))]
            lbl.set_color(clr)
            ax.xaxis.get_ticklines()[j].set_color(clr)
    ax.set(xlabel=r"Component $\gamma$", ylabel="IBCM neuron index $i$")
    cbar = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(hgammas_mat.min(), hgammas_mat.max()), 
        cmap="Greys"), ax=ax, label=r"Alignments ${\bar{h}}_{i\gamma}$, 45 min", aspect=30, pad=0.1)
    fig.tight_layout()
    return fig, ax, cbar


# Load IBCM simulations

In [None]:
# Extract saved simulation
with np.load(pj(data_folder_conv, "non-gaussian_ibcm_convergence_analysis.npz")) as fp:
    hgammas_ser = fp["hgammas_ser"]
    tser_scaled = fp["tser_scaled"]
    specifs = fp["hgammas_specifs"]
    th_predictions = fp["th_predictions"]
    tu_predictions = fp["tu_predictions"]
    horiz_lines = fp["horiz_lines"]
    backnorm_ser = fp["backnorm_ser"]
    ynorm_ser = fp["ynorm_ser"]
    moments_conc = fp["moments_conc"]
    fixed_point_preds = fp["fixed_point_preds"]
    saddle_hd_divmean, saddle_u2, target_u2 = horiz_lines
    learnrate_mscale = fp["learnrate_mscale"]

# Panel A: good IBCM example, annotation of convergence metrics

In [None]:
tu_scale = 1.0 / 100.0 / 60.0  # 10 ms steps to min
n_i_ibcm = hgammas_ser.shape[1]
nsteps = hgammas_ser.shape[0]
n_components = hgammas_ser.shape[2]
tser_example = tser_scaled  # in min

transient = int(3*tser_example.size/4)
i_highlights = [0, 1, 4]  # Neurons to highlight

hgammas_matrix = np.mean(hgammas_ser[transient:], axis=0)

In [None]:
fig, axes = plt.subplots(2, sharex=True)
ax = axes[0]
#fig, ax = plot_ibcm_hgammas_series(tser_example, i_highlights, hgammas_ser, 
#                                   squeeze=1.0, figax=[fig, axes[0]])
fig, ax = plot_ibcm_hgammas_series_percomp(tser_example, [0, 1, 2], hgammas_ser, 
                                           n_components, figax=[fig, ax])
# Remove x label
ax.set_xlabel("")

# Annotate with learning rate and initial conditions
ax.set_title((r"Learning rate: $\mu =" +" {:.1f}".format(learnrate_mscale[0]/tu_scale) 
              + r" \,\mathrm{s^{-1}}$"), y=0.98)
arrowprops = {"arrowstyle": "<->", "lw":0.75, "color":"k"}
arrowprops.update({"lw": 0.5, "arrowstyle":"->", "shrinkB":0.01})
xannot = 5.0
hgampad = 0.1
arrowlen = 0.6
ax.annotate("", xy=(xannot, hgampad), xytext=(xannot, hgampad+arrowlen), arrowprops=arrowprops)
ax.annotate("", xy=(xannot, -hgampad), xytext=(xannot, -hgampad-arrowlen), arrowprops=arrowprops)
ax.annotate((r"$m_{\gamma, \mathrm{init}} \sim " + "{:d}".format(int(learnrate_mscale[1]*1e4)) 
             + r" \times 10^{-4}$"), xy=(xannot-5.0, -hgampad-arrowlen), 
            ha="left", va="top", fontsize=5)

# Adjust size for second subplot
fig.set_size_inches(plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][1]*1.75)

# Combine plot with the y norm series, to see when it habituates
n_odors = hgammas_ser.shape[1]
ax2 = axes[1]
tsl = slice(None, None, 20)
ax2.plot(tser_example[tsl], backnorm_ser[tsl], color="grey", alpha=0.8)
ax2.plot(tser_example[tsl], ynorm_ser[tsl], color="k", alpha=0.8)
ax2.set(xlabel="Time (min)", ylabel="PN activity norm")

th_pred_lastcomp = th_predictions[np.argmax(specifs == 1)]
lastcomp_color = sns.color_palette("colorblind")[1]
for ax in axes:
    ax.axvline(th_pred_lastcomp, ls="--", color=lastcomp_color)

txt = ax2.annotate("Predicted $t_h$ for\nlast covered odor", xy=(th_pred_lastcomp+1, ax.get_ylim()[1]*0.98), 
            ha="left", va="top", fontsize=6, color=lastcomp_color)
txt.set_bbox(dict(facecolor='w', alpha=0.8, edgecolor='none', pad=0))

fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_hgamma_example.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax, axi = plot_ibcm_hgammas_matrix_percomp(hgammas_matrix, [0, 1, 2], n_i_ibcm, n_components)
# Extra annotations

fig.tight_layout()
fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.4, plt.rcParams["figure.figsize"][1]*0.9)
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_hgamma_matrix_example.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel B: other IBCM example, larger $\mu$, analytics still work.


In [None]:
# Extract saved simulation
with np.load(pj(data_folder_conv, "non-gaussian_ibcm_convergence_analysis_highmu.npz")) as fp:
    hgammas_ser_himu = fp["hgammas_ser"]
    tser_scaled_himu = fp["tser_scaled"]
    tser_example_himu = tser_scaled_himu
    specifs_himu = fp["hgammas_specifs"]
    th_predictions_himu = fp["th_predictions"]
    tu_predictions_himu = fp["tu_predictions"]
    horiz_lines_himu = fp["horiz_lines"]
    backnorm_ser_himu = fp["backnorm_ser"]
    ynorm_ser_himu = fp["ynorm_ser"]
    moments_conc_himu = fp["moments_conc"]
    fixed_point_preds_himu = fp["fixed_point_preds"]
    learnrate_mscale = fp["learnrate_mscale"]

In [None]:
fig, axes = plt.subplots(2, sharex=True)
ax = axes[0]
#fig, ax = plot_ibcm_hgammas_series(tser_example, i_highlights, hgammas_ser, 
#                                   squeeze=1.0, figax=[fig, axes[0]])
fig, ax = plot_ibcm_hgammas_series_percomp(tser_example_himu, [0, 1, 2], hgammas_ser_himu, 
                                           n_components, figax=[fig, ax])
# Remove x label
ax.set_xlabel("")

# Annotate with learning rate and initial conditions. 
# Learning rate here is 0.002, m initial scale is 2e-3 (1 + fluctuating part of 1e-3)
ax.set_title((r"Learning rate: $\mu =" +" {:.1f}".format(learnrate_mscale[0]/tu_scale) 
                + r" \,\mathrm{s^{-1}}$"), y=0.98)
arrowprops = {"arrowstyle": "<->", "lw":0.75, "color":"k"}
arrowprops.update({"lw": 0.5, "arrowstyle":"->", "shrinkB":0.01})
xannot = 1.0
hgampad = 0.1
arrowlen = 0.6
ax.annotate("", xy=(xannot, hgampad), xytext=(xannot, hgampad+arrowlen), arrowprops=arrowprops)
ax.annotate("", xy=(xannot, -hgampad), xytext=(xannot, -hgampad-arrowlen), arrowprops=arrowprops)
txt = ax.annotate((r"$m_{\gamma, \mathrm{init}} \sim " + "{:d}".format(int(learnrate_mscale[1]*1e3)) 
             + r" \times 10^{-3}$"), xy=(xannot-1.0, -hgampad-arrowlen), 
            ha="left", va="top", fontsize=5, annotation_clip=False)
txt.set_bbox(dict(facecolor='w', alpha=0.8, edgecolor='none', pad=0))
# Re-do legend with title to save space
ax.legend(frameon=False, loc="upper right", bbox_to_anchor=(1.0, 0.75))

# Adjust size for second subplot
fig.set_size_inches(plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][1]*1.75)

# Combine plot with the y norm series, to see when it habituates
n_odors = hgammas_ser_himu.shape[1]
ax2 = axes[1]
tsl = slice(None, None, 20)
ax2.plot(tser_example_himu[tsl], backnorm_ser_himu[tsl], color="grey", alpha=0.8)
ax2.plot(tser_example_himu[tsl], ynorm_ser_himu[tsl], color="k", alpha=0.8)
ax2.set(xlabel="Time (min)", ylabel="PN activity norm")
 
last_specif = 0  # manually looked at which component is selected last
th_pred_lastcomp = th_predictions_himu[np.argmax(specifs_himu == last_specif)]
lastcomp_color = sns.color_palette("colorblind")[last_specif]
for ax in axes:
    ax.axvline(th_pred_lastcomp, ls="--", color=lastcomp_color)

txt = ax2.annotate("Predicted $t_h$ for\nlast covered odor", 
                   xy=(th_pred_lastcomp+1, ax.get_ylim()[1]*0.98), 
            ha="left", va="top", fontsize=6, color=lastcomp_color)
txt.set_bbox(dict(facecolor='w', alpha=0.8, edgecolor='none', pad=0))

fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_hgamma_example_himu.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
hgammas_matrix_himu = np.mean(hgammas_ser_himu[transient:], axis=0)
fig, ax, axi = plot_ibcm_hgammas_matrix_percomp(hgammas_matrix_himu, [0, 1, 2], n_i_ibcm, n_components)

fig.tight_layout()
fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.4, plt.rcParams["figure.figsize"][1]*0.9)
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_hgamma_matrix_example_himu.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel C: individual neurons $h_\mathrm{d}$, $u^2$ series

In [None]:
# For convergence, track sums of h_gammas within each neuron
hdsums = np.sum(hgammas_ser, axis=2)
u2sums = np.sum(hgammas_ser**2, axis=2)

ncols = 5
nrows = n_i_ibcm // ncols + min(1, n_i_ibcm % ncols)
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True)
fig.set_size_inches(plt.rcParams["figure.figsize"][0]*3.0, 
                    plt.rcParams["figure.figsize"][1]*0.7*nrows)
colors = sns.color_palette("tab20", n_colors=n_i_ibcm*2)
tsl = slice(0, len(tser_scaled), 4)
for i in range(n_i_ibcm):
    colors2 = colors[0], colors[1]  # Same two colors for each subplot
    axes.flat[i].plot(tser_scaled[tsl], hdsums[tsl, i], 
                 alpha=1.0, color=colors2[0], ls="-", 
                label=r"$ h_d \,/\langle c \rangle = \sum_{\gamma} h_{\gamma}$")
    axes.flat[i].plot(tser_scaled[tsl], u2sums[tsl, i], 
                 alpha=1.0, color=colors2[1], ls="--", label=r"$u^2 = \sum_{\gamma} h_{\gamma}^2$")
    axes.flat[i].set_title("Neuron {}, specif: {}".format(i, specifs[i]), y=0.9)
    #axes.flat[i].axhline(1.0 / avgnu, ls="-", color="k", lw=0.5)
    axes.flat[i].axhline(saddle_hd_divmean, ls="-", color="k", lw=0.75)
    # saddle and final h_d are pretty much the same
    axes.flat[i].axhline(saddle_u2, ls=":", color="grey", lw=0.75)

    # u^2 at stable fixed point? Not worth showing
    #axes.flat[i].axhline(fixed_u2, ls="-.", color="grey", lw=0.75)
    
    # Prediction of convergence time?
    th_pred = th_predictions[i]
    axes.flat[i].axvline(th_pred, ymax=0.9, 
        color="r", lw=0.75, ls="-", label="$t_d$ predict.")
    #axes.flat[i].axhline(0.0, ls="-", color="k", lw=0.5)
    
    # Phase 2: predict from actual t_d, using the exponential exit rate
    # obtained from the Jacobian matrix at the saddle point
    # Use for the initial value of u^2 the value at the actual t_h
    tu_pred = tu_predictions[i]
    clr_tu = colors[3*2+1]
    axes.flat[i].axvline(tu_pred, ymax=0.9, 
        color=clr_tu, lw=0.75, ls="-", label=r"$t_u$ predict.")  # after true t_d
    axes.flat[i].axhline(target_u2, ls="--", color=clr_tu, lw=0.75)
    
    # Annotations where there is space
    if th_pred > 25.0:
        xannot, yshiftd, yshiftu, halign = 0.0, 0.0, 0.15, "left"
    else:
        xannot, yshiftd, yshiftu, halign = 60.0, 0.5, 0.0, "right"
    axes.flat[i].annotate(r"Saddle $h_\mathrm{d}/\langle c \rangle$", 
                          xy=(xannot, saddle_hd_divmean + yshiftd), fontsize=5, ha=halign, va="bottom")
    axes.flat[i].annotate(r"Saddle $u^2$", xy=(xannot, saddle_u2 + yshiftu), fontsize=5, 
                          ha=halign, va="top", color="grey")
    axes.flat[i].annotate(r"Target $u^2$", xy=(xannot, target_u2), fontsize=5, 
                          ha=halign, va="top", color=clr_tu)
for i in range(n_i_ibcm, axes.size):
    if i == n_i_ibcm:
        handles, labels = axes.flat[0].get_legend_handles_labels()
        handles[1], handles[0] = handles[0], handles[1]
        labels[1], labels[0] = labels[0], labels[1]
        axes.flat[i].legend(handles, labels, frameon=False, loc="upper left")
    axes.flat[i].set_axis_off()
for i in range(n_i_ibcm-ncols, n_i_ibcm):
    axes.flat[i].set_xlabel("Time (min)")
for j in range(2):
    axes[j, 0].set_ylabel(r"Sums of $h_\gamma$s")
fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_individual_neurons.pdf"), 
               transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel D, E, F: BioPCA $L$, $M$, and background inhibition

In [None]:
def plot_biopca_convergence(t_axis, true_pvs, learned_pvs, figax=None):
    # Get existing figure and axis, if any
    if figax is None: fig, ax = plt.subplots()
    else: fig, ax = figax
    # First plot: eigenvalues
    n_comp = learned_pvs.shape[1]
    pca_palette = sns.color_palette("colorblind", n_colors=n_comp)
    
    tsl = slice(None, None, 4)
    for i in range(n_comp):
        li, = ax.plot(t_axis[tsl], learned_pvs[tsl, i], label="Value {}".format(i),
                      lw=plt.rcParams["lines.linewidth"] - 0.5*i/n_comp, zorder=10-i, color=pca_palette[i])
        if true_pvs[i] / true_pvs.max() > 1e-12:
            ax.axhline(true_pvs[i], ls="--", color=pca_palette[i], 
                       lw=plt.rcParams["lines.linewidth"] - 0.5*i/n_comp, zorder=n_comp-i)
    ax.set(ylabel="Principal values, diag$(L^{-1})$", yscale="log", xlabel="Time (min)")

    return fig, ax

In [None]:
# Extract saved simulation
with np.load(pj(data_folder_conv, "non-gaussian_biopca_convergence_analysis.npz")) as fp:
    tser_scaled = fp["tser_scaled"]
    mnormser = fp["mnormser"]
    tser_example_himu = tser_scaled_himu
    lser = fp["lser"]
    true_pvs = fp["true_pvs"]
    align_error = fp["align_error"]
    ynormser_pca = fp["ynormser"]
    bknormser_pca = fp["bknormser"]
    decay_times = fp["decay_times"]
    decay_rates = fp["decay_rates"]
    n_i_pca = len(decay_times)

In [None]:
# Visualize convergence dynamics in terms of L, M norm, and PN activity norm
# L plot first
fig, ax = plot_biopca_convergence(tser_scaled, true_pvs[:3], lser)

ax.set_ylim(ax.get_ylim()[0]*0.65, ax.get_ylim()[1])

# Stop vertical lines at the level of horizontal ones
axis_to_data = ax.transAxes + ax.transData.inverted()
data_to_axis = axis_to_data.inverted()

comp_colors = []
for i in range(n_i_pca):
    clr = ax.get_lines()[2*i].get_color()
    comp_colors.append(clr)
    ax.axvline(decay_times[i], ls=":", ymax=data_to_axis.transform((decay_times[i], true_pvs[i]))[1], 
                color=clr)
    ax.annotate("Pred.\n" + r"$t_{}$".format(i+1), xy=(decay_times[i]+1.0, ax.get_ylim()[0]), 
               ha="left", va="bottom", fontsize=5, color=clr)

# TODO: custom legend to indicate analytical vs learned?
handles = [mpl.lines.Line2D([0], [0], color="grey", ls="-", label=r"BioPCA", 
                            lw=plt.rcParams["lines.linewidth"]), 
          mpl.lines.Line2D([0], [0], color="grey", ls=":", label="Predicted $t_i$", 
                          lw=plt.rcParams["lines.linewidth"]),
          mpl.lines.Line2D([0], [0], color="grey", ls="--", label="True PCA", 
                          lw=plt.rcParams["lines.linewidth"])]
leg = ax.legend(handles=handles, frameon=False, fontsize=5.0, ncol=2)
leg.set_zorder(30)

fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_biopca_principal_principal_values.pdf"), 
               transparent=True, bbox_inches="tight")

plt.show()
plt.close()

In [None]:
# Then M plot 
fig, ax = plt.subplots()
tsl = slice(None, None, 4)
ypositions = [2.0, 0.5, 0.15]
for i in range(n_i_pca):
    li, = ax.plot(tser_scaled[tsl], mnormser[tsl, i], color=comp_colors[i])
    ax.axvline(decay_times[i], ls=":", color=li.get_color())
    ax.annotate("Pred.\n" + r"$t_{}$".format(i+1), xy=(decay_times[i]+1.0, ypositions[i]), 
               ha="left", va="bottom", fontsize=5, color=li.get_color())
ax.set(yscale="log", xlabel="Time (min)", ylabel=r"Norm of $M$ rows, $\|\mathbf{m}\|$")
fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_biopca_mnorms.pdf"), 
               transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
# Then, finally, background inhibition
fig, ax = plt.subplots()
ax.plot(tser_example_himu[tsl], bknormser_pca[tsl], color="grey", alpha=0.8)
ax.plot(tser_example_himu[tsl], ynormser_pca[tsl], color="k", alpha=0.8)
ax.set(xlabel="Time (min)", ylabel="PN activity norm")
 
for i in range(n_i_pca):
    ax.axvline(decay_times[i], ls=":", color=comp_colors[i])
    ax.annotate("Pred.\n" + r"$t_{}$".format(i+1), xy=(decay_times[i]+1.0, bknormser_pca[tsl].max()*1.1), 
               ha="left", va="top", fontsize=5, color=comp_colors[i], annotation_clip=False)

fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_convergence_time_biopca_ynorm_inhibition.pdf"), 
               transparent=True, bbox_inches="tight")
plt.show()
plt.close()