# Example responses

This notebook generates plots for Figure 3 showing narrow- vs broad-spiking neuron waveforms and example responses from different areas.

In [None]:
import sys

sys.path.insert(0, "../scripts")

In [None]:
import json
from pathlib import Path

import graphics_defaults  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from core import (
    MotifSplitter,
    find_resources,
    load_wave,
    rescale,
    split_trials,
)
from dlab import plotting
from filters import SpectrogramTransform

In [None]:
unit_type_colors = ["#70549B", "#FF7F0E"]
area_names = {
    "deep": "L3/NCM",
    "intermediate": "L2a/L2b",
    "superficial": "L1/CM"
}

In [None]:
dataset_dir = Path("../datasets/zebf-social-acoustical-ephys")
# set this to `./inputs` for if analyzing new data
metadata_dir = dataset_dir / "metadata/"
response_dir = dataset_dir / "responses/"
stim_dir = dataset_dir / "stimuli"
# spectrogram parameters
window_size = 0.020
max_frequency = 8000

## Spike features

Spike waveforms are extracted during spike sorting. The `scripts/unit_waveforms.py` was used to upsample and align the waveforms, make various measurements (peak heights, trough depth, trough-to-peak time), and cluster the spikes by waveform using a Gaussian mixture model. Because files with the individual spike waveforms are too big to deposit in a public data repository, only the results of the preprocessing were saved in the dataset.

In [None]:
# load these from build if you ran scripts/unit_waveforms.py on new data
feature_file = metadata_dir / "mean_spike_features.csv"
waveform_file = metadata_dir / "mean_spike_waveforms.csv"
features = pd.read_csv(feature_file, index_col="unit")
# exclude positive spikes and others that can't be classified
unit_features = features[~features.spike.isna()]
narrow_units = unit_features[unit_features.spike=="narrow"].index
wide_units = unit_features[unit_features.spike=="wide"].index

In [None]:
upsampled_rate_khz = 150   # this is taken from `scripts/unit_waveforms.py`
mean_waveforms = pd.read_csv(waveform_file, index_col="time_samples")
mean_waveforms.index /= upsampled_rate_khz
ncells, npoints = mean_waveforms.shape

In [None]:
fig, ax = plt.subplots(nrows=1, figsize=(1.7,1.7), dpi=300)
axin1 = ax.inset_axes([0.55, 0.7, 0.3, 0.2])
axin1.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
for idx, group in enumerate((wide_units, narrow_units)):
    ax.plot(unit_features.loc[group].peak2_t, 
            unit_features.loc[group].ptratio, '.',
            color=unit_type_colors[idx],
            markersize=3.5, 
            markeredgewidth=0.0, 
            alpha=0.3)
    axin1.plot(mean_waveforms[group].mean(axis=1), color=unit_type_colors[idx])
plotting.simple_axes(ax)
ax.set_xlabel("Spike width (ms)")
ax.set_ylabel("Peak/trough ratio")
fig.savefig("../figures/unit_waveform_features.pdf")

## Example responses

In [None]:
# load some metadata so we know what areas the example units are from
site_file = metadata_dir / "recordings.csv"
sites = pd.read_csv(site_file, index_col="site")
sites["area"] = pd.Categorical(sites["area"].apply(lambda s: area_names[s]), categories=["L2a/L2b", "L1/CM", "L3/NCM"], ordered=True)

units = features.reset_index()[["unit", "spike"]]
units["site"] = units.unit.apply(lambda s: "_".join(s.split("_")[:-1]))
units = units.join(sites, on="site", how="inner").set_index("unit")


In [None]:
# selected example motifs and units, one of each type from each area
selected_motifs = ["0oq8ifcb","g29wxi4q", "vekibwgj", "ztqee46x"]
unit_names = [
    "C194_3_1_c126",
    "C104_4_1_c120",
    "C42_4_1_c131",
    "C42_4_1_c14",
    "C104_3_1_c201",
    "C44_3_1_c74",
]
units.loc[unit_names]

