This notebook is used to generate some example plots of neural responses.

In [None]:
import sys
sys.path.insert(0, "../scripts")

In [None]:
%env NBANK_REGISTRY https://gracula.psyc.virginia.edu/neurobank

In [None]:
import json
import logging
from pathlib import Path
from functools import partial

import ewave
import numpy as np
from numpy.random import default_rng
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from dlab import pprox, nbank, spikes, plotting, signal
import statsmodels.api as sm
import statsmodels.formula.api as smf
import statsmodels.stats.multitest as smt

import graphics_defaults
from core import MotifSplitter, MotifBackgroundSplitter, split_trials

rng = default_rng()

In [None]:
rate_binwidth = 0.005
rate_bandwidth = 0.02
kernel, _ = signal.kernel("gaussian", rate_bandwidth, rate_binwidth)

In [None]:
area_names = {
    "deep": "L3/NCM",
    "intermediate": "L2a/L2b",
    "superficial": "L1/CM"
}
feature_file = Path("..") / "build" / "mean_spike_features.csv"
features = pd.read_csv(feature_file)[["unit", "spike"]]
features["site"] = features.unit.apply(lambda s: "_".join(s.split("_")[:-1]))
site_file = Path("..") / "inputs" / "recording_metadata.csv"
sites = pd.read_csv(site_file, index_col="site")
sites["area"] = pd.Categorical(sites["area"].apply(lambda s: area_names[s]), categories=["L2a/L2b", "L1/CM", "L3/NCM"], ordered=True)
bird_file = Path("..") / "inputs" / "bird_metadata.csv"
birds = pd.read_csv(bird_file, index_col="bird")
features = features.join(sites, on="site", how="inner").set_index("unit")

In [None]:
selected_motifs = ["0oq8ifcb","g29wxi4q", "vekibwgj", "ztqee46x"]
unit_names = [
    "C194_3_1_c126",
    "C104_4_1_c120",
    "C42_4_1_c131",
    "C42_4_1_c14",
    "C104_3_1_c201",
    "C44_3_1_c74",
]
features.loc[unit_names]

In [None]:
unit_responses = []
for unit_name, pprox_file in nbank.find_resources(*unit_names):
    unit = json.loads(pprox_file.read_text())
    splitter = MotifBackgroundSplitter()
    responses = (
        split_trials(splitter, unit)
        .reset_index()
        .rename(columns=lambda s: s.replace("-", "_"))
        .query("background_dBFS == -100 | foreground == 'silence'")
        .query("foreground != 'background'")
        .drop(["background", "foreground_dBFS", "offset"], axis=1)
    )
    responses["unit"] = unit_name
    unit_responses.append(responses)
motifs = pd.concat(unit_responses)
motifs["n_events"] = motifs.events.fillna("").apply(len)
motifs["rate"] = (motifs.n_events / motifs.interval_end)
motif_names = ["silence"] + list(set(motifs.foreground.unique()) - {"silence"})
motifs["foreground"] = pd.Categorical(motifs.foreground, categories=motif_names, ordered=True)

In [None]:
wav_signals = {}
for name, location in nbank.find_resources(*motif_names):
    if not isinstance(location, Path):
        wav_signals[name] = None
    else:
        with ewave.wavfile(location, "r") as fp:
            wav_signals[name] = (fp.read(), fp.sampling_rate)

In [None]:
n_motifs = len(selected_motifs)
n_units = len(unit_names)
df = motifs.set_index(["foreground", "unit"])
fig, axes = plt.subplots(nrows=1 + n_units, ncols=n_motifs, sharex=True, sharey="row", figsize=(3, 2.9), dpi=300)
for col, motif in zip(axes.T, selected_motifs):
    plotting.spectrogram(col[0], frequency_range=(0, 8000), *wav_signals[motif])
    col[0].set_yticks([0, 4000, 8000], ["0", "4", "8"])
    col[0].get_xaxis().set_visible(False)
    # col[0].set_title(motif)
    motif_trials = df.loc[motif]
    for row, unit in zip(col[1:], unit_names):
        trials = motif_trials.loc[unit]
        for i, trial in enumerate(trials.itertuples()):
            if isinstance(trial.events, float):
                continue
            row.plot(
                trial.events,
                [i] * trial.events.size,
                color="k",
                marker="|",
                markeredgewidth=0.2,
                markersize=1.2,
                linestyle="",
            )
        row.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
# much axis annoyance
for i in range(n_motifs):
    for boundary in ("right", "bottom", "top"):
        axes[0, i].spines[boundary].set_visible(False)
    axes[0, i].get_xaxis().set_visible(False)
    if i > 0:
        axes[0, i].get_yaxis().set_visible(False)
        axes[0, i].spines["left"].set_visible(False)
