# Dynamical Neuroscience in Ukraine Academy: Day 4, Tutorial 2
# Real data with Python

In [None]:
# @title Helper functions
# Install packages
!pip install -U hvplot bokeh>/dev/null
!pip install matplotlib pandas xarray numpy tqdm spykes >/dev/null
import os

import pandas as pd
import xarray as xr
import requests
from pathlib import Path
import zipfile


def download_data():
    print('Downloading data... Please wait. Should take less than 4 min')
    # Get link:
    r = requests.get('http://data.cortexlab.net/singlePhase3/data/dataset.zip')
    path = Path('data/dataset.zip')
    path.parent.mkdir(exist_ok=True)
    if not path.exists():
        with open(path, "wb") as fid:
            # Write out content of link:
            fid.write(r.content)
    # Unzip
    with zipfile.ZipFile(path, 'r') as zip_ref:
        zip_ref.extractall(path.parent)


def load_spikes_from_phy(path_to_data='/Users/myroshnychenkm2/Downloads/dataset/', sampling_frequency=30000):
    """
    Get spikes from a kilosort/phy result folder
    :param path_to_data:
    :param sampling_frequency:
    :return:
    :id: neuron id, 1xN
    :ts: corresponding spiketime, 1xN
    """
    groupfname = os.path.join(path_to_data, 'cluster_groups.csv')
    groups = pd.read_csv(groupfname, delimiter='\t')

    # load spike times and cluster IDs
    with open(path_to_data + 'spike_clusters.npy', 'rb') as f:
        ids = np.load(f).flatten()
    with open(path_to_data + 'spike_times.npy', 'rb') as f:
        ts = np.load(f).flatten()

    # Create the list of our "good" labeled units
    ids_to_take = groups[(groups.group == 'good')].cluster_id
    # Find which spikes beloing to our "good" groups
    spikes_to_take = []
    for i in tqdm(ids_to_take, desc='Selecting only good spikes'):
        spikes_to_take.extend((ids == i).nonzero()[0])
    # only take spikes that are in our list
    ids = np.array(ids[spikes_to_take])
    ts = np.array(ts[spikes_to_take]).astype(float) / sampling_frequency

    return ids, ts


def bin_neuron(spike_times, bin_size=.100, window=None):
    """
    Make binned raster for a single neuron
    :param spike_times:
    :param bin_size: in sec
    :param window:
    :return:
    """
    if window is None:
        window = [0, spike_times.max()]
    bins = np.arange(window[0], window[1] + bin_size, bin_size)
    return np.histogram(spike_times, bins)[0]


def bin_neurons(spike_times, Neuron_IDs, bin_size=None, window=None, plotose=False):
    """
    Make binned raster for many neurons
    :param spike_times:
    :param Neuron_IDs:
    :param bin_size: in sec
    :param window:
    :param plotose:
    :return:
    """
    if window is None:
        window = [0, spike_times.max()]
    # the following uses an inline for loop (look it up):
    spike_counts = [bin_neuron(spike_times[Neuron_IDs == Neuron_ID], bin_size, window)
                    for Neuron_ID in tqdm(np.unique(Neuron_IDs))]
    spike_counts = np.vstack(spike_counts)
    raster = xr.DataArray(spike_counts, coords=dict(Time=np.arange(window[0], window[1] + bin_size, bin_size)[:-1],
                                                    Neuron_ID=range(len(np.unique(Neuron_IDs)))),
                          dims=['Neuron_ID', 'Time'])
    if plotose:
        raster.plot(robust=True)
    return raster


def identify_down_states(ts, bin_size=.05, number_of_neurons_treshold=20, minimum_time_between_states=0.15):
    """
    Find spontaneous periods of quiecence in spiketimes
    :param ts:
    :param bin_size:
    :param number_of_neurons_treshold:
    :param minimum_time_between_states:
    :return:
    """
    lfp = bin_neuron(np.sort(ts), bin_size=bin_size)
    down_states = np.where(lfp < number_of_neurons_treshold)[0]
    down_states_lengths = np.diff(down_states)
    print(f'Eliminating {down_states[1:][down_states_lengths < .15 / bin_size].shape} that are too short')
    down_states = down_states[1:][down_states_lengths > minimum_time_between_states / bin_size]
    print(f'Ended up with {down_states.shape} down states')
    # convert into seconds:
    down_states = down_states * bin_size
    down_states -= .03
    print(down_states[(down_states > 63) & (down_states < 68)])  # compare with raster
    return lfp, down_states


