# Supplementary figure for OSN adaptation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import os, json
pj = os.path.join

In [None]:
do_save_plots = True
# Resources
root_dir = pj("..", "..", "..")
data_folder = pj(root_dir, "results", "for_plots")
data_folder_nl = pj(root_dir, "results", "for_plots", "nonlin_adapt")
panels_folder = "panels/"
params_folder = pj(root_dir, "results", "common_params")

In [None]:
# rcParams
with open(pj(params_folder, "olfaction_rcparams.json"), "r") as f:
    new_rcParams = json.load(f)
plt.rcParams.update(new_rcParams)

# color maps
with open(pj(params_folder, "back_colors.json"), "r") as f:
    all_back_colors = json.load(f)
back_color = all_back_colors["back_color"]
back_color_samples = all_back_colors["back_color_samples"]
back_palette = all_back_colors["back_palette"]

with open(pj(params_folder, "orn_colors.json"), "r") as f:
    orn_colors = json.load(f)
    
with open(pj(params_folder, "inhibitory_neuron_two_colors.json"), "r") as f:
    neuron_colors = np.asarray(json.load(f))
with open(pj(params_folder, "inhibitory_neuron_full_colors.json"), "r") as f:
    neuron_colors_full = np.asarray(json.load(f))

with open(pj(params_folder, "model_colors.json"), "r") as f:
    model_colors = json.load(f)
with open(pj(params_folder, "model_nice_names.json"), "r") as f:
    model_nice_names = json.load(f)

models = list(model_colors.keys())
print(models)

In [None]:
# Specific model names and colors
model_nice_names.update({
    "ibcm_adapt": "IBCM + adapt.",
    "biopca_adapt": "BioPCA + adapt.",
    "adapt": "OSN adaptation",
    "optimal_adapt": "Optim. $P$ + adapt."
})
model_colors.update({
    "ibcm_adapt": model_colors.get("ibcm"),
    "biopca_adapt": model_colors.get("biopca"),
    "adapt": "xkcd:purple",
    "optimal_adapt": model_colors.get("optimal")
})

In [None]:
# Extra aesthetic parameters for this figure
# Figures slightly less high, to squeeze four rows of plots
plt.rcParams["figure.figsize"] = (plt.rcParams["figure.figsize"][0], 1.6)

# More legend rcParams: make everything smaller by 30 %
plt.rcParams["patch.linewidth"] = 0.75
legend_rc = {"labelspacing":0.5, "handlelength":2.0, "handleheight":0.7, 
             "handletextpad":0.8, "borderaxespad":0.5, "columnspacing":2.0}
for k in legend_rc:
    plt.rcParams["legend."+k] = 0.75 * legend_rc[k]

new_color = "r"
linestyles = ["-", "--", ":", (0, (5, 1, 2, 1)), "-."]
neuron_styles = linestyles + [(0, (1, 2, 1, 2))]

In [None]:
def l2_norm(vecs, axis=-1):
    r""" Computes l2 norm of vectors stored along the last axis of vecs.
    Args:
        vecs can be either a single vector (1d) or an  arbitrary array of vectors,
            where the last dimension indexes elements of vectors.
        axis (int): which axis to sum along.

    Returns: array of distances of same shape as vecs
        except for the summation axis, removed.
    """
    return np.sqrt(np.sum(vecs**2, axis=axis))

# Panel A: OSN adaptation cartoon
Here, need to plot a cartoon $\epsilon_i(t)$ trace at high resolution. Do this by defining a dummy whiff-blank function of time and integrating a moving average of it with time constant $\tau_a = 250$ ms

In [None]:
#t_seq = np.asarray([0.0, 10.0, 15.0, 17.0, 30.0, 32.0, 48.0, 55.0, np.inf])
#c_seq = np.asarray([1.0, 0.0,  1.5,  1.25, 2.5,  0.0,  3.0,  0.0])
#def whiffseq(t):  # t is in units of 10 ms
#    idx = np.argmax(t < t_seq) - 1  # first element where t < t_seq is returned
#    return c_seq[idx]
tc_ser = np.load(pj(data_folder, "sample_turbulent_background.npz"))["nuser"]
#c_seq = tc_ser[200:10200, 0, 1].copy()
# Mix three odors
c_seq = (tc_ser[200:10200, 0, 1]*0.3 + tc_ser[200:10200, 1, 1]*0.7 + tc_ser[200:10200, 2, 1]*0.5)
t_seq = np.arange(0.0, c_seq.shape[0]*10.0, 10.0)  # 10 ms per step in the loaded example
def whiffseq(t):
    idx = int(t / 10.0)
    return c_seq[idx]

def osn_toy(s, eps):
    return s / (s + np.exp(eps))

