### Imports 

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import scipy.signal as sg
import matplotlib.pyplot as plt
from neuropy import plotting
import subjects

### Waveform amplitude stability

In [None]:
import matplotlib.pyplot as plt
import subjects
import numpy as np
from scipy import stats

sessions = subjects.sd.ratSday3

In [None]:
for sub, sess in enumerate(sessions):
    rec_duration = sess.eegfile.duration
    t_bin = np.arange(0, rec_duration, 2 * 3600)
    spktrns = sess.neurons.spiketrains
    wav_amps = sess.neurons.waveforms_amplitude

    mean_wav_amp = []
    for spk, amp in zip(spktrns, wav_amps):
        mean_wav_amp.append(
            stats.binned_statistic(spk, amp, bins=t_bin, statistic="mean")[0]
        )
    mean_wav_amp = np.asarray(mean_wav_amp)

In [None]:
%matplotlib widget

plt.plot(mean_wav_amp.T)

### Stability viewer from phyio

In [None]:
import subjects
import numpy as np
from scipy import stats

sess = subjects.sd.ratSday3[0]

In [None]:
neurons = sess.neurons_iso
metrics = neurons.metadata["stability_metrics"]

In [None]:
import seaborn as sns

metrics = metrics.reset_index(drop=True)
sns.relplot(
    data=metrics,
    x="epoch",
    y="isolation_distances",
    kind="line",
    col="cluster_id",
    col_wrap=5,
)

### Waveform stability using mean waveforms

In [None]:
from neuropy.core import Signal
from neuropy.utils.signal_process import filter_sig
from palettable.cartocolors.qualitative import Pastel_6
from spikemetrics import calculate_metrics, Epoch, calculate_pc_metrics
from pathlib import Path

rng = np.random.default_rng()
sess = subjects.nsd.ratSday2[0]
# neurons = sess.neurons_stable.get_neuron_type("pyr")
# neurons = sess.neurons.get_neuron_type("pyr")
neurons = sess.neurons
clupath = Path(
    "/data/Clustering/sessions/RatS/Day2NSD/spykcirc/RatS-Day2NSD-2020-11-27_10-22-29-1.GUI/"
)
pc_feat = np.load(clupath / "pc_features.npy", mmap_mode="r")
pc_feat_ind = np.load(clupath / "pc_feature_ind.npy", mmap_mode="r")
spike_clusters = np.load(clupath / "spike_clusters.npy", mmap_mode="r")
spike_times = np.load(clupath / "spike_times.npy", mmap_mode="r")
spike_templates = np.load(clupath / "spike_templates.npy", mmap_mode="r")
spike_amp = np.load(clupath / "amplitudes.npy", mmap_mode="r")
cluster_info = pd.read_csv(clupath / "cluster_info.tsv", sep="\t")


def get_phy_id(indx):
    n_spikes = len(neurons.spiketrains[indx])
    phy_id = cluster_info[cluster_info.n_spikes == n_spikes]["id"].values
    assert len(phy_id) == 1
    return phy_id[0]

In [None]:
# ===== Which neurons to choose =========
nspikes_stable_pyr = sess.neurons_stable.get_neuron_type("pyr").n_spikes
nspikes_all = sess.neurons.n_spikes
stable_pyr_indx = np.where(np.isin(nspikes_all, nspikes_stable_pyr))
unstable_pyr_indx = np.setdiff1d(
    np.where(sess.neurons.neuron_type == "pyr")[0], stable_pyr_indx
)
stable_pyr_indx, unstable_pyr_indx

In [None]:
fig = plotting.Fig(grid=(15, 8), fontsize=8)
cmap = Pastel_6.mpl_colors
color_thresh = ""
nframes = 60

# neuron_indxs = [5, 10, 51] #index for stable pyr
# neuron_indxs = [5, 13, 63, 7, 65] # index for all pyr
# neuron_indxs = [8, 22, 106, 99, 112]  # index for all units
neuron_indxs = [115, 22, 82, 16, 99, 112]  # index for all units
# neuron_indxs = [57, 75, 83, 88]  # index for all units
# neuron_indxs = [115]

mean_frate = neurons.firing_rate[neuron_indxs]