class PETH:
    """
    A collection of functions dealing with peristimulus time histogram
    """

    @staticmethod
    def make_psth(trial_starts):
        """
        Simple wrapper creating a dataframe with times we want to lock onto
        :param trial_starts: List of times of interest (trials)
        :return: spykes object
        """
        trials = pd.DataFrame()
        trials['trialStart'] = trial_starts
        return trials

    @staticmethod
    def spykes_get_times(s_ts, s_id, debug=False):
        """
        Use spykes library
        :param s_ts:
        :param s_id:
        :param debug:
        :return:
        """

        def print_spyke(spykess):
            [print(len(spykess[i].spiketimes)) for i in range(len(spykess))]

        from spykes.plot import neurovis
        s_id = s_id.astype('int')
        neuron_list = list()
        for iu in np.unique(s_id):
            spike_times = s_ts[s_id == iu]
            if len(spike_times) < 2:
                if debug:
                    print('Too few spiketimes in this unit: ' + str(spike_times))
                else:
                    pass  # neuron_list.append(NeuroVis([],'ram'+str(iu)))
            else:
                neuron = neurovis.NeuroVis(spike_times, name='ram' + str(iu))
                neuron_list.append(neuron)

        if debug:
            print_spyke(neuron_list)
        return neuron_list

    @staticmethod
    def spykes_summary(spikes, spykes_df, event, window=[-100, 100], bin_size=10, fr_thr=.1, plotose=True):
        """

        :param spikes:
        :param spykes_df:
        :param event:
        :param window:
        :param bin_size:
        :param fr_thr:
        :param plotose:
        :return:
        """
        import spykes
        assert window[1] - window[0] > 0, 'Window size must be greater than zero!'
        # filter firing rate
        spikes = [i for i in spikes if i.firingrate > fr_thr]
        pop = spykes.plot.popvis.PopVis(spikes)
        # calculate psth
        mean_psth = pop.get_all_psth(event=event, df=spykes_df, window=window, binsize=bin_size, plot=False)
        assert mean_psth['data'][0].size > 0, 'Empty group PSTH!'
        if plotose:
            # % plot heatmap of average psth
            _ = plt.figure(figsize=(10, 10))
            #        fig.subplots_adjust(hspace=.3)
            # set_trace()
            pop.plot_heat_map(mean_psth, sortby=None, sortorder='ascend', normalize=None,
                              colors=['viridis'])  # or latency

            # %% Population PSTH
            plt.figure()
            pop.plot_population_psth(all_psth=mean_psth, event_name='Event',
                                     colors=([.5, .5, .5], [0, .6, 0]))

        return pop, mean_psth


download_data()

In [None]:
# Import basic libraries
import matplotlib.pylab as plt
import numpy as np
from tqdm import tqdm

In [None]:
# @title Figure settings

%config InlineBackend.figure_format = 'retina'

plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

In [None]:
Neuron_IDs_huge, spike_times_all_neurons_huge = load_spikes_from_phy('data/')

# Inspect variables

In [None]:
# Let's inspect the spikes
Neuron_IDs_huge

In [None]:
spike_times_all_neurons_huge

**Questions**
* What are the numbers in the two blocks above?

### Exercise 1
What are the dimensions of each variable?

In [None]:
# TODO for students:
# Use the `X.shape` formalism to examine the shapes of each.

**Questions**
Is it possible to combine these two sets of numbers? How?

In [None]:
# Limit the number of spikes we're dealing with but keep full variables
Neuron_IDs = Neuron_IDs_huge[spike_times_all_neurons_huge < 70]
spike_times_all_neurons = spike_times_all_neurons_huge[spike_times_all_neurons_huge < 70]

Let's inspect the first ten timestamps from one neuron

In [None]:
# TODO for students:
# Complete the following code to print out the first ten timestamps from neuron 300
id_of_interest = Neuron_IDs[300]
spike_times_of_interest = spike_times_all_neurons[Neuron_IDs == id_of_interest]

## Visualize neurons' spike times

### Exercise 2
Inspect the timestamps from one neuron

In [None]:
# Plot spiketimes
plt.plot(spike_times_of_interest);

** Questions **
* What are neurons' identities in this plot?
* Why is the curve progressively rising?
* What's on the x- and y- axes?



In [None]:
plt.plot(spike_times_of_interest)
plt.title('Timestamps of a neuron')
plt.ylabel('Time')
plt.xlabel('Number of spikes');

 What is a plot like this good for? (Hint: what would the change in firing rate look like?)

### Exercise 3
Visualize a neuron's spike times using a For loop

**Suggestions**
* Loop over time stamps of action potentials
* At each time stamp
   * Make a vertical tick at that value of time on the x axis
   * What should the y axis be?

In [None]:
## TODO for students:
# Uncomment below to get started


# Raster of one neuron, by hand
# for timestamp in spike_times_of_interest:
#     plt.scatter(x=..., y=..., marker='|', color='black')

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_5c2ba34c.py)

*Example output:*

<img alt='Solution hint' align='left' width=558 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_5c2ba34c_0.png>



### Exercise 4
Visualize a neuron's spike times without a for loop

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_88518c98.py)

*Example output:*