In [None]:
def integrate_epsil_trace(whiff, duration=9000.0):
    # Integrate over 200 ms, Euler is fine
    dt = 1.0
    t_a = 25.0
    target_resp = 0.25
    epsil_range = [-5.0, 5.0]

    # Containers
    tser = np.arange(0.0, duration+dt/2.0, dt)
    epsil_ser = np.zeros(int(duration / dt)+1)
    conc_ser = np.zeros(int(duration / dt)+1)
    activ_ser = np.zeros(int(duration / dt)+1)
    
    # Initialize variables
    t = 0.0
    epsil_ser[0] = 0.0
    conc = whiff(t)
    conc_ser[0] = conc
    activ = osn_toy(conc, epsil_ser[0])
    activ_ser[0] = activ
    for i in range(epsil_ser.size-1):
        epsil_ser[i+1] = epsil_ser[i] + dt * (activ - target_resp) / t_a
        epsil_ser[i+1] = np.clip(epsil_ser[i+1], a_min=epsil_range[0], a_max=epsil_range[1])
        t += dt
        conc = whiff(t)
        conc_ser[i+1] = conc
        activ = osn_toy(conc, epsil_ser[i+1])
        activ_ser[i+1] = activ
        
    return tser, epsil_ser, conc_ser, activ_ser

In [None]:
ttrace, epsil_trace, conc_trace, activ_trace = integrate_epsil_trace(whiffseq, duration=6000.0)

In [None]:
fig0, ax0 = plt.subplots()
fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots()
figures = [fig0, fig1, fig2]
axes = [ax0, ax1, ax2]
for fig in figures:
    fig.set_size_inches(0.6, 0.35)
skp_plot = 50
lw = 0.75

axes[0].plot(ttrace[::skp_plot], conc_trace[::skp_plot], color=back_palette[1], lw=lw)
axes[0].set_title("OSN input", fontsize=5, y=0.55, x=0.04, ha="left", color=back_palette[1])

clr_activ = sns.color_palette("cubehelix")[3]
axes[1].plot(ttrace[::skp_plot], activ_trace[::skp_plot], color=clr_activ, lw=lw)
axes[1].set_title("OSN activity", fontsize=5, y=0.55, x=0.04, ha="left", color=clr_activ)

axes[2].plot(ttrace[::skp_plot], epsil_trace[::skp_plot], color=model_colors["adapt"], lw=lw)
axes[2].set_title(r"$\epsilon_i(t)$", fontsize=5, y=0.55, x=0.04, ha="left", color=model_colors["adapt"])

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    #ax.set_axis_off()

for fig in figures:
    fig.tight_layout(h_pad=0.0, w_pad=0.0)

if do_save_plots:
    for i, fig in enumerate(figures):
        fig.savefig(pj(panels_folder, "resources", "adaptation_cartoon_{}.pdf".format(i)), 
                bbox_inches="tight", transparent=True)
plt.show()
plt.close()

# Panel B: dynamics of epsilon

In [None]:
# Functions to load simulations results for a given epsilon from within a larger npz archive
def load_ibcm_adapt_simul(fp):
    simul = {
        "cbars_gamma_ser": fp.get("cbars_gamma_ser"),
        "bkvec_ser": fp.get("bkvec_ser"),
        "back_components": fp.get("back_components"),
        "y_norm_ser": fp.get("y_norm_ser"),
        "conc_ser": fp.get("conc_ser"),
        "mixed_new_odors": fp.get("mixed_new_odors"),
        "new_odors": fp.get("new_odors"),
        "epsilon_ser": fp.get("eps_ser"),
    }
    return simul

def load_biopca_adapt_simul(fp):
    simul = {
        "true_pca_vals": fp.get("true_pca_vals"),
        "learnt_pca_vals": fp.get("learnt_pca_vals"),
        "pca_align_error": fp.get("pca_align_error"),
        "bkvec_norm_ser": fp.get("bkvec_norm_ser"),
        "y_norm_ser": fp.get("y_norm_ser")
    }
    return simul

In [None]:
# Axes zoom effect from Matplotlib documentation
from matplotlib.transforms import (Bbox, TransformedBbox,
                                   blended_transform_factory)
from mpl_toolkits.axes_grid1.inset_locator import (BboxConnector,
                                                   BboxConnectorPatch,
                                                   BboxPatch)
def connect_bbox(bbox1, bbox2,
                 loc1a, loc2a, loc1b, loc2b,
                 prop_lines, prop_patches=None):
    if prop_patches is None:
        prop_patches = {
            **prop_lines,
            "alpha": prop_lines.get("alpha", 1) * 0.2,
            "clip_on": False,
        }

    c1 = BboxConnector(
        bbox1, bbox2, loc1=loc1a, loc2=loc2a, clip_on=False, **prop_lines)
    c2 = BboxConnector(
        bbox1, bbox2, loc1=loc1b, loc2=loc2b, clip_on=False, **prop_lines)

    bbox_patch1 = BboxPatch(bbox1, **prop_patches, color="grey")
    bbox_patch2 = BboxPatch(bbox2, **prop_patches, color="grey")

    p = BboxConnectorPatch(bbox1, bbox2,
                           loc1a=loc1a, loc2a=loc2a, loc1b=loc1b, loc2b=loc2b,
                           clip_on=False,
                           **prop_patches)

    return c1, c2, bbox_patch1, bbox_patch2, p

