In [None]:
from pylab import *
from scipy import *
from scipy import stats, io
import numpy as np
import struct
from phy.io import KwikModel
from attrdict import AttrDict
import matplotlib.pyplot as plt
import os as os

In [None]:
#----------------------------------------------------------------------------------------
# READ STIMULUS
#----------------------------------------------------------------------------------------
# Here we read the binary file with stimulus: 902 of 25 piezos x 1024 samples
# Text file has the type of stimulus: F sparse, C correlated, U uncorrelated
def read_stimulus(binname='Stimulus_UCC.bin', textname='Stimulus_UCC.txt'):
    ## This function reads the stimulus binary file, reads the type of stimulus,
    ## and stores it in a matrix and a row vector
    
    bin_file = open(binname,'rb')
    read_data = np.fromfile(file=bin_file, dtype=np.float32)
    read_data = read_data.reshape((-1,25,10240)) # reshapes data assuming 25 whiskers, 10240 time bins
    txt_data = np.loadtxt(textname, dtype='S8') # Makes sure data type is text decoded
    txt_data = txt_data.view(np.chararray).decode('utf-8') # Makes sure data type is text decoded
    
    return read_data, txt_data

#----------------------------------------------------------------------------------------
# READKWIKINFO
#----------------------------------------------------------------------------------------
# We read the data of the output from klusterkwik: spike times and cluster-number of each
# cluster-number is in klustaviewa series (can be as high as 130 e.g.)
def readkwikinfo(kwik, grupete=3):
    model = KwikModel(kwik) # load kwik model from file
    spiketimes = model.spike_times # extract the absolute spike times
    clusters = model.cluster_groups # extract the cluster names
    sample_rate = model.sample_rate # extract sampling freq
    
    spikedata = {} # initialise dictionary
    for cluster in clusters.keys():
        clustergroup = clusters[cluster]
        if clustergroup==grupete: # only look at specified type of cluster, 0 = noise, 1 = MUA, 2 = GOOD, 3 = unsorted
            spiketimematrix = AttrDict({'spike_times': np.zeros(len(spiketimes[where(model.spike_clusters==cluster)]))})
            spiketimematrix.spike_times = spiketimes[where(model.spike_clusters==cluster)]
            spikedata[cluster] = spiketimematrix # data structure is a dictionary with attribute accessible spiketimes
            # attribute accessible means that spikedata.spike_times works, normal dictionaries would be spikedata[spike_times]
    
    return spikedata, sample_rate

def BuildPSTH(Spikes, Vtag1, sampling_freq, t_before, t_after) :
## The first task is to find the stimulus onset times for each whisker in each sweep in each direction
    stim, stimtype = read_stimulus()
    start_and_stops = Vtag1[1:] - Vtag1[:-1]
    starts = (where(start_and_stops==1)[0]-2999)/float(sampling_freq) # time in seconds
    stops = (where(start_and_stops==-1)[0]+4110)/float(sampling_freq) # time in seconds
    
    stim = stim[0:len(stops),:,:]
    stimtype = stimtype[0:len(stops)]
    
    stim = stim[np.where(stimtype=='F')[0], :, :]
    starts = starts[np.where(stimtype=='F')[0]]
    stops = stops[np.where(stimtype=='F')[0]]
    
    stimtimes = {}
    for w in np.arange(25, dtype='int') :  
        timesUP = []
        timesDOWN = []
        for i in np.arange(len(stim), dtype='int') :
            indsUP = (np.where(stim[i, w, :]==1108.8889)[0]-1)[::2]   #we correct for 0 at the start of stim
            timesUP.append(indsUP)       
            indsDOWN = (np.where(stim[i, w, :]==-1108.8889)[0]-1)[::2]   #we correct for 0 at the start of stim
            timesDOWN.append(indsDOWN)
        stimtimes[w] = timesUP, timesDOWN #stimtimes[whisker][0][:]=UP stimtimes[whisker][1][:]=DOWN
    
    # make an 'output dict'
    # the PSTH will be built on -tbefore:tafter
    hist_inds = {}
    PSTH_spike_counts = {}
    
    # Loop each neuron and get the spikes.
    for neuron in Spikes.keys(): 
        PSTH_spike_counts[neuron], hist_inds[neuron] = PSTH_spikes(stim, stimtype, stimtimes, Spikes[neuron].spike_times, sampling_freq, t_before, t_after, starts, stops)
    
    return PSTH_spike_counts, hist_inds

