## Noise invariance example

This notebook is used to generate the examples plots for Figure 8.

In [None]:
import sys

sys.path.insert(0, "../scripts")

In [None]:
import json
from pathlib import Path

import graphics_defaults  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
from core import MotifSplitter, load_wave, rescale, split_trials
from dlab import plotting, signal, spikes
from filters import SpectrogramTransform


In [None]:
dataset_dir = Path("../datasets/zebf-social-acoustical-ephys")
metadata_dir = dataset_dir / "metadata/"
response_dir = dataset_dir / "responses/"
stim_dir = dataset_dir / "stimuli"
# spectrogram parameters
window_size = 0.020
max_frequency = 8000

In [None]:
rate_binwidth = 0.005
rate_bandwidth = 0.02
kernel, _ = signal.kernel("gaussian", rate_bandwidth, rate_binwidth)

In [None]:
unit_name = "C104_3_1_c67"
selected_motifs = ["g29wxi4q", "vekibwgj", "9ex2k0dy"]

In [None]:
pprox_file = (response_dir / unit_name).with_suffix(".pprox")
unit = json.loads(pprox_file.read_text())

In [None]:
splitter = MotifSplitter()
motifs = split_trials(splitter, unit, metadata_dir).drop("igmi8fxa", level=1)
motif_names = motifs.index.unique(level="foreground")
wav_signals = {}
for name in motif_names:
    wav_path = (stim_dir / name).with_suffix(".wav")
    wav_signals[name] = load_wave(wav_path)
    # these stimuli are scaled to -20 dB FS so need to be corrected to match their amplitude in the scene stimuli
    rescale(wav_signals[name], -30)

In [None]:
# set up spectrogram transform
stfter = SpectrogramTransform(window_size, wav_signals[name]["sampling_rate"], max_frequency)
fgrid = stfter.freq

In [None]:
n_motifs = len(selected_motifs)
# plot each noise level in a different color
colors = {
    v: c for v, c in zip(motifs.index.unique(level=0), plt.color_sequences["tab20"])
}
fig = plt.figure(figsize=(2.4, 2.9), dpi=300)
subfigs = fig.subfigures(1, n_motifs, hspace=0.001, wspace=0.0001)
for motif, subfig in zip(selected_motifs, subfigs):
    trials = motifs.xs(motif, level="foreground")
    axes = subfig.subplots(3, sharex=True, height_ratios=[1, 5, 1])
    # signal, sampling_rate = wav_signals[motif]
    # plotting.spectrogram(axes[0], signal=signal, sampling_rate_hz=sampling_rate, frequency_range=(0, 8000))
    # axes[0].set_yticks([500, 8000], ["1", "8"])
    spec = stfter.transform(wav_signals[motif]["signal"], scaling=None) + 1e-6
    log_spec = 10 * np.log10(spec / stfter.scale1)
    tgrid = stfter.tgrid(spec)
    pos = axes[0].imshow(log_spec, aspect="auto", origin="lower", vmin=-90, vmax=-20, 
                   extent=(tgrid[0], tgrid[-1], fgrid[0] / 1000, fgrid[-1] / 1000))
    axes[0].set_yticks([1, 8])
    for i, trial in enumerate(trials.sort_index(ascending=False).itertuples()):
        if isinstance(trial.events, float):
            continue
        background_level = trial.Index
        axes[1].plot(
            trial.events,
            [i] * trial.events.size,
            color=colors[background_level],
            marker="|",
            markeredgewidth=0.5,
            linestyle="",
        )
    axes[1].set_ylim(0, trials.shape[0])
    axes[1].get_yaxis().set_visible(False)
    plotting.adjust_raster_ticks(axes[1], gap=3.2)
    for lvl, trls in trials.sort_index(ascending=False).groupby("background-dBFS"):
        rate, bins = spikes.rate(
            trls.events.dropna().explode(),
            rate_binwidth,
            kernel,
            start=0,
            stop=trials.interval_end.max(),
        )
        axes[2].plot(bins, rate, color=colors[lvl])
    plotting.simple_axes(*axes)
    #subfig.subplots_adjust(hspace=0.01)

max_rate = max(subfig.axes[2].get_ylim()[1] for subfig in subfigs.flat)
for subfig in subfigs:
    subfig.axes[2].set_ylim((0, max_rate))
    subfig.subplots_adjust(left=0.05, right=0.95, hspace=0.08)
for subfig in subfigs[1:]:
    for ax in subfig.axes:
        ax.get_yaxis().set_visible(False)
        ax.spines["left"].set_visible(False)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_noise_rasters.pdf")