def zoom_effect01(ax1, ax2, xmin, xmax, **kwargs):
    """
    Connect *ax1* and *ax2*. The *xmin*-to-*xmax* range in both Axes will
    be marked.

    Parameters
    ----------
    ax1
        The main Axes.
    ax2
        The zoomed Axes.
    xmin, xmax
        The limits of the colored area in both plot Axes.
    **kwargs
        Arguments passed to the patch constructor.
    """

    bbox = Bbox.from_extents(xmin, 0, xmax, 1)

    mybbox1 = TransformedBbox(bbox, ax1.get_xaxis_transform())
    mybbox2 = TransformedBbox(bbox, ax2.get_xaxis_transform())

    prop_patches = {**kwargs, "ec": "none", "alpha": 0.2}

    c1, c2, bbox_patch1, bbox_patch2, p = connect_bbox(
        mybbox1, mybbox2,
        loc1a=3, loc2a=2, loc1b=4, loc2b=1,
        prop_lines=kwargs, prop_patches=prop_patches)

    #ax1.add_patch(bbox_patch1)
    ax2.add_patch(bbox_patch2)
    ax2.add_patch(c1)
    ax2.add_patch(c2)
    ax2.add_patch(p)

    return c1, c2, bbox_patch1, bbox_patch2, p

In [None]:
def moving_average(points, kernelsize, boundary="free"):
    r""" Moving average filtering on the array of experimental points,
    averages over a block of size kernelsize.
    kernelsize should be an odd number; otherwise,
    the odd number just lower is used.
    The ith smoothed value, S_i, is:
        $$ S_i = \frac{1}{kernelsize} \sum_{j = i-kernelsize//2}^{i + kernelsize//2} x_j $$
    Values at the boundary are smoothed with smaller and smaller kernels
    (up to size 1 for boundary values)

    Args:
        points (1darray): the experimental data points
        kernelsize (int): odd integer giving the total number of points summed
        boundary (str): how to deal with points within kernelsize//2 of edges
            "shrink": the window for a point within distance d < w
                is shrunk symmetrically to a kernel of size d
            "free": the window is asymmetric, full on the inside and clipped
                on the side near the edge.
            "noflux": these points are set to the value of the closest point
                with full window (i.e. distance kernelsize//2 of the edge)

    Returns:
        smoothed (ndarray): the smoothed data points.
    """
    smoothed = np.zeros(points.shape)
    if kernelsize % 2 == 0:  # if an even number was given
        kernelsize -= 1
    w = kernelsize // 2  # width
    end = smoothed.shape[0]  # index of the last element

    if boundary not in ["shrink", "free", "noflux"]:
        raise ValueError("Unknown boundary {}".format(boundary))

    # Smooth the middle points using slicing.
    smoothed[w:end - w] = points[w:end - w]
    for j in range(w):  # Add points around the middle one
        smoothed[w:-w] += points[w - j - 1:end - w - j - 1] + points[w + j + 1:end - w + j + 1]

        # Use the loop to treat the two points at a distance j from boundaries
        if j < w and boundary == "shrink":
            smoothed[j] = np.sum(points[0:2*j + 1], axis=0) / (2*j + 1)
            smoothed[-j - 1] = np.sum(points[-2*j - 1:], axis=0)/(2*j + 1)
        elif j < w and boundary == "free":
            smoothed[j] = np.sum(points[0:j + w + 1], axis=0) / (j + w + 1)
            smoothed[-j - 1] = np.sum(points[-j - w - 1:], axis=0) / (j + w + 1)

    # Normalize the middle points
    smoothed[w:end - w] = smoothed[w:end - w] / kernelsize

    # If noflux boundary, set edge points
    if boundary == "noflux":
        smoothed[:w] = smoothed[w]
        smoothed[-w:] = smoothed[-w - 1]

    return smoothed