for i, unit_name in enumerate(unit_names, start=1):
    info = features.loc[unit_name]
    axes[i, 0].set_ylabel(f"{info.area}\n{info.spike}")
    if i < n_units:
        for ax in axes[i]:
            for boundary in ("left", "right", "bottom", "top"):
                ax.spines[boundary].set_visible(False)
            ax.get_xaxis().set_visible(False)
            #ax.set_frame_on(False)
    else:
        for ax in axes[i]:
            for boundary in ("left", "right","top"):
                ax.spines[boundary].set_visible(False)
fig.subplots_adjust(hspace=0.1, wspace=0.1)

In [None]:
fig.savefig(f"../figures/motif_rasters.pdf")

## Firing rates and selectivity

In [None]:
def rate_model(df):
    lm = smf.glm("n_events ~ foreground", data=df, family=sm.families.Poisson(), offset=np.log(df["interval_end"])).fit()
    conf_int = lm.conf_int()
    coefs = (
        pd.DataFrame({"stimulus": motif_names, 
                      "coef": lm.params, 
                      "std err": lm.bse, 
                      "pvalue": smt.multipletests(lm.pvalues, method="sidak")[1],
                      "coef_lcl": conf_int[0],
                      "coef_ucl": conf_int[1]
                     })
        .reset_index(drop=True)
        .set_index("stimulus")
    )
    coefs["responsive"] = (coefs.coef > 0) & (coefs.pvalue < 0.05)
    return coefs

In [None]:
spike_type_colors = ["#70549B", "#FF7F0E"]
example_units = features.loc[unit_names].query("area=='L3/NCM'").index
rates = motifs.set_index("unit").loc[example_units].groupby("unit").apply(rate_model)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(1, 2.2), dpi=300)
for i, unit_name in enumerate(example_units):
    coefs = rates.loc[unit_name]
    spont = coefs.loc["silence"]
    evoked = coefs.iloc[1:].sort_values(by="coef", ascending=False)
    evoked["rank"] = np.arange(evoked.shape[0]) + 1
    evoked["pred"] = np.exp(evoked.coef + spont.coef)
    responsive = evoked.loc[evoked.responsive]
    axes[i].axhline(np.exp(spont['coef']), color="black", linestyle='--')
    axes[i].plot(evoked["rank"], evoked.pred, 'o-', markersize=2.5, markerfacecolor="white", markeredgewidth=0.2, color=spike_type_colors[i])
    axes[i].plot(responsive["rank"], responsive.pred, 'o', markersize=1.5, color=spike_type_colors[i])
    axes[i].set_title(unit_name, fontsize=3, pad=0, loc="right")
    axes[i].set_ylim(- evoked.iloc[0].pred * 0.05, evoked.iloc[0].pred * 1.05)
axes[1].set_ylabel("Firing rate (Hz)")
axes[1].set_xlabel("Stimulus rank")

In [None]:
fig.savefig(f"../figures/motif_rates.pdf")

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(1.0, 1.0), dpi=300)
for i, unit_name in enumerate(example_units):
    coefs = rates.loc[unit_name]
    spont = coefs.loc["silence"]
    evoked = coefs.iloc[1:].sort_values(by="coef", ascending=False)
    evoked["rank"] = np.arange(evoked.shape[0]) + 1
    evoked["pred"] = np.exp(evoked.coef + spont.coef)
    evoked["norm"] = evoked.pred / evoked.pred.max()
    responsive = evoked.loc[evoked.responsive]
    axes.axhline(np.exp(spont['coef']) / evoked.pred.max(), linestyle='--', color=spike_type_colors[i])
    axes.plot(evoked["rank"], evoked.norm, 'o-', markersize=1.5, markerfacecolor="white", markeredgewidth=0.2, color=spike_type_colors[i])
    axes.plot(responsive["rank"], responsive.norm, 'o', markersize=1.0, color=spike_type_colors[i])
axes.set_ylabel("Normalized Firing Rate")
axes.set_xlabel("Stimulus rank")

## Sparseness

In [None]:
rate_file = Path("..") / "build" / "motif_rate_coefs.csv"
rates = pd.read_csv(rate_file)
rates["bird"] = rates.unit.apply(lambda s: s.split("_")[0])
rates = rates.join(features, on="unit", how="inner").join(birds, on="bird", how="inner")
rates

