This notebook is used to generate some example plots used in Figure 2.

In [None]:
import sys
sys.path.insert(0, "../scripts")

In [None]:
%env NBANK_REGISTRY https://gracula.psyc.virginia.edu/neurobank

In [None]:
import json
import logging
from pathlib import Path
from functools import partial

import ewave
import numpy as np
from numpy.random import default_rng
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from dlab import pprox, nbank, spikes, plotting, signal

import graphics_defaults
from core import MotifSplitter, split_trials, trial_to_spike_train, pairwise_spike_comparison

rng = default_rng()

In [None]:
rate_binwidth = 0.005
rate_bandwidth = 0.02
kernel, _ = signal.kernel("gaussian", rate_bandwidth, rate_binwidth)

In [None]:
# weakly selective
unit_name = "C104_3_1_c67"
selected_motifs = ["g29wxi4q", "vekibwgj", "jkexyrd5"]

In [None]:
# selective
unit_name = "C42_4_1_c131"
selected_motifs = ["g29wxi4q", "vekibwgj", "ztqee46x"]

In [None]:
# high firing rate, temporal patterning, quite invariant
unit_name = "C104_4_1_c120"
selected_motifs = ["g29wxi4q", "vekibwgj", "ztqee46x"]

In [None]:
pprox_file = nbank.find_resource(unit_name)
unit = json.loads(pprox_file.read_text())

In [None]:
splitter = MotifSplitter()
motifs = split_trials(splitter, unit).drop("igmi8fxa", level=1)
motif_names = motifs.index.unique(level="foreground")
wav_signals = {}
for name, location in nbank.find_resources(*motif_names):
    with ewave.wavfile(location, "r") as fp:
        wav_signals[name] = (fp.read(), fp.sampling_rate)

In [None]:
n_motifs = len(selected_motifs)
# plot each noise level in a different color
colors = {
    v: c for v, c in zip(motifs.index.unique(level=0), plt.color_sequences["tab20"])
}
fig = plt.figure(figsize=(2.4, 2.9), dpi=300)
subfigs = fig.subfigures(1, n_motifs, hspace=0.001, wspace=0.0001)
for motif, subfig in zip(selected_motifs, subfigs):
    trials = motifs.xs(motif, level="foreground")
    axes = subfig.subplots(3, sharex=True, height_ratios=[1, 5, 1])
    plotting.spectrogram(axes[0], frequency_range=(0, 8000), *wav_signals[motif])
    axes[0].set_yticks([500, 8000], ["1", "8"])
    for i, trial in enumerate(trials.sort_index(ascending=False).itertuples()):
        if isinstance(trial.events, float):
            continue
        background_level = trial.Index
        axes[1].plot(
            trial.events,
            [i] * trial.events.size,
            color=colors[background_level],
            marker="|",
            markeredgewidth=0.5,
            linestyle="",
        )
    axes[1].set_ylim(0, trials.shape[0])
    axes[1].get_yaxis().set_visible(False)
    plotting.adjust_raster_ticks(axes[1], gap=3.2)
    for lvl, trls in trials.sort_index(ascending=False).groupby("background-dBFS"):
        rate, bins = spikes.rate(
            trls.events.dropna().explode(),
            rate_binwidth,
            kernel,
            start=0,
            stop=trials.interval_end.max(),
        )
        axes[2].plot(bins, rate, color=colors[lvl])
    plotting.simple_axes(*axes)
    #subfig.subplots_adjust(hspace=0.01)

max_rate = max(subfig.axes[2].get_ylim()[1] for subfig in subfigs.flat)
for subfig in subfigs:
    subfig.axes[2].set_ylim((0, max_rate))
    subfig.subplots_adjust(left=0.05, right=0.95, hspace=0.08)
for subfig in subfigs[1:]:
    for ax in subfig.axes:
        ax.get_yaxis().set_visible(False)
        ax.spines["left"].set_visible(False)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_rasters.pdf")

## Motif discrimination

In [None]:
import pyspike

def inv_spike_sync_matrix(*args, **kwargs):
    return 1 - pyspike.spike_sync_matrix(*args, **kwargs)

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import cross_validate, LeaveOneOut, cross_val_score, cross_val_predict

n_neighbors = 9

class ShuffledLeaveOneOut(LeaveOneOut):
    
    def __init__(self, rng):
        super().__init__()
        self.rng = rng
        
    def split(self, *args, **kwargs):
        for train, test in super().split(*args, **kwargs):
            yield self.rng.permutation(train), test
            
def kneighbors_classifier(distance_matrix, rng, normalize="true"):
    """Compute confusion matrix of a k-neighbors classifier on the spike distance matrix"""
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
    loo = ShuffledLeaveOneOut(rng)
    groups = distance_matrix.index
    names = groups.unique()
    group_idx, _ = pd.factorize(groups)
    cv_results = cross_val_score(neigh, distance_matrix.values, group_idx, cv=loo)
    pred = cross_val_predict(neigh, distance_matrix.values, group_idx, cv=loo)
    cm = confusion_matrix(group_idx, pred, normalize=normalize)
    return pd.DataFrame(cm, index=names, columns=names)


In [None]:
spike_trains = motifs.apply(
    partial(trial_to_spike_train, interval_end=motifs.interval_end.min()), axis=1
)