In [None]:
unit_responses = []
for unit_name, pprox_file in find_resources(*unit_names, alt_base=response_dir):
    unit = json.loads(pprox_file.read_text())
    splitter = MotifSplitter()
    responses = (
        split_trials(splitter, unit, metadata_dir)
        .reset_index()
        .rename(columns=lambda s: s.replace("-", "_"))
        .query("background_dBFS == -100")
        .drop(["background", "foreground_dBFS", "offset"], axis=1)
    )
    responses["unit"] = unit_name
    unit_responses.append(responses)
motifs = pd.concat(unit_responses)

In [None]:
wav_signals = {}
for name, wav_path in find_resources(*selected_motifs, alt_base=stim_dir):
    wav_signals[name] = load_wave(wav_path)
    # these stimuli are scaled to -20 dB FS so need to be corrected to match their amplitude in the scene stimuli
    rescale(wav_signals[name], -30)

In [None]:
# set up spectrogram transform
stfter = SpectrogramTransform(window_size, wav_signals[name]["sampling_rate"], max_frequency)
fgrid = stfter.freq

In [None]:
n_motifs = len(selected_motifs)
n_units = len(unit_names)
df = motifs.set_index(["foreground", "unit"])
fig, axes = plt.subplots(nrows=1 + n_units, ncols=n_motifs, sharex=True, sharey="row", figsize=(3, 2.9), dpi=300)
for col, motif in zip(axes.T, selected_motifs):
    spec = stfter.transform(wav_signals[motif]["signal"], scaling=None) + 1e-6
    log_spec = 10 * np.log10(spec / stfter.scale1)
    tgrid = stfter.tgrid(spec)
    pos = col[0].imshow(log_spec, aspect="auto", origin="lower", vmin=-90, vmax=-20, 
                   extent=(tgrid[0], tgrid[-1], fgrid[0] / 1000, fgrid[-1] / 1000))
    col[0].set_yticks([0, 4, 8])
    col[0].get_xaxis().set_visible(False)
    # col[0].set_title(motif)
    motif_trials = df.loc[motif]
    for row, unit in zip(col[1:], unit_names):
        trials = motif_trials.loc[unit]
        for i, trial in enumerate(trials.itertuples()):
            if isinstance(trial.events, float):
                continue
            row.plot(
                trial.events,
                [i] * trial.events.size,
                color="k",
                marker="|",
                markeredgewidth=0.2,
                markersize=1.2,
                linestyle="",
            )
        row.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
# much axis annoyance
for i in range(n_motifs):
    for boundary in ("right", "bottom", "top"):
        axes[0, i].spines[boundary].set_visible(False)
    axes[0, i].get_xaxis().set_visible(False)
    if i > 0:
        axes[0, i].get_yaxis().set_visible(False)
        axes[0, i].spines["left"].set_visible(False)
for i, unit_name in enumerate(unit_names, start=1):
    info = units.loc[unit_name]
    axes[i, 0].set_ylabel(f"{info.area}\n{info.spike}")
    if i < n_units:
        for ax in axes[i]:
            for boundary in ("left", "right", "bottom", "top"):
                ax.spines[boundary].set_visible(False)
            ax.get_xaxis().set_visible(False)
            #ax.set_frame_on(False)
    else:
        for ax in axes[i]:
            for boundary in ("left", "right","top"):
                ax.spines[boundary].set_visible(False)
fig.subplots_adjust(hspace=0.1, wspace=0.1)

spax_pos = pos.axes.get_position()
cax = fig.add_axes((spax_pos.xmax, spax_pos.ymin, 0.01, spax_pos.height))
cax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
cax.yaxis.tick_right()
cbar = fig.colorbar(pos, cax=cax)

In [None]:
fig.savefig("../figures/motif_rasters.pdf")