<img alt='Solution hint' align='left' width=558 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_88518c98_1.png>



In [None]:
## TODO for students:
# Uncomment below to get started

# Whole raster for one neuron, no loop
# plt.scatter(x=...,
#             y=np.ones(spike_times_of_interest.shape...),
#             marker='|', color='black');

### Exercise 4

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_9ecfc96d.py)

*Example output:*

<img alt='Solution hint' align='left' width=558 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_9ecfc96d_1.png>



In [None]:
## TODO for students:
# Uncomment below to get started

# For all neurons now
# plt.scatter(x=...,
#             y=...,
#             marker='|', color='black', alpha=.7);

## Visualize binned raster of spikes

### Exercise 5


In [None]:
# Bin neurons
raster = bin_neurons(spike_times_all_neurons, Neuron_IDs, bin_size=.05)
raster

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_ce7ef659.py)

*Example output:*

<img alt='Solution hint' align='left' width=529 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_ce7ef659_1.png>



In [None]:
## TODO for students:
# Uncomment below to get started

# raster.plot(x=...,
#             y=...,
#             robust=True);

### Exercise 6
Interactive version including the whole dataset

In [None]:
## TODO for students:
# Uncomment below to plot interactive full dataset. Zoom in on areas of the raster that catch your eye


# import holoviews as hv
# import hvplot.xarray # noqa
# hv.extension('bokeh')
# raster_huge = bin_neurons(spike_times_all_neurons_huge, Neuron_IDs_huge, bin_size=.12)

# raster_huge.hvplot(x='Time', y='Neuron_ID', clim=(0,12)).opts(cmap='viridis')

### Exercise 7
Slice dataset to save interesting times/neurons

Note: Your solution will most likely look different from mine. We are just looking for area of the
raster with striking features.

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_70a7f6ae.py)

*Example output:*

<img alt='Solution hint' align='left' width=529 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_70a7f6ae_1.png>



In [None]:
## TODO for students:
# Uncomment below to get started

# raster.sel(...=slice(..., ...),
#            ...=slice(..., ...)
#           ).plot(robust=True);

## Visualize PETH

### Exercise 8

This dataset includes some behavior. However, let's imagine we don't know what behavior is happening and when.
 Can you think of a way to zero in on some features of spikes you found earlier?


* Suggested steps:
   * Make a function that returns raster 1/4 second before and 1/4 second after a certain event

In [None]:
multiunit_activity, event_times = identify_down_states(spike_times_all_neurons_huge)

# PETH based on homemade code
raster = bin_neurons(spike_times_all_neurons_huge, Neuron_IDs_huge, bin_size=.01)
# Select one neuron
raster241 = raster.sel(Neuron_ID=241)

**Suggestions**
   1. Complete the function `half_a_second()`
      * It should select a slice of 1/4 second before an event and 1/4 second after
   2. Use `half_a_second()` to create a perievent time histogram (PETH)

**Questions**
   * What do you think the events were? Take a guess

[*Click for solution*](https://github.com/mmyros/dnu_course/tree/master//D2_Spikes/solutions/spikes_tutorial_Solution_c6d03403.py)

*Example output:*

<img alt='Solution hint' align='left' width=558 height=413 src=https://raw.githubusercontent.com/mmyros/dnu_course/master/D2_Spikes/static/spikes_tutorial_Solution_c6d03403_1.png>



In [None]:
## TODO for students:
# Uncomment below to get started
# def half_a_second(raster, timestamp):
#     return raster.sel(Time=slice(timestamp - ...,
#                                  timestamp + ...)).values[:100]
#
#
# psth = [half_a_second(raster241, ...) for ... in event_times]
# psth = np.vstack(psth)
#
# plt.pcolormesh(psth);

### Exercise 8
PETH for all neurons

In [None]:
# make a psth object using a toolbox:
psth_object = PETH.make_psth(event_times)
# make raster using toolbox:
spykes_times = PETH.spykes_get_times(spike_times_all_neurons_huge, Neuron_IDs_huge)
# Get mean PSTH for all neurons
pop, all_psth = PETH.spykes_summary(spikes=spykes_times, spykes_df=psth_object,
                                    event='trialStart', window=[-300, 400], bin_size=5, plotose=True)

**Questions**
* In your own words, what does the code above do?
* What are the axes in the resulting plots?
* Why does the curve rise up and down repeatedly?

In [None]:
# Plot mean PSTH for all neurons
pop.plot_heat_map(all_psth, sortby='rate', sortorder='ascend', normalize=None, colors=['viridis'])  # or latency

**Questions**
* How is this plot different from the previous two?
* How is this achieved? Take a guess - what would you have done?

**Advanced questions**
* How do you think the gray curve will look if you invert the detection of events we explained in Exercise 8?
* Test your hypothesis by setting `event_times = np.where(multiunit_activity > 90)[0]` and replotting