### Plots PSTHs and calculates basic neural measures for Chronic data. 

#### Run this notebook after creating the /PlaybackPkl files inside the site folder using PlotSpikeSortedCategoriesGUI.ipynb

#### You should only have to modify the rootPath in cell 2 for this notebook to work.


In [1]:
# Dependencies 
import os
import glob
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle as pk
from scipy.stats import t

# Display and GUI
import ipywidgets as widgets

# Data Processing Stuff
from sklearn.neighbors import KernelDensity

from IPython.display import display


### Set Paths
Note that it is assumed that you have data directory accessible from the rootPath

In [2]:
rootPath = '/Users/frederictheunissen/Code/songephys/'
dataPath = 'data/birds/'

birds = [
    os.path.basename(bird)
    for bird in sorted(glob.glob(os.path.join(rootPath+dataPath, "*")), key=os.path.getmtime, reverse=True)
    if (os.path.isdir(bird))]

### Input Widgets

In [3]:
# Widget to pick bird
bird_picker = widgets.Dropdown(
    options=birds,
    value=birds[0],
    description="Bird",
    disabled=False,
)

site_picker = widgets.Dropdown(
    options=['None'],
    value='None',
    description="Site",
    disabled=False
)

playbackPkl_picker = widgets.Dropdown(
    options=['None'],
    value='None',
    description="ElectCluster",
    disabled=False
)

stimulus_picker = widgets.Dropdown(
    options=['None'],
    value='None',
    description="Stimulus",
    disabled=False
)

save_button = widgets.Button(
    description='Save PSTH',
    disabled=False,
    button_style='success', 
    tooltip='Click me to save figure', 
)


### Loading Functions

In [6]:
def load_bird(change):
    
    bird = change['new']
    if bird is not load_bird._loaded[0]:
        sites = [ 
            os.path.basename(site)
            for site in glob.glob(os.path.join(rootPath+dataPath+bird+'/sites/', "*"))
            ]

        load_bird._loaded = (bird, 'None', 'None', 'None')
        site_picker.options=['None'] + sites
        site_picker.value='None'
        playbackPkl_picker.options=['None'] 
        playbackPkl_picker.value='None'
        stimulus_picker.options=['None'] 
        stimulus_picker.value='None'
        
        
    return load_bird._loaded


def load_site(change):

    global playPklPath    
    
    bird = bird_picker.value
    site = change['new']
    if (bird, site) is not (load_site._loaded[0],load_site._loaded[1]):
        if site is not 'None':
            # Set Paths
            playPklPath = rootPath + dataPath + bird + '/sites/' + site + '/PlaybackPkl/'
    
            # Set up list of Playback files available for this site
            playPkls = [ 
                os.path.basename(efile)
                for efile in glob.glob(playPklPath + '*.pkl')
                ]
            playPkls.sort()
            playbackPkl_picker.options=['None']+playPkls 
        else:
            playPklPath = None
            electrode_picker.options=['None']
            
        playbackPkl_picker.value='None'
        stimulus_picker.options=['None'] 
        stimulus_picker.value='None'

        load_site._loaded = (bird, site, 'None', 'None')
    
    return load_site._loaded