def PSTH_spikes(stimulation, stimtype, stimtimes, spikes, samp, t_before, t_after, starts, stops):
    """
    stimulation   : a list of numpy arrays with a n*t stimulus inside
    stimtimes     : a list of the times the stimulus occurred for each whisker 
    spikes        : an array that contains the spike times (s)
    Vtag1         : synchronises stimulus with spike times
    samp          : sampling rate of the stimulation (Hz)
    t_before      : duration before the stim (positive, s)
    t_after       : duration after the stim (positive, s)
    starts        : the start of the F sweeps
    stops         : the stops of the F sweeps
    """
    
    stim_samp = 1/.0009997575757
    
    PSTH_spike_counts = {}
    for w in np.arange(25, dtype='int') :
        spikecountsup = 0
        spikecountsdown = 0
        for i in np.arange(len(stimulation), dtype='int') : 
            for x in np.arange(len(stimtimes[w][0][i]), dtype='int') :           
                timesUP = starts[i] + stimtimes[w][0][i][x]/stim_samp
                spikecountsup += len(spikes[(timesUP - t_before < spikes) * (spikes < timesUP + t_after)])
            for y in np.arange(len(stimtimes[w][1][i]), dtype='int') :
                timesDOWN = starts[i] + stimtimes[w][1][i][y]/stim_samp                
                spikecountsdown += len(spikes[(timesDOWN - t_before < spikes) * (spikes < timesDOWN + t_after)])
        PSTH_spike_counts[w] = spikecountsup, spikecountsdown
    
    hist_inds = {}
    
    for w in np.arange(25, dtype='int') :
        hist_inds[w] = np.zeros(PSTH_spike_counts[w][0]), np.zeros(PSTH_spike_counts[w][1])
        spikecountsup = 0
        spikecountsdown = 0
        for i in np.arange(len(stimulation), dtype='int') : 
            for x in np.arange(len(stimtimes[w][0][i]), dtype='int') :           
                timesUP = starts[i] + stimtimes[w][0][i][x]/stim_samp
                spikecountup = len(spikes[(timesUP - t_before < spikes) * (spikes < timesUP + t_after)])
                spikeidxup = spikes[(timesUP - t_before < spikes) * (spikes < timesUP + t_after)]
                spikeidxup = np.around((spikeidxup - starts[i])/float(stops[i] - starts[i])*len(stimulation[i,0]))
                hist_inds[w][0][spikecountsup:(spikecountsup+spikecountup)] = spikeidxup-stimtimes[w][0][i][x]
                spikecountsup += spikecountup
            
            for y in np.arange(len(stimtimes[w][1][i]), dtype='int') :
                timesDOWN = starts[i] + stimtimes[w][1][i][y]/stim_samp                
                spikecountdown = len(spikes[(timesDOWN - t_before < spikes) * (spikes < timesDOWN + t_after)])
                spikeidxdown = spikes[(timesDOWN - t_before < spikes) * (spikes < timesDOWN + t_after)]
                spikeidxdown = np.around((spikeidxdown - starts[i])/float(stops[i] - starts[i])*len(stimulation[i,0]))
                hist_inds[w][1][spikecountsdown:(spikecountsdown+spikecountdown)] = spikeidxdown-stimtimes[w][1][i][y]
                spikecountsdown += spikecountdown
                
    return PSTH_spike_counts, hist_inds