These are the spike distance matrixes for training. Only the bottom one is used in the paper, because we use the testing distance matrices (comparing responses to noisy stimuli against the responses to clean stimuli) for the 30 dB and -10 dB levels.

In [None]:
bkgnd_levels = (-25, -100)
figsize_distances = (1.5, 2.5)
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=figsize_distances, dpi=400)
for ax, bkgnd_level in zip(axes, bkgnd_levels):
    st = spike_trains.loc[bkgnd_level]
    dist = inv_spike_sync_matrix(st)
    img = ax.imshow(1 - dist, vmin=0, vmax=1, aspect="equal", origin="upper", interpolation=None)
    for x in range(10, 90, 10):
        ax.axvline(x, color="w", linewidth=0.5)
        ax.axhline(x, color="w", linewidth=0.5)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)
fig.colorbar(img, ax=axes, location="bottom", shrink=0.3, aspect=10)    

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_distances_training.pdf")

In [None]:
figsize_discrim = (1, 1)
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=figsize_discrim, dpi=300)
for ax, bkgnd_level in zip(axes, bkgnd_levels):
    spike_dists = pairwise_spike_comparison(spike_trains.loc[bkgnd_level], comparison_fun=inv_spike_sync_matrix, stack=False)
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
    loo = ShuffledLeaveOneOut(rng)
    group_idx, names = spike_dists.index.factorize()
    pred = cross_val_predict(neigh, spike_dists.values, group_idx, cv=loo)
    conf_mtx = confusion_matrix(group_idx, pred, normalize="true")    
    img = ax.imshow(conf_mtx, origin="upper", aspect="equal", vmin=0, vmax=1.0)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)
    ax.set_title(f"{bkgnd_level} dB", fontdict={"fontsize": 6})
fig.colorbar(img, ax=axes, location="bottom", shrink=0.4, aspect=10)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_discrim_training.pdf")

In [None]:
# save the predictions from the training
pred_training = pd.concat([pd.Series(names[pred], index=spike_dists.index).rename_axis("foreground").rename("predicted")], keys=[-100], names=["background-dBFS"])

## Invariance

In this approach, we fit the classifier to the clean data and then see how accurately it classifies responses to the noisy stimuli.

In [None]:
st_train = spike_trains.loc[-100]
train = pairwise_spike_comparison(st_train, comparison_fun=inv_spike_sync_matrix, stack=False)
neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
group_idx, names = train.index.factorize()
neigh.fit(train.values, group_idx)

For each trial not in the training set, calculate the spike distance to all the trials in the training set

In [None]:
def compare_to_training(st_test):
    return st_train.apply(lambda st_ref: 1 - pyspike.spike_sync(st_ref, st_test.spikes)).rename_axis("ref")
dist_to_clean = spike_trains.drop(-100).to_frame("spikes").apply(compare_to_training, axis=1)
predicted = pd.Series(names[neigh.predict(dist_to_clean.values)], index=dist_to_clean.index).rename("predicted")

In [None]:
bkgnd_levels = (-60, -25)
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=figsize_distances, dpi=400)
for ax, bkgnd_level in zip(axes, bkgnd_levels):
    dist = dist_to_clean.loc[bkgnd_level]
    img = ax.imshow(1 - dist, vmin=0, vmax=1, aspect="equal", origin="upper", interpolation=None)
    for x in range(10, 90, 10):
        ax.axvline(x, color="w", linewidth=0.3)
        ax.axhline(x, color="w", linewidth=0.3)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)
fig.colorbar(img, ax=axes, location="bottom", shrink=0.3, aspect=10)    

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_distances_testing.pdf")

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, figsize=figsize_discrim, dpi=300)
for ax, bkgnd_level in zip(axes, bkgnd_levels):
    pred = predicted.loc[bkgnd_level].reset_index()
    conf_mtx = confusion_matrix(pred["foreground"].values, pred["predicted"].values, normalize="true")
    img = ax.imshow(conf_mtx, origin="upper", aspect="equal", vmin=0, vmax=1.0)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)
    ax.set_title(f"{bkgnd_level} dB", fontdict={"fontsize": 6})
fig.colorbar(img, ax=axes, location="bottom", shrink=0.4, aspect=10)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_discrim_testing.pdf")

In [None]:
# we need to sub in the xvalidated accuracy for training
acc = (
    pd.concat([pred_training, predicted])
    .pipe(lambda ser: pd.Series(1.0 * (ser.index.get_level_values(-1) == ser), index=ser.index))
    .rename("correct")
)

This panel has to be hand-edited to add a gap between 70 and 35 dB SNR.

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(1, 1), dpi=300)
sns.pointplot(x="background-dBFS", y="correct", errorbar="se", capsize=0.3, data=acc.reset_index(), ax=axes)
levels = -30 - acc.index.unique(level=0)
idx_keep = [0] + list(range(1, levels.size, 2))
axes.axhline(1/9, color="k", linestyle=":")
axes.set_xticks(idx_keep, levels[idx_keep])
axes.set_xlabel("SNR (dB)")
axes.set_ylabel("p(correct)")
axes.set_ylim(0, 1.0)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_discrim_summary.pdf")