Notebook to plot spike classification figure panels.


In [None]:
import sys
sys.path.insert(0, "../scripts")

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from dlab import plotting

from sklearn.mixture import GaussianMixture

import graphics_defaults

In [None]:
unit_colors = ["#70549B", "#FF7F0E"]

## Spike features

In [None]:
feature_file = Path("..") / "build" / "mean_spike_features.csv"
features = pd.read_csv(feature_file, index_col="unit")

In [None]:
upsampled_rate_khz = 150
waveform_file = Path("..") / "build" / "mean_spike_waveforms.csv"
mean_waveforms = pd.read_csv(waveform_file, index_col="time_samples")
mean_waveforms.index /= upsampled_rate_khz
ncells, npoints = mean_waveforms.shape

In [None]:
unit_features = features[~features.spike.isna()]
narrow_units = unit_features[unit_features.spike=="narrow"].index
wide_units = unit_features[unit_features.spike=="wide"].index

In [None]:
fig, ax = plt.subplots(nrows=1, figsize=(1.7,1.7), dpi=300)
axin1 = ax.inset_axes([0.55, 0.7, 0.3, 0.2])
axin1.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
for idx, group in enumerate((wide_units, narrow_units)):
    ax.plot(unit_features.loc[group].peak2_t, 
            unit_features.loc[group].ptratio, '.',
            color=unit_colors[idx],
            markersize=3.5, 
            markeredgewidth=0.0, 
            alpha=0.3)
    axin1.plot(mean_waveforms[group].mean(axis=1), color=unit_colors[idx])
plotting.simple_axes(ax)
ax.set_xlabel("Spike width (ms)")
ax.set_ylabel("Peak/trough ratio")
fig.savefig("../figures/unit_waveform_features.pdf")

## Classify separately in each area

A quick test to see if there are any differences between areas.

In [None]:
def classify_spikes(df):
    X = df.loc[:,["peak2_t", "ptratio"]]
    gmix = GaussianMixture(n_components=2).fit(X)
    narrow = gmix.means_[:,0].argmin()
    return pd.Series(1.0 * (gmix.predict(X) == narrow), index=df.index).rename("is_narrow")

In [None]:
is_narrow = classify_spikes(unit_features)
fig, ax = plt.subplots(nrows=1, figsize=(2,2), dpi=300)
scatter = ax.scatter(unit_features.peak2_t, unit_features.ptratio, c=(is_narrow), s=0.7, cmap="tab10", alpha=0.5)
ax.set_xlabel("Spike width (ms)")
ax.set_ylabel("Peak/trough ratio")
plotting.simple_axes(ax)

In [None]:
area_names = {
    "deep": "L3/NCM",
    "intermediate": "L2a/L2b",
    "superficial": "L1/CM"
}


In [None]:
site_file = Path("..") / "inputs" / "recording_metadata.csv"
sites = pd.read_csv(site_file, index_col="site")
sites["area"] = pd.Categorical(sites["area"].apply(lambda s: area_names[s]), categories=["L2a/L2b", "L1/CM", "L3/NCM"], ordered=True)

In [None]:
feats = unit_features.loc[:,["peak2_t", "ptratio"]].reset_index()
feats["site"] = feats.unit.apply(lambda s: "_".join(s.split("_")[:-1]))
feats = feats.join(sites, on="site", how="inner")
feats["is_narrow"] = feats.groupby("area", observed=False).apply(classify_spikes, include_groups=False).droplevel(0)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(4.5,1.5), dpi=300, sharey=True)

for idx, (area, area_df) in enumerate(feats.groupby("area", observed=False)):
    axin1 = ax[idx].inset_axes([0.55, 0.7, 0.3, 0.2])
    axin1.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    for group in range(2):
        group_df = area_df[area_df.is_narrow==group]
        ax[idx].plot(group_df.peak2_t, group_df.ptratio, '.', markersize=4, markeredgewidth=0.0, alpha=0.3)
        ax[idx].set_title(area)
        plotting.simple_axes(ax[idx])
        axin1.plot(mean_waveforms[group_df["unit"]].mean(1))
ax[1].set_xlabel("Spike width (ms)")
ax[0].set_ylabel("Peak/trough ratio")
fig.savefig("../figures/unit_waveform_features_by_area.pdf")