In [None]:
# Check what individual OSN epsilons are doing on fast and long time scales
def plot_eps_series(tser, eps_ser, n_shown, tzoom_interv, skp_n=1, smoothsize=201):
    fig = plt.figure()
    gs = fig.add_gridspec(2, 4)
    axes = [fig.add_subplot(gs[0, :3])]
    axes.append(fig.add_subplot(gs[1, :3], sharey=axes[0]))
    axleg = fig.add_subplot(gs[:, 3])
    
    fig.set_size_inches(plt.rcParams["figure.figsize"][0], plt.rcParams["figure.figsize"][1])
    eps_ser_smooth = moving_average(eps_ser, kernelsize=smoothsize, boundary="free")
    tscale = 10.0 / 1000.0 / 60.0  # minutes

    osn_colors = sns.cubehelix_palette(n_colors=n_shown, 
            start=0.0, rot=1.0, gamma=1.0, hue=0.8, light=0.85, dark=0.15, reverse=True)

    tsl_local = slice(*tzoom_interv, 1)  # limit * skp * 10 = milliseconds, skp=50 default
    tsl_global = slice(0, None, 20)
    n_labels = 6
    skp_lbl = (n_shown // skp_n) // n_labels
    for i in range(0, n_shown, skp_n):
        lbl = "{}".format(i) if (i // skp_n) % skp_lbl == 0 else ""
        axes[0].plot(tser[tsl_local]*tscale, eps_ser[tsl_local, i], 
                     alpha=0.7, lw=0.75, color=osn_colors[i], label=lbl)
        axes[1].plot(tser[tsl_global]*tscale, eps_ser_smooth[tsl_global, i], 
                     alpha=0.7, lw=0.75, color=osn_colors[i])

    axes[0].set(ylabel=r"$\epsilon_i(t)$")
    axes[1].set(xlabel="Time (min)", ylabel=r"Smoothed $\epsilon_i(t)$")
    axes[0].set_xticks([5, 6, 7])
    axleg.legend(*axes[0].get_legend_handles_labels(), frameon=False, title="OSN")
    axleg.set_axis_off()
    t1, t2 = tser[tzoom_interv[0]]*tscale, tser[tzoom_interv[1]]*tscale
    zoom_effect01(axes[0], axes[1], t1, t2, lw=0.8)
    fig.tight_layout(h_pad=-0.1)
    return fig, axes, axleg

In [None]:
# Load simulations
with np.load(pj(data_folder_nl, "saved_ibcm_simulations_adapt_osn.npz")) as fp:
    ibcm_simul_example = load_ibcm_adapt_simul(fp)
with np.load(pj(data_folder_nl, "saved_biopca_simulations_adapt_osn.npz")) as fp:
    biopca_simul_example = load_biopca_adapt_simul(fp)

In [None]:
eps_ser_ibcm = ibcm_simul_example["epsilon_ser"]
chosen_smoothsize = 101
tser = np.arange(0.0, 360000.0, 1.0*(360000//eps_ser_ibcm.shape[0]))
fig, axes, axleg = plot_eps_series(tser, eps_ser_ibcm, 50, (600, 840), 
                                   skp_n=4, smoothsize=chosen_smoothsize)
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_adapt_osn_epsilon_dynamics.pdf"), 
               transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
# Moving average smoothing time window, in s
# each unit step is 10 ms, tser has skipped steps (t[1]-t[0] unit steps per array element)
avg_window_s = chosen_smoothsize * (tser[1]-tser[0]) * 10.0 / 1000.0  # in s
print("Averaging time window in s:", avg_window_s)

# Panel C: weakly curved manifold as a result

In [None]:
# Since we will make similar plots, define functions
def plot_manifold(bkser, bkvecs, conc_ser, view_params, 
                  mixed_new_odors=None, new_odor_vec=None, dims=(0, 1, 2)):
    # Plot 2D manifold in a 3D slice,
    fig = plt.figure()
    ax = fig.add_subplot(projection="3d")
    # Too many combinations for 6 odors, maybe just highlight
    # single-odor axes
    where_each = (conc_ser > 0).astype(bool)
    n_odors = where_each.shape[1]
    locations = {}
    # Track places with 0 or 1 odor
    any_single_odor = np.all(where_each == False, axis=1)  # start with places with 0 odor
    for i in range(n_odors):
        mask = np.zeros((1, n_odors), dtype=bool)
        mask[0, i] = True
        locations["Odor {}".format(i)] = np.all(where_each == mask, axis=1)
        any_single_odor += locations["Odor {}".format(i)]  # add places with odor i only
    locations["2+ odors"] = np.logical_not(any_single_odor)  # 2+ odors anywhere else
    single_odor_colors = sns.color_palette("colorblind", n_colors=n_odors)
    all_colors = {"Odor {}".format(i): single_odor_colors[i] for i in range(n_odors)}
    all_colors["2+ odors"] = "grey"
    
    orig = np.zeros([3, 6])
    locations_order = ["2+ odors"] + ["Odor {}".format(i) for i in range(n_odors)]
    for lbl in locations_order:
        alpha = 0.3 if lbl.startswith("2+") else 1.0
        slc = locations[lbl]
        tskp = 5 if lbl.startswith("2+") else 1
        zshift = 0.03 if lbl.startswith("2+") else 0.0
        shift = 0.03 if lbl.startswith("2+") else 0.0
        lbl_append = ""# if lbl.startswith("2+") else " alone"
        bk_subset = [bkser[slc, d].copy() + shift for d in dims]
        bk_subset[2] -= zshift
        ax.scatter(bk_subset[0][::tskp], bk_subset[1][::tskp], bk_subset[2][::tskp], 
                   s=4, lw=0.3, label=lbl+lbl_append, color=all_colors[lbl], alpha=alpha)
    vecs = bkvecs / l2_norm(bkvecs, axis=1)[:, None]
    print(vecs.shape)
    ax.quiver(*orig, *(vecs[:, dims].T), color="k", lw=1.5, arrow_length_ratio=0.2)
    ax.scatter(0, 0, 0, color="k", s=25)
    
    # Also show what adding a new odor can do -- out of the manifold
    new_odor_lbl = "+ new odor"
    if mixed_new_odors is not None:
        n_new_odors = mixed_new_odors.shape[0]
        new_odors_palette = sns.dark_palette("r", n_colors=n_new_odors+1)[1:]
        for i in range(n_new_odors):
            lbl = new_odor_lbl + " {}".format("abcdefghijklmnop".upper()[i])
            all_colors[lbl] = new_odors_palette[i]
            ax.scatter(mixed_new_odors[i, :, dims[0]], mixed_new_odors[i, :, dims[1]], 
                        mixed_new_odors[i, :, dims[2]], s=6, lw=0.3, 
                        label=lbl+lbl_append, color=all_colors[lbl], alpha=1.0)
            if new_odor_vec is not None:
                vec = new_odor_vec[i] / l2_norm(new_odor_vec[i])
                ax.quiver(*orig, *(vec[list(dims)]), color=all_colors[lbl], 
                          lw=1.5, arrow_length_ratio=0.2)
    # No new odors shown
    else:
        n_new_odors = 0
        

    # Labeling
    for lbl, f in enumerate([ax.set_xlabel, ax.set_ylabel, ax.set_zlabel]):
        # z label gets caught in the zlbl variable at the last iteration
        zlbl = f("OSN {} (of {})".format(lbl+1, bkser.shape[1]), labelpad=-17.5)
    for f in [ax.set_xticks, ax.set_yticks, ax.set_zticks]:
        f([])
    for f in [ax.set_xticklabels, ax.set_yticklabels, ax.set_zticklabels]:
        f([], pad=0.1)
    view_params.setdefault("azim", 240)
    view_params.setdefault("elev", 3)
    ax.view_init(**view_params)
    handles, labels = ax.get_legend_handles_labels()
    # Move the label for 2+ odors to before the new odors
    if n_new_odors > 0:
        handles.insert(-n_new_odors-1, handles[0])
        labels.insert(-n_new_odors-1, labels[0])
    else:
        handles.append(handles[0])
        labels.append(labels[0])
    handles.pop(0)
    labels.pop(0)
    leg = ax.legend(handles=handles, labels=labels, 
        frameon=True, ncol=1, loc="upper left", bbox_to_anchor=(0.85, 1.0), 
        title="Odor presence", title_fontsize=6)
    #loc="upper right", bbox_to_anchor=(0.0, 1.0), frameon=False)
    fig.tight_layout()

    # Need to adjust the tightbox to remove whitespace above and below manually. 
    #ax.set_aspect("equal")
    fig.tight_layout()
    tightbox = fig.get_tightbbox()
    tightbox._bbox.y0 = tightbox._bbox.y0*1.1   #bottom
    tightbox._bbox.y1 = tightbox._bbox.y1 + 0.7*tightbox._bbox.y0  # top
    tightbox._bbox.x0 = tightbox._bbox.x0 * 0.6  # position of left side

    return fig, ax, tightbox

In [None]:
# Inspect the background process
def pairplots_background(bkser, bkvecs, epsser=None, mixed_new_odors=None, new_odor_vec=None):
    # Background vectors time series with mixed concentrations
    tslice = slice(0, None, 30)
    
    # Scale odor affinities K_{i \gamma} by the average saturation threshold 
    # of the OSN, exp(\epsilon_i), to get the effective affinity scale, 
    # before normalizing each odor vector K_\gamma
    vecs = bkvecs / l2_norm(bkvecs, axis=1)[:, None]
    if epsser is not None:
        mean_eps = np.mean(epsser, axis=0)
        vecs_eff = bkvecs / np.exp(mean_eps[None, :])
        vecs_eff = vecs_eff / l2_norm(vecs_eff, axis=1)[:, None]
    else:
        vecs_eff = None
        
    n_comp = bkvecs.shape[0]
    n_cols = 6
    n_plots = 48 // 2
    n_rows = n_plots // n_cols + min(1, n_plots % n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True)
    fig.set_size_inches(n_cols*1.0, n_rows*1.0)
    single_odor_colors = sns.color_palette("colorblind", n_colors=n_comp)
    all_colors = {"Odor {}".format(i): single_odor_colors[i] for i in range(n_comp)}
    if mixed_new_odors is not None:
        n_new_odors = mixed_new_odors.shape[0]
        new_odors_palette = sns.dark_palette("r", n_colors=n_new_odors+1)[1:]
    for i in range(n_plots):
        ax = axes.flat[i]
        ax.scatter(bkser[tslice, 2*i+1], bkser[tslice, 2*i], 
                   s=9, alpha=0.5, color="k")
        for j in range(n_comp):
            ax.plot(*zip([0.0, 0.0], vecs[j, 2*i:2*i+2][::-1]), lw=1.5, 
                    color=single_odor_colors[j])
            if epsser is not None:
                ax.plot(*zip([0.0, 0.0], vecs_eff[j, 2*i:2*i+2][::-1]), lw=1.5, ls="--",
                    color=single_odor_colors[j])
        if mixed_new_odors is not None:
            for j in range(n_new_odors):
                clr = new_odors_palette[j]
                ax.scatter(mixed_new_odors[j, :, 2*i+1], mixed_new_odors[j, :, 2*i], 
                       s=6, alpha=1.0, color=clr)
        if new_odor_vec is not None:
            for j in range(n_new_odors):
                vec = new_odor_vec[j] / l2_norm(new_odor_vec[j])
                ax.plot(*zip([0.0, 0.0], vec[2*i:2*i+2][::-1]), lw=2.0, 
                    color=new_odors_palette[j])
        ax.set(xlabel="OSN {}".format(2*i+2), ylabel="OSN {}".format(2*i+1))
    for i in range(n_plots, n_rows*n_cols):
        axes.flat[i].set_axis_off()
    fig.tight_layout()
    
    return fig, axes

In [None]:
fig, ax, box = plot_manifold(ibcm_simul_example["bkvec_ser"], ibcm_simul_example["back_components"], 
                    ibcm_simul_example["conc_ser"], {"elev":30, "azim":210}, dims=(0, 1, 2))
#                    mixed_new_odors=ibcm_simul["mixed_new_odors"][:1], 
#                    new_odor_vec=ibcm_simul["new_odors"][:1])
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_adapt_osn_curved_manifold.pdf"), 
                transparent=True, bbox_inches=box)
plt.show()
plt.close()

In [None]:
fig, axes = pairplots_background(ibcm_simul_example["bkvec_ser"][::2], 
                    ibcm_simul_example["back_components"], ibcm_simul_example["epsilon_ser"][::2])

if do_save_plots:
    fig.savefig(pj(panels_folder, "reviewer_figure_nonlinear_manifold_pairplot_adapt_osn.pdf"), 
                transparent=True, bbox_inches=box)
plt.show()
plt.close()

# Panel D: Odor recognition performance with OSN adaptation

In [None]:
def hist_outline(ax, bins, height, **plot_kwargs):
    plot_hist = np.stack([height, height], axis=1).flatten()
    plot_edges = np.stack([bins[:-1], bins[1:]], axis=1).flatten()
    ax.plot(plot_edges, plot_hist, **plot_kwargs)
    ax.fill_between(plot_edges, min(0.0, height.min()), plot_hist,
                    color=plot_kwargs.get("color"), alpha=0.3)
    return ax

# Plot jaccard similarities
def plot_jaccards(jac_histograms, jac_stats):
    # Plot model histogram results for one new odor concentration
    # histograms: heights, bins
    # stats: mean, median, var
    fig, ax = plt.subplots()
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.95, 
                       plt.rcParams["figure.figsize"][1])
    models = [m for m in ["none", "adapt", "optimal_adapt", "biopca_adapt", "ibcm_adapt"] 
              if m in jac_histograms.keys()]
    for m in models:  # Plot IBCM last
        hist_outline(
            ax, jac_histograms[m][1], jac_histograms[m][0],
            label=model_nice_names.get(m, m),
            color=model_colors.get(m), alpha=1.0
        )
        ax.axvline(
            jac_stats[m][1], ls="--",
            color=model_colors.get(m)
        )
    # Labeling the graph, etc.
    ax.set_xlabel("Jaccard similarity (higher is better)")
    ax.set_ylabel("Probability density")
    leg = ax.legend(loc="upper left", bbox_to_anchor=(0.6, 1.025), frameon=False, 
                   borderaxespad=0.0)
    return fig, ax, leg

In [None]:
# Simulations at new concentration = 0.5<c>, across 96 background seeds, 
# 100 new odors against each, tested at 100 test times per simulation. 
conc_choice = 0.5
conc_str = str(conc_choice).replace(".", "-")
with np.load(pj(data_folder_nl, f"osn_adaptation_odor_recognition_results_{conc_str}.npz")) as res:
    all_jacs = {m:res[m] for m in res.keys()}

# Compute interesting statistics
show_models = ["none", "adapt", "biopca_adapt", "ibcm_adapt", "optimal_adapt"]
jac_histograms = {}
jac_cdfs = {}
jac_stats = {}
for m in show_models:
    jacs_sim = all_jacs[m].flatten()
    jac_histograms[m] = np.histogram(jacs_sim, bins="doane", density=True)
    jacs_dists = 1.0 - jacs_sim
    # There is only a discrete number of possible J, increments of card(z_n \cap z_mix)
    # So count each value
    dists_axis, dists_counts = np.unique(jacs_dists, return_counts=True)
    reorder = np.argsort(dists_axis)
    dists_axis = dists_axis[reorder]
    dists_counts = dists_counts[reorder] / jacs_dists.size
    dists_cdf = np.cumsum(dists_counts) 
    jac_cdfs[m] = dists_cdf, dists_axis
    jac_stats[m] = [
        np.mean(jacs_sim), 
        np.median(jacs_sim), 
        np.var(jacs_sim),
    ]    

In [None]:
fig, ax, leg = plot_jaccards(jac_histograms, jac_stats)
ax.set_title("New odor concentration = {:.1f}".format(conc_choice) + r"$\langle c \rangle$", 
            y=0.98)
fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, f"supfig_adapt_osn_odor_recognition_jaccards_{conc_str}.pdf"), 
               transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()

In [None]:
# Additional panel to look at the new odor conc c = 1.0
conc_choice2 = 1.0
conc_str2 = str(conc_choice2).replace(".", "-")
with np.load(pj(data_folder_nl, f"osn_adaptation_odor_recognition_results_{conc_str2}.npz")) as res:
    all_jacs2 = {m:res[m] for m in res.keys()}

# Compute interesting statistics
show_models = ["none", "adapt", "biopca_adapt", "ibcm_adapt", "optimal_adapt"]
jac_histograms2 = {}
jac_cdfs2 = {}
jac_stats2 = {}
for m in show_models:
    jacs_sim = all_jacs2[m].flatten()
    jac_histograms2[m] = np.histogram(jacs_sim, bins="doane", density=True)
    jacs_dists = 1.0 - jacs_sim
    # There is only a discrete number of possible J, increments of card(z_n \cap z_mix)
    # So count each value
    dists_axis, dists_counts = np.unique(jacs_dists, return_counts=True)
    reorder = np.argsort(dists_axis)
    dists_axis = dists_axis[reorder]
    dists_counts = dists_counts[reorder] / jacs_dists.size
    dists_cdf = np.cumsum(dists_counts) 
    jac_cdfs2[m] = dists_cdf, dists_axis
    jac_stats2[m] = [
        np.mean(jacs_sim), 
        np.median(jacs_sim), 
        np.var(jacs_sim),
    ]    

In [None]:
fig, ax, leg = plot_jaccards(jac_histograms2, jac_stats2)
ax.set_title("New odor concentration = {:.1f}".format(conc_choice) + r"$\langle c \rangle$", 
            y=0.98)
fig.tight_layout()
if do_save_plots:
    fig.savefig(pj(panels_folder, f"reviewer_figure_adapt_osn_odor_recognition_jaccards_{conc_str2}.pdf"), 
               transparent=True, bbox_inches="tight", bbox_extra_artists=(leg,))
plt.show()
plt.close()

# Panel E: IBCM convergence for moderate nonlinearity

## IBCM plotting functions

In [None]:
def plot_ibcm_cgammas_series(t_axis, i_highlights, cgammaser):
    # Show three neurons
    fig = plt.figure()

    gs = fig.add_gridspec(3, 3)
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.65, 
                        plt.rcParams["figure.figsize"][1])
    ax = fig.add_subplot(gs[:, :3])
    axi = None

    #ax.axhline(0.0, ls="-", color=(0.8,)*3, lw=0.8)
    legend_styles = [[0,]*6, [0,]*6, [0,]*6]
    neuron_colors3 = neuron_colors_full[[8, 17, 23]]
    clr_back = back_palette[-1]
    plot_skp = 20
    n_b = cgammaser.shape[2]
    n_i_ibcm_loc = cgammaser.shape[1]

    # plot all other neurons first, skip some points
    for i in range(n_i_ibcm_loc):
        if i in i_highlights: 
            continue
        elif i % 2 == 0:   # thinning
            continue
        else: 
            for j in range(n_b):
                ax.plot(t_axis[::plot_skp], cgammaser[::plot_skp, i, j], color=clr_back, 
                    ls="-", alpha=1.0-0.1*j, lw=plt.rcParams["lines.linewidth"]-j*0.1)

    # Now plot the highlighted neuron
    for j in range(n_b):
        for i in range(len(i_highlights)):
            li, = ax.plot(t_axis[::plot_skp], cgammaser[::plot_skp, i_highlights[i], j], 
                          color=neuron_colors3[i], ls="-", alpha=1.0-0.1*j, 
                          lw=plt.rcParams["lines.linewidth"]-j*0.1)
            legend_styles[i][j] = li
    
    # Annotations
    ax.set(xlabel="Time (min)", 
           ylabel=r"Alignments $\bar{h}_{i\gamma} = \mathbf{\bar{m}}_i \cdot \mathbf{s}_{\gamma}$")

    gs.tight_layout(fig, w_pad=-0.2)
    return fig, ax, axi

