### Generates reconstructions of the smoothed PSTH  from the PC responses for each trail for the single units. 

#### Run this notebook after GenerateDataBase2

#### You should only have to modify the rootPath in cell 2 for this notebook to work.


In [None]:
# Dependencies 
import os
import glob
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle as pk
from scipy.stats import t
from sklearn.decomposition import PCA


### Set Paths
Note that it is assumed that you have data directory accessible from the rootPath

In [None]:
rootPath = '/Users/frederictheunissen/Google Drive/My Drive/julie/'
pklPath = 'pkl'


# The 6 birds from Julie data set
birds = ['BlaBro09xxF', 'GreBlu9508M', 'WhiBlu5396M', 'LblBlu2028M', 'WhiWhi4522M', 'YelBlu6903F']

In [None]:
# Read the data base
# Temp save of results
inPath = rootPath+'JulieDataBase.pkl'
fileIn = open(inPath,"rb")
dfDataBase = pk.load(fileIn)
pcKDE = pk.load(fileIn)
fileIn.close()

In [None]:
dfDataBase

### Verify PCs

In [None]:
# PCs of KDE
print(pcKDE.explained_variance_ratio_)

fig = plt.figure(figsize=(8,4), dpi = 300)

plt.plot(np.linspace(1,10, num=10), np.cumsum(pcKDE.explained_variance_ratio_))
plt.xlabel('Number of PCs')
plt.ylabel('Variance Explained')

plt.savefig('/Users/frederictheunissen/Desktop/PCAVarianceExplainedJulie.eps')

In [None]:

fig = plt.figure(figsize=(8,4), dpi=300)
plt.plot(pcKDE.mean_*.005, label='Mean')
plt.plot(pcKDE.components_[0,:], label='PC0')
plt.plot(pcKDE.components_[1,:], label='PC1')
plt.plot(pcKDE.components_[2,:], label='PC2')
plt.plot(pcKDE.components_[3,:], label='PC3')
plt.plot(pcKDE.components_[4,:], label='PC4')
plt.xlabel('Time (ms)')

plt.legend()

plt.savefig('/Users/frederictheunissen/Desktop/PCA5PCsJulie.eps')


### Functions
These are the same functions that are in GenerateDataBase2 and are needed to select the auditory units.

In [None]:
# These 3 z_score_stim functions could be combinned into 1.