def calc_zscore(stimName):
    global dfRelTime
    
    rows = []
    # This loop finds all rows that match but there should only be one.
    for index, row in dfRelTime.iterrows():
        if row['file'] == stimName:
            rows.append(row)
    if (len(rows) != 1):
        print('Stimulus not found or too many')
        return 0.0, 1.0
    
    # Plot the stimulus
    row = rows[0]
    
    # Calculates response diff for each stim
    spikeDiff = np.zeros(row['nTrials'])
    for it in range(row['nTrials']):
        spikeDiff[it] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
           
    # Calculate z-score and pvalue
    if (row['nTrials'] > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            spikeDiff[0] += 1
            sdiffSD = np.std(spikeDiff, ddof=1)

        zscore = np.mean(spikeDiff)/sdiffSD
        if (zscore < 0.0):
            pvalue = (t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
        else:
            pvalue = (1.0 - t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
    else:
        zscore = spikeDiff
        pvalue = 1.0
        
    return zscore, pvalue

    
def load_playbackPkl(change):
    global playPklPath
    global unitInfo, dfRelTime
    
    bird = bird_picker.value
    site = site_picker.value
    playbackPkl = change['new']
    
    if (bird, site, playbackPkl) is not (load_playbackPkl._loaded[0],load_playbackPkl._loaded[1],load_playbackPkl._loaded[2]) :
        
        if playbackPkl is not 'None':            
            # Load unitInfo and data frames 
            pklFile = playPklPath + playbackPkl
            fileIn = open(pklFile, 'rb')
            unitInfo = pk.load(fileIn)
            dfAbsTime = pk.load(fileIn)
            dfRelTime = pk.load(fileIn)
            fileIn.close()
            
            stims = list(dfRelTime['file'])        
            stimulus_picker.options=['None']+stims
        else:
            stimulus_picker.options=['None']
        
        stimulus_picker.value='None'       
        load_playbackPkl._loaded = (bird, site, playbackPkl, 'None')
        
    return load_playbackPkl._loaded

def load_stimulus(change):
    global dfRelTime
    
    bird = bird_picker.value
    site = site_picker.value
    playbackPkl = playbackPkl_picker.value
    stimulus = change['new']

    load_stimulus._loaded = (bird, site, playbackPkl, stimulus)
                        
    return load_stimulus._loaded
       




  if site is not 'None':
  if playbackPkl is not 'None':


### Initialize Input Widgets

In [7]:
load_bird._loaded = ('None', 'None', 'None', 'None')
load_site._loaded = ('None', 'None', 'None', 'None')
load_playbackPkl._loaded = ('None', 'None', 'None', 'None')
load_stimulus._loaded = ('None', 'None', 'None', 'None')

load_bird({'type' : 'change', 'new': birds[0], 'old': 'None'} )
bird_picker.observe(load_bird, 'value')
site_picker.observe(load_site, 'value')
playbackPkl_picker.observe(load_playbackPkl, 'value')
stimulus_picker.observe(load_stimulus, 'value')

### Output Widget for Plot Results

In [8]:
plotPSTH_output = widgets.Output()
plotMean_output = widgets.Output()

### Plotting Functions

In [9]:
figPSTH = None
def plot_PSTH(stimName):
    
    global dfRelTime, figPSTH
    
    
    if (stimName == 'None'):
        print('No stimulus chosen to plot')
        return
    
    rows = []
    # This loop finds all rows that match but there should only be one.
    for index, row in dfRelTime.iterrows():
        if row['file'] == stimName:
            rows.append(row)
    if (len(rows) != 1):
        print('Stimulus not found or too many')
        return
    
    # Plot the stimulus
    row = rows[0]
    if figPSTH is not None:
        plt.close(figPSTH)
        
    figPSTH = plt.figure()
    aSound = plt.axes([0.1, 0.85, 0.87, 0.15])
    aSound.plot(row['tStim'], row['stimWav'])       
        
    # Plot Microphone and spikes for each trial 
    spikeDiff = np.zeros(row['nTrials'])
    spikeTot = np.zeros(row['nTrials'])
    for it in range(row['nTrials']):
        aTrial = plt.axes([0.1, 0.85-0.6*(it+1)/row['nTrials'], 0.87, 0.6/row['nTrials']])
        aTrial.plot(row['tMic'], row['micWav'][it], 'silver', lw=0.5)
        aTrial.get_yaxis().set_ticks([])
        
        # Spike difference first 500 ms
        spikeDiff[it] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
        
        # Total number of spikes in first 100 ms
        spikeTot[it] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 1.0))
        micMax = row['micWav'][it].max()
        for tval in row['spikeTimes'][it]:
            aTrial.plot([tval, tval], [0, micMax*0.8], 'k') 
            
    # Calculate z-score and pvalue
    if (row['nTrials'] > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            titleStr = 'Spike Diff is %f in all trials' % np.mean(spikeDiff)
        else:
            zscore = np.mean(spikeDiff)/sdiffSD
            if (zscore < 0.0):
                pvalue = (t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
            else:
                pvalue = (1.0 - t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
            titleStr = 'z = %.3f p = %.4f'% (zscore, pvalue)
    else:
        titleStr = 'Spike Diff %f' % spikeDiff[0]
    
    plotPSTH_output.clear_output()
    with plotPSTH_output:
        print(titleStr)
            
    # Plot the average rate from KDE (smoothed PSTH)
    aRate = plt.axes([0.1, 0.1, 0.87, 0.15])
    aRate.plot(row['tKDE'], row['spikeKDE'])
    plt.xlabel('Time (s)')
    plt.ylabel('Spikes/s')
    plt.show()
    
    # Check kde estimate
    spikeIntegral = sum(row['spikeKDE'][(row['tKDE']>=0) & (row['tKDE']<1.0)])
    spikeIntegral *= row['tKDE'][1]-row['tKDE'][0]

    print('KDE rate estimate', spikeIntegral)
    print('Actual rate', np.mean(spikeTot))
    return

def print_plot_PSTH(b):   
    global playPklPath, figPSTH
    
    stimName = load_stimulus._loaded[3]
    
    if (stimName == 'None'):
        return

    clusterName = os.path.splitext(load_stimulus._loaded[2])[0]
    stimName = os.path.splitext(load_stimulus._loaded[3])[0]
    
    figName = playPklPath + clusterName + '_' + stimName + '.pdf'        
    figPSTH.savefig(figName)
    

def plot_meanRate(playbackPklName):  
    global dfRelTime
    kWidth = 0.03    # 30 ms width
    
    if (playbackPklName == 'None'):
        print('No stimulus chosen to plot')
        return
    
    allSpikes = []
    # Loop through all stims and trials.
    nTotal = 0
    for index, row in dfRelTime.iterrows():
        nTotal += row['nTrials']
        for it in range(row['nTrials']):
            allSpikes.append(list(row['spikeTimes'][it]))
            
    spikeDiff = np.zeros(nTotal)
    itot = 0
    for index, row in dfRelTime.iterrows():
        for it in range(row['nTrials']):        
            # Spike difference first 500 ms
            spikeDiff[itot] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
            itot += 1

    
    if (len(allSpikes) == 0):
        # When there is no data
        tRate = np.linspace(0, 1.0, num=1001)
        meanSpikeDens = np.zeros(tRate.shape)            
    else:  
        # KDE estimate of mean rate
        allSpikes = np.concatenate(allSpikes).reshape(-1,1)
        # The time axis for the mean rate - same for all rows - taking the last one
        tRate = row['tKDE'].reshape(-1,1)
        tInt = tRate[-1]-tRate[0]
        sampRate = 1.0/(tRate[1]-tRate[0])
        kdeMean = KernelDensity(kernel='gaussian', bandwidth=kWidth).fit(allSpikes)
    
        # Generate the kde rate.        
        meanSpikeDens = np.exp(kdeMean.score_samples(tRate))
        meanSpikeDens = meanSpikeDens*len(allSpikes)*sampRate/(sum(meanSpikeDens)*nTotal)

    # Calculate z-score and pvalue
    if (nTotal > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            titleStr = 'Spike Diff is %f in all trials' % np.mean(spikeDiff)
        else:
            zscore = np.mean(spikeDiff)/sdiffSD
            if (zscore < 0.0):
                pvalue = (t.cdf(zscore*np.sqrt(nTotal), nTotal))*2.0
            else:
                pvalue = (1.0 - t.cdf(zscore*np.sqrt(nTotal), nTotal))*2.0
            titleStr = 'z = %.3f p = %.4f'% (zscore, pvalue)
    elif nTotal == 1:
        titleStr = 'Spike Diff %f' % spikeDiff[0]
    else:
        titleStr = 'No data'
    
    plotMean_output.clear_output()
    with plotMean_output:
        print(titleStr)

    
    plt.plot(tRate, meanSpikeDens)
    plt.xlabel('Time (s)')
    plt.ylabel('Spike Rate (spikes/s)')
    plt.title('Playback Evoked Firing Rate')
    plt.show()
    return

save_button.on_click(print_plot_PSTH)


### Output Widgets

In [10]:
# Widget to plot rate with a given threhold
PSTH_plot = widgets.interactive_output(
    plot_PSTH, 
    {
        "stimName" : stimulus_picker
    }
)

mean_plot = widgets.interactive_output(
    plot_meanRate, 
    {
        "playbackPklName" : playbackPkl_picker
    }
)

### Run Widget

In [11]:
widgets.VBox([
    widgets.HBox([widgets.VBox([
        bird_picker,
        site_picker,
        playbackPkl_picker,
        plotMean_output,
        stimulus_picker,
        plotPSTH_output
    ]), mean_plot ]),
    PSTH_plot,
    save_button

])

VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Bird', options=('ZF17M_ZF18F', 'ZF4F', 'ZF1…