# Outline of Part 2, Analysis of sorted spikes 

## Set up, paths
## Functions - easier than you think!
## Load and inspect spikes
### Spike time raster
### Binned raster
## Zero in on a feature of data
## Create a perievent time histogram

In [None]:
# @title Helper functions
!pip install -U matplotlib pandas xarray numpy tqdm spykes hvplot bokeh

import os

import pandas as pd
import xarray as xr
import requests
from pathlib import Path
import zipfile


def download_data():
    # Get link:
    r = requests.get('http://data.cortexlab.net/singlePhase3/data/dataset.zip')
    path = Path('data/dataset.zip')
    path.parent.mkdir(exist_ok=True)
    if not path.exists():
        with open(path, "wb") as fid:
            # Write out content of link:
            fid.write(r.content)
    # Unzip
    with zipfile.ZipFile(path, 'r') as zip_ref:
        zip_ref.extractall(path.parent)


def load_spikes_from_phy(path_to_data='/Users/myroshnychenkm2/Downloads/dataset/', sampling_frequency=30000):
    """
    Get spikes from a kilosort/phy result folder
    :param path_to_data:
    :param sampling_frequency:
    :return:
    :id: neuron id, 1xN
    :ts: corresponding spiketime, 1xN
    """
    groupfname = os.path.join(path_to_data, 'cluster_groups.csv')
    groups = pd.read_csv(groupfname, delimiter='\t')

    # load spike times and cluster IDs
    with open(path_to_data + 'spike_clusters.npy', 'rb') as f:
        ids = np.load(f).flatten()
    with open(path_to_data + 'spike_times.npy', 'rb') as f:
        ts = np.load(f).flatten()

    # Create the list of our "good" labeled units
    ids_to_take = groups[(groups.group == 'good')].cluster_id
    # Find which spikes beloing to our "good" groups
    spikes_to_take = []
    for i in tqdm(ids_to_take, desc='Selecting only good spikes'):
        spikes_to_take.extend((ids == i).nonzero()[0])
    # only take spikes that are in our list
    ids = np.array(ids[spikes_to_take])
    ts = np.array(ts[spikes_to_take]).astype(float) / sampling_frequency

    return ids, ts


def bin_neuron(spike_times, bin_size=.100, window=None):
    """
    Make binned raster for a single neuron
    :param spike_times:
    :param bin_size: in sec
    :param window:
    :return:
    """
    if window is None:
        window = [0, spike_times.max()]
    bins = np.arange(window[0], window[1] + bin_size, bin_size)
    return np.histogram(spike_times, bins)[0]


def bin_neurons(spike_times, neuron_ids, bin_size=None, window=None, plotose=False):
    """
    Make binned raster for many neurons
    :param spike_times:
    :param neuron_ids:
    :param bin_size: in sec
    :param window:
    :param plotose:
    :return:
    """
    if window is None:
        window = [0, spike_times.max()]
    # the following uses an inline for loop (look it up):
    spike_counts = [bin_neuron(spike_times[neuron_ids == neuron_id], bin_size, window)
                    for neuron_id in tqdm(np.unique(neuron_ids))]
    spike_counts = np.vstack(spike_counts)
    raster = xr.DataArray(spike_counts, coords=dict(Time=np.arange(window[0], window[1] + bin_size, bin_size)[:-1],
                                                    Single_unit_id=range(len(np.unique(neuron_ids)))),
                          dims=['Single_unit_id', 'Time'])
    if plotose:
        raster.plot(robust=True)
    return raster


def identify_down_states(ts, bin_size=.05, number_of_neurons_treshold=20, minimum_time_between_states=0.15):
    """
    Find spontaneous periods of quiecence in spiketimes
    :param ts:
    :param bin_size:
    :param number_of_neurons_treshold:
    :param minimum_time_between_states:
    :return:
    """
    lfp = bin_neuron(np.sort(ts), bin_size=bin_size)
    down_states = np.where(lfp < number_of_neurons_treshold)[0]
    down_states_lengths = np.diff(down_states)
    print(f'Eliminating {down_states[1:][down_states_lengths < .15 / bin_size].shape} that are too short')
    down_states = down_states[1:][down_states_lengths > minimum_time_between_states / bin_size]
    print(f'Ended up with {down_states.shape} down states')
    # convert into seconds:
    down_states = down_states * bin_size
    down_states -= .03
    print(down_states[(down_states > 63) & (down_states < 68)])  # compare with raster
    return lfp, down_states


