## Response discriminability

This notebook is used to generate example plots for Figure 5.

In [None]:
import sys

sys.path.insert(0, "../scripts")

In [None]:
import json
from functools import partial
from pathlib import Path

import graphics_defaults  # noqa: F401
import matplotlib.pyplot as plt
import pandas as pd
import pyspike
from core import (
    MotifSplitter,
    pairwise_spike_comparison,
    split_trials,
    trial_to_spike_train,
)
from numpy.random import default_rng
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import LeaveOneOut, cross_val_predict
from sklearn.neighbors import KNeighborsClassifier

rng = default_rng()

In [None]:
dataset_dir = Path("../datasets/zebf-social-acoustical-ephys")
metadata_dir = dataset_dir / "metadata/"
response_dir = dataset_dir / "responses/"
stim_dir = dataset_dir / "stimuli"

In [None]:
# example 1: BS neuron
unit_name = "C104_3_1_c201"

In [None]:
# example 2: NS neuron
unit_name = "C44_3_1_c74"

In [None]:
# load the response and split by motif
pprox_file = (response_dir / unit_name).with_suffix(".pprox")
unit = json.loads(pprox_file.read_text())
splitter = MotifSplitter()
motifs = split_trials(splitter, unit, metadata_dir).drop("igmi8fxa", level=1)

In [None]:
# classifier

def inv_spike_sync_matrix(*args, **kwargs):
    """ Inverse of the spike sync matrix (to give dissimilarities) """
    return 1 - pyspike.spike_sync_matrix(*args, **kwargs)

n_neighbors = 9

class ShuffledLeaveOneOut(LeaveOneOut):
    
    def __init__(self, rng):
        super().__init__()
        self.rng = rng
        
    def split(self, *args, **kwargs):
        for train, test in super().split(*args, **kwargs):
            yield self.rng.permutation(train), test
            
def kneighbors_classifier(distance_matrix, rng, normalize="true"):
    """Compute confusion matrix of a k-neighbors classifier on the spike distance matrix"""
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
    loo = ShuffledLeaveOneOut(rng)
    groups = distance_matrix.index
    names = groups.unique()
    group_idx, _ = pd.factorize(groups)
    # cv_results = cross_val_score(neigh, distance_matrix.values, group_idx, cv=loo)
    pred = cross_val_predict(neigh, distance_matrix.values, group_idx, cv=loo)
    cm = confusion_matrix(group_idx, pred, normalize=normalize)
    return pd.DataFrame(cm, index=names, columns=names)


In [None]:
# convert data to pyspike's format
spike_trains = motifs.apply(
    partial(trial_to_spike_train, interval_end=motifs.interval_end.min()), axis=1
)

## Spike distance matrix

The first step is to calculate all pairwise distances. This is Figure 5A.

In [None]:
bkgnd_level = -100
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(0.9, 1.5), dpi=400)
st = spike_trains.loc[bkgnd_level]
dist = inv_spike_sync_matrix(st)
img = ax.imshow(1 - dist, vmin=0, vmax=1, aspect="equal", origin="upper", interpolation=None)
for x in range(10, 90, 10):
    ax.axvline(x, color="w", linewidth=0.5)
    ax.axhline(x, color="w", linewidth=0.5)
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
fig.colorbar(img, ax=ax, location="bottom", shrink=0.3, aspect=10)    

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_distances_training.pdf")

## Confusion matrix

Generated by using a K-neighbors classifier to assign each trial to the most likely stimulus.

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=(0.8, 1), dpi=300)
spike_dists = pairwise_spike_comparison(spike_trains.loc[bkgnd_level], comparison_fun=inv_spike_sync_matrix, stack=False)
neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric="precomputed")
loo = ShuffledLeaveOneOut(rng)
group_idx, names = spike_dists.index.factorize()
pred = cross_val_predict(neigh, spike_dists.values, group_idx, cv=loo)
conf_mtx = confusion_matrix(group_idx, pred, normalize="true")    
img = ax.imshow(conf_mtx, origin="upper", aspect="equal", vmin=0, vmax=1.0)
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
ax.set_title(f"{bkgnd_level} dB", fontdict={"fontsize": 6})
fig.colorbar(img, ax=ax, location="bottom", shrink=0.4, aspect=10)

In [None]:
fig.savefig(f"../figures/{unit_name}_motif_discrim_training.pdf")