for i, neuron_indx in enumerate(neuron_indxs):
    spktime = neurons.spiketrains[neuron_indx]
    shank_id = neurons.shank_ids[neuron_indx]
    channels = sess.recinfo.channel_groups[shank_id]
    srate = sess.datfile.sampling_rate

    rec_duration = sess.eegfile.duration
    windows = np.linspace(0, rec_duration, 6)
    spk_window_loc = np.digitize(spktime, windows)

    subfig = fig.add_subfigure(fig.gs[0:4, i])
    axs = subfig.subplots(1, 2, sharey=True, sharex=True)

    for w, window_loc in enumerate([1, 5]):
        spk_window = spktime[spk_window_loc == window_loc]
        spk_window = np.sort(rng.choice(spk_window, 300))

        spk_frames = (spk_window * srate).astype("int")
        start_frames = spk_frames - nframes // 2
        frames_in_windows = np.concatenate([np.arange(f, f + 60) for f in start_frames])

        lfp = sess.datfile._raw_traces[np.ix_(list(channels), frames_in_windows)]
        lfp = filter_sig.highpass(lfp, cutoff=500, fs=srate)
        lfp = lfp.reshape(len(channels), nframes, -1, order="F").mean(axis=-1)

        lfp = Signal(lfp, channel_id=channels, sampling_rate=30000)
        plotting.plot_signal_traces(lfp, ax=axs[w], pad=0.7, color=cmap[i], lw=2)

    window_frate = np.histogram(spktime, windows)[0] / np.diff(windows)

    ax = fig.subplot(fig.gs[4, i])
    ax.plot(
        [1, 2, 3, 4, 5],
        (window_frate / mean_frate[i]) * 100,
        marker="o",
        color=cmap[i],
        markersize=5,
    )
    ax.set_xticks(np.arange(5) + 1)
    ax.axhline(25, color="gray", lw=2)
    ax.set_ylim(0, 300)

# probe_data = sess.probegroup.to_dataframe()
# chanmap = probe_data[probe_data.shank_id==0]
x_coord = np.concatenate((np.arange(8), np.arange(8) + 15))
y_coord = np.concatenate((np.arange(8) * 15, np.arange(8) * 15 + 15))
window_epoch = [
    Epoch(f"window{i+1}", e1, e2)
    for i, (e1, e2) in enumerate(zip(windows[:-1], windows[1:]))
]

for i, neuron_indx in enumerate(neuron_indxs):
    # spktime = neurons.spiketrains[neuron_indx]
    phy_id = get_phy_id(neuron_indx)
    spk_pos = np.where(np.isin(spike_clusters, phy_id))[0]
    # spk_pos = rng.choice(spk_pos, 2000)
    spk_time = spike_times[spk_pos] / srate
    shank_id = neurons.shank_ids[neuron_indx]

    other_indxs = np.setdiff1d(np.where(neurons.shank_ids == shank_id)[0], neuron_indx)
    other_phy_id = np.array([get_phy_id(_) for _ in other_indxs])
    other_spk_pos = np.where(np.isin(spike_clusters, other_phy_id))[0]
    # other_spk_pos = rng.choice(other_spk_pos, 2000)
    other_spk_id = spike_clusters[other_spk_pos]
    other_spk_time = spike_times[other_spk_pos] / srate

    pos_combined = np.concatenate((spk_pos, other_spk_pos))
    pc_feat_combined = pc_feat[pos_combined, :, :]

    iso_dist = []
    for w in window_epoch:
        try:
            metrics = calculate_metrics(
                spike_times=np.concatenate([spk_time, other_spk_time]),
                spike_clusters=np.concatenate(
                    [np.ones(len(spk_time)), 2 * np.ones(len(other_spk_time))]
                ),
                # total_units=2,
                amplitudes=spike_amp[pos_combined],
                pc_features=pc_feat_combined,
                pc_feature_ind=np.vstack((np.arange(16), np.arange(16))),
                # epochs=window_epoch,
                epochs=[w],
                channel_locations=np.vstack((x_coord, y_coord)).T,
                duration=sess.datfile.duration,
                verbose=False,
                params=dict(
                    num_channels_to_compare=7,
                    max_spikes_for_unit=1000,
                    max_spikes_for_nn=1000,
                    n_neighbors=300,
                    # channel_locations=np.vstack((x_coord, y_coord)).T,
                    # min_num_pcs=5,
                    n_silhouette=3,
                    isi_threshold=0.002,
                    min_isi=2,
                ),
            )
            w_iso_dist = metrics["isolation_distance"].values[0]
        except:
            w_iso_dist = 0

        iso_dist.append(w_iso_dist)

    # iso_dist = metrics["isolation_distance"].values[::2]

    ax = fig.subplot(fig.gs[5, i])
    ax.plot([1, 2, 3, 4, 5], iso_dist, marker="o", color=cmap[i], markersize=5)
    ax.axhline(15, color="gray", lw=2)
    ax.set_ylim(0, 75)
    ax.set_xticks(np.arange(5) + 1)

fig.savefig(subjects.figpath_sd / "neurons_stability")