def display_PSTH(histdata, counts, t_before, t_after) :
    stim_samp = 1/.0009997575757 
    before_index = int(np.around(t_before*stim_samp)) # indexes
    after_index = int(np.around(t_after*stim_samp)) # indexes
    histlength = before_index + after_index + 1
    
    nup = np.zeros((25,histlength-1))
    ndown = np.zeros((25,histlength-1))
    for i in range(25) :
        if histdata[i][0].size :
            n1, bins, patches = hist(histdata[i][0], bins = np.linspace(-before_index, after_index, histlength))
            nup[i,:] = n1
            close()
        if histdata[i][1].size :
            n2, bins, patches = hist(histdata[i][1], bins = np.linspace(-before_index, after_index, histlength))
            ndown[i,:] = n2
            close()
    normnum = (1/np.sum(nup+ndown))
    height = np.max(np.array([np.max(nup), np.max(ndown)]))/(1/normnum)
      
    clf()
    for i in range(25) :
        if i == 0 :
            ax1 = subplot(5,5,1, frame_on=False)
        else :
            subplot(5,5,i+1,sharex=ax1,sharey=ax1,frame_on=False)
            xticks([],[])  #gets rid of the x ticks and numbers
            yticks([],[])  #gets rid of the y ticks and numbers
        if histdata[i][0].size :
            hist(histdata[i][0], bins = np.linspace(-before_index, after_index, histlength), color='b', alpha=0.5, edgecolor='none', histtype='stepfilled', label='Pos', weights=np.repeat(normnum, len(histdata[i][0])))
        if histdata[i][1].size :
            hist(histdata[i][1], bins = np.linspace(-before_index, after_index, histlength), color='r', alpha=0.5, edgecolor='none', histtype='stepfilled', label='Neg', weights=np.repeat(normnum, len(histdata[i][1]))) 
        #if (histdata[i][0].size) or (histdata[i][1].size) :
        xlim(-before_index, after_index)
        axvline(0, color = 'r', linewidth=1)
        axhline(0, color = 'r', linewidth=2)
        ymax = 1.02 * height
        ylim(0, ymax)
        xvals = np.array([0,10,20,30])
        yvals = np.array([0,ymax*0.9,ymax*0.9,0])
        plot(xvals, yvals, 'r-', linewidth=0.5)
        ax1.set_title('ymax =' + str( np.around(height,decimals = 3) ),fontsize=8)
        
def display_all_PSTHs_of_recording(histdata, counts, pdf_files_directory, t_before, t_after) :
    for neuron in histdata.keys() :
        clf()
        totalup = 0
        totaldown = 0
        for i in np.arange(25, dtype='int') :
            totalup+=counts[neuron][i][0]
            totaldown+=counts[neuron][i][1]
        numspikesP= totalup                  
        numspikesN= totaldown
        display_PSTH(histdata[neuron], counts[neuron], t_before, t_after)                               
        suptitle('Nrn' + str(neuron) + '_Pos' + str(int(numspikesP))+ '_Neg' + str(int(numspikesN)),fontsize=10)
        savefig(pdf_files_directory + 'Nrn' + str(neuron) + '_Pos' + str(int(numspikesP)) + '_Neg' + str(int(numspikesN)) + '_hist.pdf', format='pdf')
        clf()

In [None]:
#kwiks = ['./m2s1/MEAS-151027-2_ele01_ele08.kwik', './m2s2/MEAS-151027-2_ele09_ele16.kwik', './m2s3/MEAS-151027-2_ele17_ele24.kwik', './m2s4/MEAS-151027-2_ele25_ele32.kwik', './m2s5/MEAS-151027-2_ele33_ele40.kwik', './m2s6/MEAS-151027-2_ele41_ele48.kwik', './m2s7/MEAS-151027-2_ele49_ele56.kwik', './m2s8/MEAS-151027-2_ele57_ele64.kwik']
#INSERT LIST Of KWIK FILES with full directory paths like the list above

In [None]:
for ind in np.arange(len(kwiks)) :

    sp_file = kwiks[ind]
    #bin_file = open('MEAS-151116-3_Vtag1.dat','rb')
    #INSERT bin_file like the line above
    
    subdir = os.path.dirname(kwiks[ind])
    dirs  = [subdir+'/PDFS', subdir+'/PDFS/PDFpsth/']
    for dir in dirs:
        if not os.path.exists(dir):
            os.makedirs(dir)

    Spikes, sampling_freq = readkwikinfo(sp_file, 3)

    Vtag1 = np.fromfile(file=bin_file, dtype=np.int16)
    t_before = 0.005
    t_after = 0.045

    PSTH_spike_counts, hist_output = BuildPSTH(Spikes, Vtag1, sampling_freq, t_before, t_after)

    display_all_PSTHs_of_recording(hist_output, PSTH_spike_counts, dirs[1], t_before, t_after)