In [None]:
def plot_ibcm_cgammas_matrix(cgammas_mat, i_high):
    n_i, n_comp = cgammas_mat.shape
    neuron_colors3 = neuron_colors_full[[8, 17, 23]]
    
    fig, ax = plt.subplots()
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.4, plt.rcParams["figure.figsize"][1])
    # Extent: left, right, bottom, top
    # Greyscale version
    #ax.imshow(cgammas_matrix, cmap="Greys", aspect=0.6, extent=(0.5, n_comp+0.5, 0.5, n_i))
    #ax.set_xticks(list(range(1, n_comp+1)))
    # Colorful version: add patches manually with fill_between. 
    # Color highlighted neurons, leave others grayscale!
    normed_matrix = (cgammas_mat - cgammas_mat.min()) / (cgammas_mat.max() - cgammas_mat.min())
    for i in range(n_i):
        # Full rainbow version
        #cmap = sns.light_palette(neuron_colors_full[i], as_cmap=True)
        # Version where only highlights are colored
        if i in i_high:
            cmap = sns.light_palette(neuron_colors3[i_high.index(i)], as_cmap=True)
        else:
            cmap = sns.color_palette("Greys", as_cmap=True)
        for j in range(n_comp):
            ax.fill_between([-0.5+j, 0.5+j], -0.5+i, 0.5+i, color=cmap(normed_matrix[i, j]))

    ax.set_xlim([-0.6, -0.6+n_comp])
    ax.set_ylim([-0.6, -0.6+n_i])
    ax.set_xticks([0, 3, 5])
    ax.set_yticks(list(range(0, 24, 2)))

    for i, lbl in enumerate(ax.get_yticklabels()):
        if int(lbl.get_text()) in i_high:
            clr = neuron_colors3[i_high.index(int(lbl.get_text()))]
            lbl.set_color(clr)
            ax.yaxis.get_ticklines()[i].set_color(clr)

    ax.set_xlabel(r"Component $\gamma$", size=6)
    ax.set_ylabel(r"IBCM neuron index $i$", size=6)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(cgammas_mat.min(), cgammas_mat.max()), 
        cmap="Greys"), ax=ax, aspect=30, pad=0.1)
    cbar.set_ticks([])
    cbar.set_label(label=r"Alignments ${\bar{h}}_{i\gamma}$, 45 min", fontsize=6)
    fig.tight_layout()
    return fig, ax, cbar


