# Create a function for plotting (using hv.Spikes) a simulated spike raster from a pandas DF

In [None]:
from simulate_spiketimes import sim_spikes, assign_groups
from plot_spike_raster import plot_spike_raster
import holoviews as hv
hv.extension('bokeh')

In [None]:
spikes_df = sim_spikes(50, 1, 30)
spikes_df['group'] = assign_groups(spikes_df.time, 4, sigma=2)

In [8]:
plot_spike_raster(spikes_df, spiketime_col='time', neuron_col='neuron', spike_train_opts={'color':'group', 'cmap':'category10'}, overlay_opts={'height':300, 'width':300})

## Scratch (ignore)

In [None]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.insert(0, '../..')
from scripts.simdata import sim_spikes, assign_groups

import numpy as np
import holoviews as hv
hv.extension('bokeh')

In [None]:
spikes_df = sim_spikes(50, 1, 30)
spikes_df['group'] = assign_groups(spikes_df.time, 4, sigma=2)

In [None]:
spikes_df.dtypes

In [None]:
spikes_df.head()

In [None]:
def spike_raster(df, spiketime_col, neuron_col, spike_train_opts=None, overlay_opts=None):

    """
    Plot a spike raster

    Args:
    - data (pandas.DataFrame): DataFrame with columns for spike times and neuron IDs
    - spiketime_col (int): Column name into `data` for spike times
    - neuron_col (int): Column name into `data` for neuron ID
    - spike_train_opts (dict): plotting opts that applies to all hv.Spikes spiketrain elements
    - overlay_opts (dict): plotting opts that applies to the hv.NdOverlay element
    
    Returns:
    - hv.NdOverlay of hv.Spikes elements
    """


    default_spike_train_opts = {'color':'black',  'cmap':'glasbey_cool', 
                        'spike_length':.95, 'tools':['hover']}
    default_overlay_opts = {'ylabel':'Neuron', 'xlabel':'Time', 'show_grid':True, 
                         'padding':0.01, 'width':1000, 'height':500, 'show_legend':False}

    # If plot opts are not None, update the defaults
    if spike_train_opts is not None:
        default_spike_train_opts.update(spike_train_opts)
    if overlay_opts is not None:
        default_overlay_opts.update(overlay_opts)
        
    # group the DataFrame by the neuron ID col and sort the resulting groups by key
    spike_groups = sorted(df.groupby(neuron_col), key=lambda x: x[0])
    
    spikes_dict = {}
    for ineuron, ispikes in spike_groups:
        spikes_dict[ineuron] = hv.Spikes(ispikes).opts(position=ineuron-.5, **default_spike_train_opts)

    overlay = hv.NdOverlay(spikes_dict, kdims=default_overlay_opts['ylabel']).opts(yticks=spikes_dict.keys, **default_overlay_opts)
    
    return overlay

In [None]:
%%time
spike_raster(spikes_df, spiketime_col='time', neuron_col='neuron', spike_train_opts={'color':'group'})

## Now in a script.. try it:

In [None]:
from scripts.spike_raster import spike_raster

In [None]:
spikes_df = sim_spikes(50, 1, 30)
spikes_df['group'] = assign_groups(spikes_df.time, 4, sigma=2)

In [None]:
spike_raster(spikes_df, spiketime_col='time', neuron_col='neuron', spike_train_opts={'color':'group', 'cmap':'category10'}, overlay_opts={'height':300, 'width':300})

## Scratch. Ignore

In [None]:
# %%time 

# spikes_dict = {}
# for ineuron, ispikes in spikes_df.groupby('neuron'):
#     spikes_dict[ineuron] = hv.Spikes(ispikes).opts(
#                 position=ineuron-.5, spike_length=.95, tools=['hover'], color='group')

# hv.NdOverlay(spikes_dict)

In [None]:
# # Numpy implementation
# def spike_raster(data, spike_col, neuron_col, ylabel='Neuron', xlabel='Time',
#                  show_grid=True, padding=0.01, width=1000, height=500):

#     """
#     Plot a spike raster

#     Args:
#     - data (numpy.ndarray): NumPy array with columns for spike times and neuron IDs
#     - spike_col (int): Column index into `data` for spike times
#     - neuron_col (int): Column index into `data` for and neuron label
    
#     Returns:
#     - HoloViews NdOverlay of HoloViews Spikes elements
#     """

#     data = spikes_df.values

#     # find the unique neurons in the specified column index
#     unique_neurons = np.unique(data[:, neuron_idx])
#     unique_neurons.sort()

#     spikes_dict = {}
#     for i_neuron in unique_neurons:
#         # extract the spike indices for each neuron
#         spike_indices = np.where(data[:, neuron_idx] == i_neuron)[0]
#         # create a Spikes element for this neuron's spike train
#         spikes_dict[i_neuron] = hv.Spikes(data[spike_indices,:], kdims=xlabel).opts(
#                     position=i_neuron-.5, spike_length=.95, tools=['hover'])

#     overlay = hv.NdOverlay(spikes_dict, kdims=ylabel).opts(
#         yticks=unique_neurons, ylabel=ylabel, show_grid=show_grid, show_legend=False, padding=padding, width=width, height=height)
    
#     return overlay
    