class PSTH:
    """
    A collection of functions dealing with peristimulus time histogram
    """

    @staticmethod
    def make_psth(trial_starts):
        """
        Simple wrapper creating a dataframe with times we want to lock onto
        :param trial_starts: List of times of interest (trials)
        :return: spykes object
        """
        trials = pd.DataFrame()
        trials['trialStart'] = trial_starts
        return trials

    @staticmethod
    def spykes_get_times(s_ts, s_id, debug=False):
        """
        Use spykes library
        :param s_ts:
        :param s_id:
        :param debug:
        :return:
        """

        def print_spyke(spykess):
            [print(len(spykess[i].spiketimes)) for i in range(len(spykess))]

        from spykes.plot import neurovis
        s_id = s_id.astype('int')
        neuron_list = list()
        for iu in np.unique(s_id):
            spike_times = s_ts[s_id == iu]
            if len(spike_times) < 2:
                if debug:
                    print('Too few spiketimes in this unit: ' + str(spike_times))
                else:
                    pass  # neuron_list.append(NeuroVis([],'ram'+str(iu)))
            else:
                neuron = neurovis.NeuroVis(spike_times, name='ram' + str(iu))
                neuron_list.append(neuron)

        if debug:
            print_spyke(neuron_list)
        return neuron_list

    @staticmethod
    def spykes_summary(spikes, spykes_df, event, window=[-100, 100], bin_size=10, fr_thr=.1, plotose=True):
        """

        :param spikes:
        :param spykes_df:
        :param event:
        :param window:
        :param bin_size:
        :param fr_thr:
        :param plotose:
        :return:
        """
        import spykes
        assert window[1] - window[0] > 0, 'Window size must be greater than zero!'
        # filter firing rate
        spikes = [i for i in spikes if i.firingrate > fr_thr]
        pop = spykes.plot.popvis.PopVis(spikes)
        # calculate psth
        mean_psth = pop.get_all_psth(event=event, df=spykes_df, window=window, binsize=bin_size, plot=False)
        assert mean_psth['data'][0].size > 0, 'Empty group PSTH!'
        if plotose:
            # % plot heatmap of average psth
            _ = plt.figure(figsize=(10, 10))
            #        fig.subplots_adjust(hspace=.3)
            # set_trace()
            pop.plot_heat_map(mean_psth, sortby=None, sortorder='ascend', normalize=None,
                              colors=['viridis'])  # or latency

            # %% Population PSTH
            plt.figure()
            pop.plot_population_psth(all_psth=mean_psth, event_name='Event',
                                     colors=([.5, .5, .5], [0, .6, 0]))

        return pop, mean_psth


download_data()

In [None]:
# Import basic libraries
import matplotlib.pylab as plt
import numpy as np
from tqdm import tqdm

In [None]:
# @title Figure settings

%config InlineBackend.figure_format = 'retina'

plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

In [None]:
neuron_ids_huge, spike_times_all_neurons_huge = load_spikes_from_phy('data/')

# Inspect variables

# Visualizing spiketimes without binning 

In [None]:
# Let's inspect the spikes
print(neuron_ids_huge)
print(spike_times_all_neurons_huge)

### Exercise 1
What are the dimensions of each variable?

In [None]:
# Limit the number of spikes we're dealing with but keep full variables
neuron_ids = neuron_ids_huge[spike_times_all_neurons_huge < 70]
spike_times_all_neurons = spike_times_all_neurons_huge[spike_times_all_neurons_huge < 70]

In [None]:
# Let's inspect the timestamps from one neuron
id_of_interest = neuron_ids[300]
spike_times_of_interest = spike_times_all_neurons[neuron_ids == id_of_interest]
print(spike_times_of_interest)

In [None]:
# Plot spiketimes
plt.plot(spike_times_of_interest);

## What are neurons' identities in this plot? 
## Why is the curve progressively rising?
## What's on the x- and y- axes?


### Exercise 1
Let's inspect the timestamps from one neuron - version 2