## IBCM plot

In [None]:
tu_scale = 1.0 / 100.0 / 60.0  # 10 ms steps to min
cgammaser = ibcm_simul_example["cbars_gamma_ser"]
n_i_ibcm = cgammaser.shape[1]
nsteps = cgammaser.shape[0]
n_components = cgammaser.shape[2]
tser_example = np.arange(0.0, 360000.0, 360000 / nsteps) * tu_scale  # in min

transient = int(3*tser_example.size/4)
i_highlights = [2, 8, 18]  # Neurons to highlight

cgammas_matrix = np.mean(cgammaser[transient:], axis=0)

In [None]:
fig, ax, axi = plot_ibcm_cgammas_series(tser_example, i_highlights, cgammaser)
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_adapt_osn_ibcm_cgamma_series.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax, axi = plot_ibcm_cgammas_matrix(cgammas_matrix, i_highlights)
if do_save_plots:
    fig.savefig(pj(panels_folder, "supfig_adapt_osn_ibcm_cgamma_matrix.pdf"), 
                transparent=True, bbox_inches="tight")
plt.show()
plt.close()

# Panel F: BioPCA convergence for moderate nonlinearity

In [None]:
def plot_biopca_convergence(t_axis, true_pvs, learnt_pvs):
    # First plot: eigenvalues
    n_comp = learnt_pvs.shape[1]
    pca_palette = sns.color_palette("colorblind", n_colors=n_comp)
    fig, ax = plt.subplots()
    fig.set_size_inches(plt.rcParams["figure.figsize"][0]*0.95, 
                        plt.rcParams["figure.figsize"][1])

    for i in range(n_comp):
        li, = ax.plot(t_axis, learnt_pvs[:, i], label="Value {}".format(i),
                      lw=plt.rcParams["lines.linewidth"] - 0.5*i/n_comp, zorder=10-i, color=pca_palette[i])
        if true_pvs[i] / true_pvs.max() > 1e-12:
            ax.axhline(true_pvs[i], ls="--", color=pca_palette[i], 
                       lw=plt.rcParams["lines.linewidth"] - 0.5*i/n_comp, zorder=n_comp-i)
    ax.set(ylabel="Principal values, diag$(L^{-1})$", yscale="log", xlabel="Time (min)")
    handles = [mpl.lines.Line2D([0], [0], color="grey", ls="-", label=r"BioPCA ($L^{-1}$ diagonal)", 
                                lw=plt.rcParams["lines.linewidth"]), 
              mpl.lines.Line2D([0], [0], color="grey", ls="--", label="True PCA", 
                              lw=plt.rcParams["lines.linewidth"])]
    leg = ax.legend(handles=handles, frameon=False, fontsize=5.0)
    leg.set_zorder(30)
    ax.set_ylim([ax.get_ylim()[0]*0.8, ax.get_ylim()[1]])

    fig.tight_layout()
    return fig, ax, handles

In [None]:
true_pca_vals = biopca_simul_example["true_pca_vals"]
learnt_pca_vals = biopca_simul_example["learnt_pca_vals"]
pca_align_error = biopca_simul_example["pca_align_error"]

In [None]:
# BioPCA
fig, ax, handles = plot_biopca_convergence(tser_example, true_pca_vals, learnt_pca_vals)
# Need to adjust legend location here, not sure why
ax.legend(handles=handles, frameon=False, fontsize=5.0, 
          loc="upper left", bbox_to_anchor=(0.35, 0.68))
if do_save_plots:
    fig.savefig(os.path.join(panels_folder, "supfig_adapt_osn_biopca_eigenvalues.pdf"), 
            transparent=True, bbox_inches="tight")
plt.show()
plt.close()