In [None]:
bs_site = (
    rates
    .set_index("site")
    .loc["C104_3_1"]
    .query("spike=='wide'")
    .set_index(["unit", "foreground"])
    .is_responsive
    .unstack()
)
ns_site = (
    rates
    .set_index("site")
    .loc["C44_3_1"]
    .query("spike=='narrow'")
    .set_index(["unit", "foreground"])
    .is_responsive
    .unstack()
)
bs_spar = 1 - bs_site.mean().mean()
ns_spar = 1 - ns_site.mean().mean()

In [None]:
bs_site.sort_values(list(bs_site.columns)).index.get_loc(example_units[0])

In [None]:
ns_site.sort_values(list(ns_site.columns)).index.get_loc(example_units[1])

In [None]:
ns_site.shape

In [None]:
cmap_bs = ListedColormap(["white", spike_type_colors[0]])
cmap_ns = ListedColormap(["white", spike_type_colors[1]])

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(0.7, 1.8), height_ratios=(bs_site.shape[0], ns_site.shape[0]), dpi=300)
axes[0].imshow(bs_site.sort_values(list(bs_site.columns)), cmap=cmap_bs, interpolation="none")
axes[0].set_title(f"sparseness = {bs_spar:.2f}", fontsize=3, pad=0, loc="right")
axes[1].imshow(ns_site.sort_values(list(ns_site.columns)), cmap=cmap_ns, interpolation="none")
axes[1].set_title(f"sparseness = {ns_spar:.2f}", fontsize=3, pad=0, loc="right")
axes[1].set_ylabel("Neuron")
axes[1].set_xlabel("Stimulus")

In [None]:
fig.savefig(f"../figures/motif_sparseness.pdf")

## Motif discriminability

In [None]:
from core import trial_to_spike_train, pairwise_spike_comparison, inv_spike_sync_matrix
from motif_discrim import ShuffledLeaveOneOut
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import cross_validate, LeaveOneOut, cross_val_score, cross_val_predict

In [None]:
bkgnd_levels = (-25, -100)
figsize_distances = (1.5, 2.5)
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, sharey=True, figsize=(2.4, 3.5), dpi=300)
for i, unit_name in enumerate(example_units):
    trials = motifs.set_index(["unit", "foreground"]).loc[unit_name].drop(["silence", "igmi8fxa"])
    spike_trains = trials.apply(partial(trial_to_spike_train, interval_end=trials.interval_end.min()), axis=1)
    dist = inv_spike_sync_matrix(spike_trains.to_list())
    img = axes[i].imshow(1 - dist, vmin=0, vmax=1, aspect="equal", origin="upper", interpolation=None)
    for x in range(10, 90, 10):
        axes[i].axvline(x, color="w", linewidth=0.5)
        axes[i].axhline(x, color="w", linewidth=0.5)
    axes[i].get_yaxis().set_visible(False)
    axes[i].get_xaxis().set_visible(False)
    axes[i].set_title(unit_name, fontsize=3, pad=1, loc="right")
fig.colorbar(img, ax=axes, location="bottom", shrink=0.3, aspect=10)  

In [None]:
fig.savefig(f"../figures/motif_distances.pdf")

In [None]:
n_neighbors=9

def kneighbors_classifier(distance_matrix, rng, normalize="true"):
    """Compute confusion matrix of a k-neighbors classifier on the spike distance matrix"""
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
    loo = ShuffledLeaveOneOut(rng)
    groups = distance_matrix.index
    names = groups.unique()
    group_idx, _ = pd.factorize(groups)
    cv_results = cross_val_score(neigh, distance_matrix.values, group_idx, cv=loo)
    pred = cross_val_predict(neigh, distance_matrix.values, group_idx, cv=loo)
    cm = confusion_matrix(group_idx, pred, normalize=normalize)
    return pd.DataFrame(cm, index=names, columns=names)

In [None]:
rng = np.random.default_rng(1024)

fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(1.5, 1.5), dpi=300)
for i, unit_name in enumerate(example_units):
    trials = motifs.set_index(["unit", "foreground"]).loc[unit_name].drop(["silence", "igmi8fxa"])
    spike_trains = trials.apply(partial(trial_to_spike_train, interval_end=trials.interval_end.min()), axis=1)
    spike_dists = pairwise_spike_comparison(spike_trains, comparison_fun=inv_spike_sync_matrix, stack=False)
    conf_mtx = kneighbors_classifier(spike_dists, rng, normalize="true")
    img = axes[i].imshow(conf_mtx, origin="upper", aspect="equal", vmin=0, vmax=1.0)
    axes[i].get_yaxis().set_visible(False)
    axes[i].get_xaxis().set_visible(False)
    axes[i].set_title(unit_name, fontsize=3, pad=1, loc="right")

In [None]:
fig.savefig(f"../figures/motif_classifier.pdf")