In [None]:
plt.plot(spike_times_of_interest)
plt.title('Timestamps of a neuron')
plt.ylabel('Time')
plt.xlabel('Number of spikes');

 What is a plot like this good for? (Hint: what would the change in firing rate look like?)

### Exercise 2
# Visualizing raster of spiketimes 

**Suggestions**
* Loop variable `step` for 10 steps (`step` takes values from `0` to `9`)
* At each time step
    * Compute the value of `t` with variables `step` and `dt`
    * Compute the value of `i`
    * Print `i`

In [None]:
## TODO for students: compute the mean squared error
# Uncomment below to get started


# Raster of one neuron, by hand
# for timestamp in spike_times_of_interest:
#     plt.scatter(x=..., y=..., marker='|', color='black')

In [None]:
# to_remove solution

# Raster of one neuron, by hand
for timestamp in spike_times_of_interest:
    plt.scatter(x=timestamp, y=1, marker='|', color='black')

This is our first For loop! How does it work?


In [None]:
# to_remove solution
# Whole raster for one neuron, no loop
plt.scatter(x=spike_times_of_interest,
            y=np.ones(spike_times_of_interest.shape[0]),
            marker='|');

In [None]:
## TODO for students: compute the mean squared error
# Uncomment below to get started
# Whole raster for one neuron, no loop
# plt.scatter(x=...,
#             y=np.ones(spike_times_of_interest.shape...),
#             marker='|');

In [None]:
# to_remove solution
# For all neurons now
plt.scatter(x=spike_times_all_neurons,
            y=neuron_ids,
            marker='|', color='black', alpha=.7);

In [None]:
## TODO for students: compute the mean squared error
# Uncomment below to get started

# For all neurons now
# plt.scatter(x=...,
#             y=...,
#             marker='|', color='black', alpha=.7);

# visualizing binned raster 


In [None]:
# Bin neurons
raster = bin_neurons(spike_times_all_neurons, neuron_ids, bin_size=.05)
raster

In [None]:
raster.plot()

# For an interactive version including the whole dataset, uncomment the following code

In [None]:
# import holoviews as hv
# import hvplot.xarray # noqa
# hv.extension('bokeh')
# raster_huge = bin_neurons(spike_times_all_neurons_huge, neuron_ids_huge, bin_size=.2)

# raster.hvplot(x='Time', y='Single_unit_id',clim=(0,22)).opts(cmap='viridis')

In [None]:
# Convenient one-liners to "slice" the raster by time or neurons
raster.sel(Single_unit_id=slice(20, 1200), Time=slice(60, 65)).plot(robust=True);
## Do you see the difference? We are missing-ish the up states

In [None]:
# PSTH
lfp, down_states = identify_down_states(spike_times_all_neurons_huge)

In [None]:
# PSTH based on homemade code
raster = bin_neurons(spike_times_all_neurons_huge, neuron_ids_huge, bin_size=.005)
psth = []
for down_state in tqdm(down_states):
    # get raster
    raster_now = raster.sel(Time=slice(down_state - .100,
                                       down_state + .300))
    # add raster to psth
    psth.append(raster_now.sel(Single_unit_id=241).values[:80])  # Note no equals sign
psth_all = np.vstack(psth)
# With for loop as the basic building block, you can do anything!
## plot psth
plt.pcolormesh(psth_all);
## Q: What are the axes?

In [None]:
# ============= Kording psth toolbox is a lot easier ===============
# make a psth object using the toolbox:
psth_object = PSTH.make_psth(down_states)
# make raster using toolbox:
spykes_times = PSTH.spykes_get_times(spike_times_all_neurons_huge, neuron_ids_huge)
# Get mean PSTH for all neurons
pop, all_psth = PSTH.spykes_summary(spikes=spykes_times, spykes_df=psth_object,
                                    event='trialStart', window=[-300, 400], bin_size=5, plotose=True)
# Plot mean PSTH for all neurons
pop.plot_heat_map(all_psth, sortby='rate', sortorder='ascend', normalize=None, colors=['viridis']);  # or latency
## Note the periodicity
## Q: Do you think if you flip the detection, the result will be similar and opposite? Hint: Inspect the raster.
# Test your hypothesis by setting down_states = np.where(lfp > 90)[0]