def calc_zscore_stim(stimName):
    global dfRelTime
    
    rows = []
    # This loop finds all rows that match but there should only be one.
    for index, row in dfRelTime.iterrows():
        if row['file'] == stimName:
            rows.append(row)
    if (len(rows) != 1):
        print('Stimulus not found or too many')
        return 0.0, 1.0, 0
    
    # Choose one and only
    row = rows[0]
    
    # Calculates response diff for each stim
    spikeDiff = np.zeros(row['nTrials'])
    for it in range(row['nTrials']):
        spikeDiff[it] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
           
    # Calculate z-score and pvalue
    if (row['nTrials'] > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            spikeDiff[0] += 1
            sdiffSD = np.std(spikeDiff, ddof=1)

        zscore = np.mean(spikeDiff)/sdiffSD
        if (zscore < 0.0):
            pvalue = (t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
        else:
            pvalue = (1.0 - t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
    else:
        zscore = 0
        pvalue = 1.0
        
    return zscore, pvalue, row['nTrials']

def calc_zscore_category(catName):
    global dfRelTime
    
    rows = []
    # This loop finds all rows that match but there should only be one.
    
    nTotal = 0
    for index, row in dfRelTime.iterrows():
        if row['call_type'] == catName:
            rows.append(row)
            nTotal += row['nTrials']
    if (nTotal == 0):
        return 0.0, 1.0, nTotal
    
    spikeDiff = np.zeros(nTotal)
    itot = 0
    for row in rows:
        for it in range(row['nTrials']):
            spikeDiff[itot] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
            itot += 1
            
    # Calculate z-score and pvalue
    if (nTotal > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            spikeDiff[0] += 1
            sdiffSD = np.std(spikeDiff, ddof=1)

        zscore = np.mean(spikeDiff)/sdiffSD
        if (zscore < 0.0):
            pvalue = (t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
        else:
            pvalue = (1.0 - t.cdf(zscore*np.sqrt(row['nTrials']), row['nTrials']))*2.0
    else:
        zscore = 0
        pvalue = 1.0
        
    return zscore, pvalue, nTotal


def calc_zscore_all():  
    global dfRelTime

    # Loop through all stims and trials.
    nTotal = 0
    for index, row in dfRelTime.iterrows():
        nTotal += row['nTrials']
            
    spikeDiff = np.zeros(nTotal)
    itot = 0
    for index, row in dfRelTime.iterrows():
        for it in range(row['nTrials']):        
            # Spike difference first 500 ms
            spikeDiff[itot] = np.sum((row['spikeTimes'][it] >= 0) & (row['spikeTimes'][it] < 0.5)) - np.sum((row['spikeTimes'][it] >= -0.5) & (row['spikeTimes'][it] < 0)) 
            itot += 1

    # Calculate z-score and pvalue
    if (nTotal > 1) :
        sdiffSD = np.std(spikeDiff, ddof=1)
        if sdiffSD == 0:
            spikeDiff[0] += 1   # Add a spike to generate SD
            sdiffSD = np.std(spikeDiff, ddof=1)
            
        zscore = np.mean(spikeDiff)/sdiffSD
        if (zscore < 0.0):
            pvalue = (t.cdf(zscore*np.sqrt(nTotal), nTotal))*2.0
        else:
            pvalue = (1.0 - t.cdf(zscore*np.sqrt(nTotal), nTotal))*2.0
    else: 
        zscore = 0
        pvalue = 1.0
        
    return zscore, pvalue, nTotal

    
def load_playbackPkl(playbackPkl):
    global unitInfo, dfRelTime
                
    # Load unitInfo and data frames 
    pklFile = playbackPkl
    try:
        fileIn = open(pklFile, 'rb')
        try:
            unitInfo = pk.load(fileIn)
            dfAbsTime = pk.load(fileIn)
            dfRelTime = pk.load(fileIn)
            fileIn.close()
        except:
            print('Empty file: ', pklFile)
    except OSError as err:
        print("OS error: {0}".format(err))

    
    return
               


### Loop through the data and make some raster plots for each stimulus.

In [None]:
# Loop again through the data to fill in PC values
itKDE = 1000    # Number of points in the KDE for PCs
tInt = 5.0      # This is for a correction for KDEs... (see Unit_readh5_files)
good_Calls = ['Ag', 'Be', 'DC', 'Di', 'LT', 'Ne', 'So', 'Te', 'Th', 'Wh']
nPC = 5

for bird in birds:    
    # Find sites
    # Find pklfiles
    pklfiles = glob.glob(os.path.join(rootPath,pklPath,bird, "*.pkl"))
    kp = 'n'
    print('----------------------'+bird+'-------------------------------')
   
    # Loop through sites
    for playPkl in pklfiles:
        load_playbackPkl(playPkl)
        
        # Get the site and unit name from the file name
        site_unit = playPkl.split('/')[-1]
        site_unit_split = site_unit.split('_')
        site = site_unit_split[0]+ '_' + site_unit_split[1]
        unit = site_unit.split(site)[1]
        unit = unit[1:-4]
            
        # Select only single units.
        if unitInfo['SpikeSNR'] < 5.0:
            continue
                
        # Get a measure of auditory strength
        zTot, pTot, nTot = calc_zscore_all()
            
        # Select only units that have significant auditory responses
        if pTot > 0.01:
            continue
            
        if zTot < 1:
            continue
                               
        # Loop through call categories to get measure of call-type selectivity
        calls = dfRelTime['call_type'].unique()
        good_calls = [call for call in calls if call in good_Calls]
        ncalls = len(good_calls)
        callRendition = np.zeros(len(calls), dtype = int)
            
        if (ncalls == 0) :        # This happens if there is data for the other stims..
            continue
            
        # Stuffing the arrays with the information                                     
        for index, row in dfRelTime.iterrows():
            
            # Ignore stims that are not vocalizations
            if row['call_type'] not in good_calls:
                continue
            if row['nTrials'] <5:
                continue
            if np.sum(np.concatenate(row['spikeTimes'])) == 0 :
                continue
                
            callID = np.argwhere(row['call_type'] == calls)[0][0]
            callRendition[callID] += 1

            kdeNorm = row['spikeKDE'][0:itKDE]*tInt/row['nTrials']            
            fsamp = 1/(row['tKDE'][1]-row['tKDE'][0])

            PCSum = np.zeros((1, nPC))
            PCtrial = np.zeros((row['nTrials'], nPC))
            for it in range(row['nTrials']) :
                PCtrial[it,:] = np.dot(pcKDE.components_[0:nPC,:], -pcKDE.mean_)   # This is how we remove the mean response before projecting into PC
                for spikeTime in row['spikeTimes'][it]:                            
                    if spikeTime < -0.5:
                        continue
                    iMin = np.argmin(np.abs(row['tKDE']-spikeTime))
                    if iMin < 1000:
                        PCtrial[it,:] += fsamp*pcKDE.components_[0:nPC,iMin]   # A single spike needs to be worth 1000.0 because our sampling rate of 1ms.
                    else:
                        break
                PCSum += PCtrial[it,:]   
                

            kdeAVG = PCSum/row['nTrials']
            
            fig = plt.figure(figsize=(2,3), dpi=100)
            ax = fig.subplots(2,1, gridspec_kw={'height_ratios': [3,1]}, sharex= True)

            for it in range(row['nTrials']):
                KDETrial = np.dot(PCtrial[it,:], pcKDE.components_[0:nPC,:]) + pcKDE.mean_
                ax[1].plot(row['tKDE'][0:1000], KDETrial, color = '0.5', linewidth = 0.5)
                for spikeTime in row['spikeTimes'][it]:
                    if spikeTime < -0.5:
                        continue
                    ax[0].plot([spikeTime, spikeTime], [10*it, 10*it +5], 'k-', linewidth=0.5)

            ax[0].set_title(site_unit + '_' + row['call_type'])
            ax[0].set_xlim(-0.5, 1.5)
            ax[0].set_axis_off()

            # Reconstructed
            ax[1].plot(row['tKDE'], row['spikeKDE']*tInt/row['nTrials'])

            kdePC = pcKDE.transform(kdeNorm.reshape(1, -1))
            KDERecon = np.dot(kdePC[0,0:nPC], pcKDE.components_[0:nPC,:]) + pcKDE.mean_
            #ax[1].plot(row['tKDE'][0:1000], KDERecon, 'r')

            KDERecon2 = np.dot(kdeAVG[0], pcKDE.components_[0:nPC,:]) + pcKDE.mean_
            ax[1].plot(row['tKDE'][0:1000], KDERecon2, 'b')
            ax[1].spines['top'].set_visible(False)
            ax[1].spines['right'].set_visible(False)
            plt.xlim(-0.5, 1.5)

            plt.xlabel('Time (s)')
            plt.show()
            
            kp = input('Type n (next), p(print and next), q(next bird), x(exit):')
        
            if kp == 'p':
                fig.savefig('/Users/frederictheunissen/Desktop/'+bird + '_' + site_unit + '_' + row['call_type'] + str(callRendition[callID]) + '.eps')
            elif (kp == 'q') or (kp =='x') :
                break
        
        if (kp == 'q') or (kp == 'x'):
            break
        
    if (kp == 'x'):
        break
        

    
            

In [None]:
str(callRendition[callID])

In [None]:
callID

In [None]:
callRendition