## Import Packages

In [1]:
import numpy as np
import nexfile # a .py file
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
import matplotlib.pyplot as plt
import copy
import scipy
import scipy.io
import scipy.linalg
import scipy.stats

## Data Loading Functions

In [2]:
def loadnpz(name, allow_pickle=False):
    """
    loadnpz loads compressed files
    Args:
       name (str): directory of npz file
       allow_pickle (bool): argument to allow pickle
    Returns:
       data (np array): np array from compressed npz file
    """
    if allow_pickle:
        data = np.load(name, allow_pickle=True)  # Over-rule default False (loading pickled data can execute arbitrary code)
    else:
        data = np.load(name)
    data = data.f.arr_0  # Gets np array from data, which is currently an instance of class NpzFile, which has f attribute (numpy.lib.npyio.NpzFile)
    return data

def loadFileNames(data_dir):
    """
    Gives the names of the files in the folder folderName1, which is simply the numeric data folders from /Rodent WFU DNMS
    Args:
        data_dir (str): local directory that stores the folder for neuron data
    Returns:
        fileNames (list): list of all .nex filenames as full directory name as str (./Rodent_WFU_DNMS/1193/1193u044merge-clean.nex)
    """
    dir_name =  data_dir + '/driveNeuron' # Local directory as string for Rodent WFU DNMS
    folderNames = os.listdir(dir_name) # Get list for directory contents
    rat_nums = [] # List to append folders with integer name as rats' id
    
    # If subdirectory in list is integer then add to list rat_nums
    for num in folderNames: 
        try:
            int(num)
            rat_nums.append(num)
        except:
            True    
    
    fileNames = [] # the full directory of each event for all the rats
    
    # Generate list of all WFU Rat DMS file names with full directory name
    for num in rat_nums: 
        folder = dir_name + "/" + num + "/" # folder directory for each rat
        events = os.listdir(folder) # list of one rat's .nex file name for each event
        for a in range(len(events)): 
            events[a] = folder + events[a] # concatenate the full directory name for each event .nex file
        fileNames = fileNames + events 

    return fileNames


In [3]:
# ***Change this according to local directory***
data_dir = '/Volumes/TOSHIBA/data'

## Event Data Preprocessing (save data to .npz files) 

In [130]:
def add_events(outtimeData, outputNames, new_events, included_events):
    '''
    This function sort the included_events events by timestamp and extract the events that are valid (Sample event followed by Match or Nonmatch event) 
    Args:
        outtimeData: the Timestamps of output events variables
        outputNames (list of str): the name of output events
        new_events (list of str): ['A_S|M', 'A_S|NM'] or ['B_S|M', 'B_S|NM'] (Match must be before NM events) 
        included_events: (list of str) events needed to find the specific event given the other event happened (All A or B events)
    Returns:
        updated outtimeData, outputNames
    '''
    # find the index of included_events in outputNames
    for i in included_events:
        if "_MATCH" in i:
            if i in outputNames: # check if i is outputNames
                M_idx = outputNames.index(i)
                _M = outtimeData[M_idx] # '_MATCH' timestamps
                _M = [(j, i) for j in _M]
            else: # if not, the list of timestamps is empty
                _M = [] 
                
        if "_NON" in i:
            if i in outputNames:
                NM_idx = outputNames.index(i)
                _NM = outtimeData[NM_idx] # '_NONMATCH' timestamps
                _NM = [(j, i) for j in _NM]
            else:
                _NM = []
            
        if "_SAMPLE" in i:
            if i in outputNames:
                S_idx = outputNames.index(i)        
                _S = outtimeData[S_idx] # '_SAMPLES' timestamps
                _S = [(j, i) for j in _S]
            else:
                _S = []
    
    _all = sorted(_S + _NM + _M) # sort all events of included events (one position of A or B) 
#     print('_S', len(_S))
#     print('_M', len(_M))
#     print('_NM', len(_NM))
#     print('_all', len(_all))
#     print(_all)
    
    # to store the index of events timestamps
    t_idx_SM = [] 
    t_idx_SNM = []
    t_idx_M = []
    t_idx_NM = []
    t_idx_S = []
    # loop over all the ordered timestamps 
    for i in np.arange(0, len(_all)-1):
        # record the timestamp index of Sample given Match event and the following Match event
        if '_SAMPLES' in _all[i][1]:
            if '_MATCH' in _all[i+1][1]: 
                t_idx_SM.append(i)
                t_idx_M.append(i+1)
                t_idx_S.append(i)
        # record the timestamp index of Sample given NM event and the following NM event
            elif '_NONMATCH' in _all[i+1][1]: 
                t_idx_SNM.append(i)
                t_idx_NM.append(i+1)
                t_idx_S.append(i)
    
    S_M_timestamps = np.array([_all[i][0] for i in t_idx_SM]) # stores the Timestamps of Sample given Match events
    S_NM_timestamps = np.array([_all[i][0] for i in t_idx_SNM]) # stores the Timestamps of Sample given NM events   
    M_timestamps = np.array([_all[i][0] for i in t_idx_M]) # stores the Timestamps of Match events
    NM_timestamps = np.array([_all[i][0] for i in t_idx_NM]) # stores the Timestamps of NM events
    S_timestamps = np.array([_all[i][0] for i in t_idx_S]) # stores the Timestamps of Sample events
    
    # if the events are not in correct sequence, update the timeData with corrected timestamps
#     if (len(S_M_timestamps) != len(_M)) or (len(S_NM_timestamps) != len(_NM)):
#     print('Target Timestamps SM', len(S_M_timestamps))
#     print('Target Timestamps SNM', len(S_NM_timestamps))
#     print('Target Timestamps M', len(M_timestamps))
#     print('Target Timestamps NM', len(NM_timestamps))
#     print('Target Timestamps S', len(S_timestamps))
#         print('_M', len(_M))
#         print('_NM', len(_NM))

    outtimeData[M_idx] = M_timestamps
    outtimeData[NM_idx] = NM_timestamps
    outtimeData[S_idx] = S_timestamps

    # append the filtered sample timestamps at the end of the timeData
    outtimeData.append(S_M_timestamps)
    outtimeData.append(S_NM_timestamps)
    
    for i in new_events:
        outputNames.append(i)
        
    return outtimeData, outputNames

In [128]:
def saveData(data_dir, eventNames, rat_idx=[], session_num=[], time_fr=10, fileLabel='', rm_groups=[]):
    """
    ******************************
    # include: check data consistency, check labels
    ******************************
    This function saves:
        spikeTimer: the neuron spike data of each session
        inputNames: the wire/cell names of each input neuron spike channel of each session 
        outputType: the index of the output events being predicted according to outputNames of each session
        validArgs: all valid session index
        ratLabels: all rat id for each session
        outputNames: the output events names
    Args:
       data_dir(str): local directory that stores the folder for neuron data
       eventNames(list of str): names of events
       rat_idx(list of int): rat id as labeled on file names
       session_num(list of int): index of sessions in sequence of imported files 
       time_fr(int): even number time frame centered around the event
       fileLabel(str): label for saved files
       rm_groups(list of string): event names that want to be removed from the output events
    Returns:
       None
    """
    # ******************************
    output = []
    # ******************************
   
    # "validArgs" will be the set of sessions with valid data that can be read with nexfile reader.
    validArgs = []
    # "ratLabels" are the id number for each rat
    ratLabels = []
    # This gives the file names with directory information of each session.
    fileNames = loadFileNames(data_dir)
    
    # Return invalid input message
    if len(rat_idx)!=0 and len(session_num)!=0:
        print('Invalid input, session_num need to be empty to filter rat_id!')
        return 
    
    # Get the range of session numbers and validate input session_num
    if len(session_num)==0:
        s = range(len(fileNames))
        print(f'Save all sessions data: {s} (or select rat data)')
        if len(rat_idx)!=0:
            s_rat = []
            for a in s: 
                fileName = fileNames[a] # full directory of one file
                rat_id = fileName.split('/')[5] # id num of rat
                if int(rat_id) in rat_idx:
                    s_rat.append(a) # record session number of select rat index
            s = s_rat
            print(f'Save select rat data: {len(s)}')
    elif not all(x in range(len(fileNames)) for x in session_num):
        print(f'Input session number(s) is invalid. \nMust be between [0, {len(fileNames)-1}]')
        return
    else:
        s = session_num
        print(f'Save selected sessions data: {s}')
    
    # loop over sessions
    # read each session and record valid session numbers and rat labels
    for a in s: 
        fileName = fileNames[a] # full directory of one file
        rat_id = fileName.split('/')[5] # id num of rat   
        
        try:
            reader = nexfile.Reader()
            fileData = reader.ReadNexFile(fileName) 
            
            validArgs.append(a)
            ratLabels.append(rat_id)
#             print(f'Save sessions: {a, rat_id}')
        except:
            print("Except: ", a, rat_id)   
            
#     print('Valid sessions: ', validArgs)
    
    # loop over validArgs
    for a in validArgs:
        fileName = fileNames[a]
        reader = nexfile.Reader()
        fileData = reader.ReadNexFile(fileName) 
        
        # This gets the neuron spike data and the event time data.
        # Record the names of input and output variables
        # Record the corresponding Timestamps
        inputNames = []
        intimeData = []        
        outputNames = []
        outtimeData = []
        for v in range(len(fileData['Variables'])):
            if 'Timestamps' in fileData['Variables'][v].keys():
                varName = fileData['Variables'][v]['Header']["Name"]

                #This gives the timing data for the variable
                times = fileData['Variables'][v]['Timestamps']
                times = np.array(times)

                #This appends the arg in "fileData['Variables']" which corresponds to the input variable "name"
                if ('wire' in varName) and ('cell' in varName):
                    inputNames.append(varName)
                    intimeData.append(np.copy(times))
                #This appends the arg in "fileData['Variables']" which corresponds to the output variable "name"
                elif varName in eventNames:
                    outputNames.append(varName)
                    outtimeData.append(np.copy(times))
                    
        # deal with exception of no A_Match events in variables
        if 'A_MATCH' not in outputNames:
            outputNames.insert(0, 'A_MATCH')
            outtimeData.insert(0, np.array([]))
            
#         print('input, output of each session', len(inputNames),len(outputNames))
        
        # ****************************
        # adding new groups of Sample events given conditions
        # & filtering only the events with the right sequence of action
        outtimeData, outputNames = add_events(outtimeData, outputNames, new_events = ['A_S|M', 'A_S|NM'], included_events = ['A_MATCH', 'A_NONMATCH', 'A_SAMPLES'])
        outtimeData, outputNames = add_events(outtimeData, outputNames, new_events = ['B_S|M', 'B_S|NM'], included_events = ['B_MATCH', 'B_NONMATCH', 'B_SAMPLES']) 
        
        # remove the events that are not needed 
        if len(rm_groups)!= 0:
            for e in rm_groups:
                idx = outputNames.index(e)
                outputNames.pop(idx)
                outtimeData.pop(idx)
        # **************************** 

        # This will give the frequency of each event in the order of outputNames. 
        numOutput = 0 
        outputType = np.array([])
        for n in range(len(outputNames)):        
            times = outtimeData[n] # This gives the timing of events.
            numOutput += times.shape[0] # This gives the frequency of events
#             b = np.argwhere(np.array(eventNames) == outputNames[n])[0,0] # This gives the argument in "events" of this variable.
            outputType = np.concatenate((outputType, np.zeros(times.shape[0])+n))

        print(np.unique(outputType, return_counts=True))
                    
        # spikeTimer will include the spike trains in the 10 (time_fr) second interval around each event by default.
        spikeTimer = np.zeros((numOutput, len(inputNames), 500*time_fr))  
        b0 = 0
        for n in range(len(outputNames)): #Iterating through output events
            times = outtimeData[n] #This is the times of this event
            for c in range(times.shape[0]):
                timeNow = times[c] #This is a particular time of a particular event
                for i in range(len(inputNames)):
                    spikes = intimeData[i] - timeNow #This gives the timing of neuron spikes relative to the event.
                    spikes = spikes[np.abs(spikes) < (time_fr/2)] #This gives only neuron spikes within 5 seconds of the event.
                    spikes = spikes + (time_fr/2) #This gives the timing of these neuron spikes relative to a 10 second window around the event
                    spikes = np.floor(spikes * 500).astype(int) #This rounds the time to the nearest 500th of a second.
                    spikeTimer[b0, i, spikes] = 1 #This converts the spike times to a binary spike train, and adds it to the full spike train array.
                b0+=1
        
        np.savez_compressed(data_dir+'/eventData/seperate/data_' + fileLabel + '_' + str(a) + '.npz', spikeTimer) #This saves the neuron spike data 
        inputNames = np.array(inputNames)
        np.savez_compressed(data_dir+'/eventData/seperate/input_' + fileLabel + '_' + str(a) + '.npz', inputNames) #This saves the wire/cell names of each input neuron spike channel.
        np.savez_compressed(data_dir+'/eventData/seperate/output_' + fileLabel + '_' + str(a) + '.npz', outputType) #This saves the index of the event being predicted according to outputNames
        print("Saved " + fileLabel + '_' + str(a))

        # ******************************
        # to check the number of each saved events         
        output = np.concatenate((output, outputType))
    a, c = np.unique(output, return_counts=True)
    for i in range(len(a)):
        print(outputNames[i], c[i])
        # ******************************

    validArgs = np.array(validArgs)
    np.savez_compressed(data_dir+'/eventData/seperate/validArgs_' + fileLabel + '.npz', validArgs) #This saves which sessions have valid data.
    ratLabels = np.array(ratLabels)
    np.savez_compressed(data_dir+'/eventData/seperate/ratLabels_' + fileLabel + '.npz', ratLabels) #This saves each sessions' corresponding rat label.
    np.savez_compressed(data_dir+'/eventData/seperate/outputNames_' + fileLabel + '.npz', outputNames) # This saves the outputNames for reference of the index in output


In [125]:
eventNames = ['A_MATCH', 'A_NONMATCH', 'A_SAMPLES', 'B_MATCH', 'B_NONMATCH', 'B_SAMPLES', 'A_S|M', 'A_S|NM', 'B_S|M', 'B_S|NM']


In [121]:
saveData(data_dir, eventNames, rat_idx=[], session_num=[10,15,16,17,18,19,22,24,64], fileLabel='exp_p')

Save selected sessions data: [10, 15, 16, 17, 18, 19, 22, 24, 64]
(array([1., 2., 7.]), array([80, 80, 80]))
Saved exp_p_10
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 20, 21, 18, 41, 59,  1, 20, 18, 41]))
Saved exp_p_15
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([22, 22, 20, 38, 58, 22, 20, 38]))
Saved exp_p_16
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([ 6,  6, 16, 58, 74,  6, 16, 58]))
Saved exp_p_17
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([ 3,  3, 12, 65, 77,  3, 12, 65]))
Saved exp_p_18
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([21, 21, 20,  7, 27, 21, 20,  7]))
Saved exp_p_19
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([14, 14,  4,  9, 13, 14,  4,  9]))
Saved exp_p_22
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([46, 47,  7, 37, 44, 46,  7, 37]))
Saved exp_p_24
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([29, 29, 24, 27, 51, 29, 24, 27]))
Saved exp_p_64
A_MATCH 1
A_NONMATCH 241
A_SAMPLES 243
B_MATCH 121
B_NONMATCH 282
B_SAMPLES 

In [104]:
saveData(data_dir, eventNames, rat_idx=[1036], session_num=[], fileLabel='exp_p')

Save all sessions data: range(0, 520) (or select rat data)
Save select rat data: 22
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([42, 42, 16, 22, 38, 42, 16, 22]))
Saved exp_p_6
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 8, 35, 43, 14, 23, 37,  8, 35, 14, 23]))
Saved exp_p_7
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 4, 28, 32, 16, 32, 48,  4, 28, 16, 32]))
Saved exp_p_8
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([46, 46,  6, 28, 34, 46,  6, 28]))
Saved exp_p_9
(array([1., 2., 7.]), array([80, 80, 80]))
Saved exp_p_10
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 59, 60,  7, 13, 20,  1, 59,  7, 13]))
Saved exp_p_11
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 2, 39, 41, 14, 25, 39,  2, 39, 14, 25]))
Saved exp_p_12
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 23, 24,  9, 15, 24,  1, 23,  9, 15]))
Saved exp_p_13
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 27, 28,  9, 43, 52,  1, 27,  9, 43]

In [105]:
saveData(data_dir, eventNames, fileLabel='exp')

Save all sessions data: range(0, 520) (or select rat data)
Except:  317 1138
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 5, 27, 32,  6, 42, 48,  5, 27,  6, 42]))
Saved exp_0
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 5, 20, 25,  5, 45, 50,  5, 20,  5, 45]))
Saved exp_1
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 7, 14, 21,  1, 35, 37,  7, 14,  1, 35]))
Saved exp_2
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 4, 22, 26,  4, 50, 54,  4, 22,  4, 50]))
Saved exp_3
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 9, 25, 34,  1, 45, 46,  9, 25,  1, 45]))
Saved exp_4
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 22, 28,  3, 49, 52,  6, 22,  3, 49]))
Saved exp_5
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([42, 42, 16, 22, 38, 42, 16, 22]))
Saved exp_6
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 8, 35, 43, 14, 23, 37,  8, 35, 14, 23]))
Saved exp_7
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), a

Saved exp_74
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 8, 48, 56,  8, 16, 24,  8, 48,  8, 16]))
Saved exp_75
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([10, 46, 56,  6, 18, 24, 10, 46,  6, 18]))
Saved exp_76
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([13, 29, 42, 13, 25, 38, 13, 29, 13, 25]))
Saved exp_77
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 9, 36, 45,  5, 30, 35,  9, 36,  5, 30]))
Saved exp_78
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([12, 18, 30,  2, 32, 34, 12, 18,  2, 32]))
Saved exp_79
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([17, 37, 54, 18,  8, 26, 17, 37, 18,  8]))
Saved exp_80
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([27, 28, 55,  4, 34, 38, 27, 28,  4, 34]))
Saved exp_81
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([25, 17, 42,  7, 31, 38, 25, 17,  7, 31]))
Saved exp_82
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([26, 21, 47, 14, 39, 53, 26, 21, 14

Saved exp_148
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([21, 29, 50,  1, 29, 30, 21, 29,  1, 29]))
Saved exp_149
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([14, 37, 51,  4, 25, 29, 14, 37,  4, 25]))
Saved exp_150
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([11, 37, 48,  6, 26, 32, 11, 37,  6, 26]))
Saved exp_151
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([16, 32, 48,  6, 26, 32, 16, 32,  6, 26]))
Saved exp_152
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([10, 38, 48,  5, 27, 32, 10, 38,  5, 27]))
Saved exp_153
(array([0., 1., 2., 4., 5., 6., 7., 9.]), array([13, 65, 78,  2,  2, 13, 65,  2]))
Saved exp_154
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 33, 39, 28, 13, 41,  6, 33, 28, 13]))
Saved exp_155
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 32, 33, 19, 28, 47,  1, 32, 19, 28]))
Saved exp_156
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 34, 40,  8, 32, 40,  6, 34,  8, 32]))

Saved exp_221
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([10, 45, 55, 17,  8, 25, 10, 45, 17,  8]))
Saved exp_222
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([13, 22, 41,  6, 27, 39, 13, 22,  6, 27]))
Saved exp_223
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([18, 27, 45,  5, 30, 35, 18, 27,  5, 30]))
Saved exp_224
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 5, 36, 41, 12, 27, 39,  5, 36, 12, 27]))
Saved exp_225
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 26, 39,  6, 23, 41,  6, 26,  6, 23]))
Saved exp_226
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 1, 31, 32, 14, 34, 48,  1, 31, 14, 34]))
Saved exp_227
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 7, 29, 36, 12, 32, 44,  7, 29, 12, 32]))
Saved exp_228
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 4, 29, 33,  9, 38, 47,  4, 29,  9, 38]))
Saved exp_229
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 5, 35, 40,  9, 31, 40,  

Saved exp_294
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([13, 39, 52, 15, 33, 48, 13, 39, 15, 33]))
Saved exp_295
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([13, 22, 35, 16, 49, 65, 13, 22, 16, 49]))
Saved exp_296
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([12, 39, 51, 16, 33, 49, 12, 39, 16, 33]))
Saved exp_297
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 7, 46, 53, 14, 33, 47,  7, 46, 14, 33]))
Saved exp_298
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 49, 55, 12, 33, 45,  6, 49, 12, 33]))
Saved exp_299
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([13, 47, 60, 33, 57, 90, 13, 47, 33, 57]))
Saved exp_300
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([22, 49, 71, 14, 65, 79, 22, 49, 14, 65]))
Saved exp_301
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([22, 54, 76, 27, 47, 74, 22, 54, 27, 47]))
Saved exp_302
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([17, 58, 75, 21, 54, 75, 1

Saved exp_368
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 5, 31, 36, 22, 22, 44,  5, 31, 22, 22]))
Saved exp_369
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([15, 24, 39, 25, 16, 41, 15, 24, 25, 16]))
Saved exp_370
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 2, 20, 22, 55,  3, 58,  2, 20, 55,  3]))
Saved exp_371
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([11, 27, 38, 11, 51, 62, 11, 27, 11, 51]))
Saved exp_372
(array([1., 2., 3., 4., 5., 7., 8., 9.]), array([27, 27, 33, 40, 73, 27, 33, 40]))
Saved exp_373
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([10, 42, 52, 35, 13, 48, 10, 42, 35, 13]))
Saved exp_374
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 4, 48, 52, 34, 14, 48,  4, 48, 34, 14]))
Saved exp_375
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([25, 15, 40,  6, 34, 40, 25, 15,  6, 34]))
Saved exp_376
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([10, 32, 42,  9, 29, 38, 10, 32,  9, 29]))

Saved exp_441
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([11, 39, 50, 13, 37, 50, 11, 39, 13, 37]))
Saved exp_442
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([18, 32, 50, 13, 37, 50, 18, 32, 13, 37]))
Saved exp_443
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([12, 34, 46, 18, 36, 54, 12, 34, 18, 36]))
Saved exp_444
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([22, 27, 49, 19, 32, 51, 22, 27, 19, 32]))
Saved exp_445
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([27, 23, 50,  4, 46, 50, 27, 23,  4, 46]))
Saved exp_446
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([18, 37, 55,  7, 38, 45, 18, 37,  7, 38]))
Saved exp_447
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([15, 35, 50, 11, 39, 50, 15, 35, 11, 39]))
Saved exp_448
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([15, 32, 47, 13, 40, 53, 15, 32, 13, 40]))
Saved exp_449
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([21, 32, 53,  9, 38, 47, 2

Saved exp_514
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([14, 38, 52, 21, 27, 48, 14, 38, 21, 27]))
Saved exp_515
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 7, 39, 46, 16, 38, 54,  7, 39, 16, 38]))
Saved exp_516
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 2, 44, 46, 15, 39, 54,  2, 44, 15, 39]))
Saved exp_517
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([11, 39, 50, 15, 35, 50, 11, 39, 15, 35]))
Saved exp_518
(array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), array([ 6, 41, 47, 17, 36, 53,  6, 41, 17, 36]))
Saved exp_519
A_MATCH 5240
A_NONMATCH 18571
A_SAMPLES 23885
B_MATCH 5747
B_NONMATCH 18263
B_SAMPLES 24080
A_S|M 5240
A_S|NM 18571
B_S|M 5747
B_S|NM 18263


In [None]:
23811 24010

In [131]:
saveData(data_dir, eventNames, fileLabel='exp_org', rm_groups=['A_S|M', 'A_S|NM', 'B_S|M', 'B_S|NM'])

Save all sessions data: range(0, 520) (or select rat data)
Except:  317 1138
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 27, 32,  6, 42, 48]))
Saved exp_org_0
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 20, 25,  5, 45, 50]))
Saved exp_org_1
(array([0., 1., 2., 3., 4., 5.]), array([ 7, 14, 21,  1, 35, 36]))
Saved exp_org_2
(array([0., 1., 2., 3., 4., 5.]), array([ 4, 22, 26,  4, 50, 54]))
Saved exp_org_3
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 25, 34,  1, 45, 46]))
Saved exp_org_4
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 22, 28,  3, 49, 52]))
Saved exp_org_5
(array([1., 2., 3., 4., 5.]), array([42, 42, 16, 22, 38]))
Saved exp_org_6
(array([0., 1., 2., 3., 4., 5.]), array([ 8, 35, 43, 14, 23, 37]))
Saved exp_org_7
(array([0., 1., 2., 3., 4., 5.]), array([ 4, 28, 32, 16, 32, 48]))
Saved exp_org_8
(array([1., 2., 3., 4., 5.]), array([46, 46,  6, 28, 34]))
Saved exp_org_9
(array([1., 2.]), array([80, 80]))
Saved exp_org_10
(array([0., 1., 2., 3., 4., 5.]), array([ 1, 59, 60,  7, 

Saved exp_org_98
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 53, 58,  8, 34, 42]))
Saved exp_org_99
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 41, 47, 12, 21, 33]))
Saved exp_org_100
(array([0., 1., 2., 3., 4., 5.]), array([16, 32, 48,  8, 24, 32]))
Saved exp_org_101
(array([0., 1., 2., 3., 4., 5.]), array([10, 39, 49, 11, 20, 31]))
Saved exp_org_102
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 47, 52,  3, 25, 28]))
Saved exp_org_103
(array([0., 1., 2., 3., 4., 5.]), array([ 3, 36, 39,  8, 33, 41]))
Saved exp_org_104
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 21, 26,  7, 19, 26]))
Saved exp_org_105
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 49, 55, 17, 43, 60]))
Saved exp_org_106
(array([0., 1., 2., 3., 4., 5.]), array([21, 55, 76, 11, 63, 74]))
Saved exp_org_107
(array([0., 1., 2., 3., 4., 5.]), array([11, 58, 69, 25, 45, 70]))
Saved exp_org_108
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 71, 77, 35, 38, 73]))
Saved exp_org_109
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 59, 

Saved exp_org_195
(array([0., 1., 2., 3., 4., 5.]), array([ 2, 40, 42,  7, 31, 38]))
Saved exp_org_196
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 15, 24, 18, 38, 56]))
Saved exp_org_197
(array([0., 1., 2., 3., 4., 5.]), array([ 5, 14, 19, 21, 40, 61]))
Saved exp_org_198
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 17, 23, 23, 34, 57]))
Saved exp_org_199
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 17, 23, 12, 44, 56]))
Saved exp_org_200
(array([0., 1., 2., 3., 4., 5.]), array([11, 10, 21, 13, 46, 59]))
Saved exp_org_201
(array([0., 1., 2., 3., 4., 5.]), array([ 3, 34, 37,  6, 37, 43]))
Saved exp_org_202
(array([0., 1., 2., 3., 4., 5.]), array([ 3, 35, 38,  8, 34, 42]))
Saved exp_org_203
(array([0., 1., 2., 3., 4., 5.]), array([ 4, 34, 38, 13, 29, 42]))
Saved exp_org_204
(array([0., 1., 2., 3., 4., 5.]), array([ 2, 38, 40,  9, 31, 40]))
Saved exp_org_205
(array([0., 1., 2., 3., 4., 5.]), array([ 8, 31, 39, 15, 26, 41]))
Saved exp_org_206
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 36

Saved exp_org_292
(array([0., 1., 2., 3., 4., 5.]), array([19, 31, 50, 15, 35, 50]))
Saved exp_org_293
(array([0., 1., 2., 3., 4., 5.]), array([34, 15, 49,  4, 47, 51]))
Saved exp_org_294
(array([0., 1., 2., 3., 4., 5.]), array([13, 39, 52, 15, 33, 48]))
Saved exp_org_295
(array([0., 1., 2., 3., 4., 5.]), array([13, 22, 35, 16, 49, 65]))
Saved exp_org_296
(array([0., 1., 2., 3., 4., 5.]), array([12, 39, 51, 16, 33, 49]))
Saved exp_org_297
(array([0., 1., 2., 3., 4., 5.]), array([ 7, 46, 53, 14, 33, 47]))
Saved exp_org_298
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 49, 55, 12, 33, 45]))
Saved exp_org_299
(array([0., 1., 2., 3., 4., 5.]), array([13, 47, 60, 33, 57, 90]))
Saved exp_org_300
(array([0., 1., 2., 3., 4., 5.]), array([22, 49, 71, 14, 65, 79]))
Saved exp_org_301
(array([0., 1., 2., 3., 4., 5.]), array([22, 54, 76, 27, 47, 74]))
Saved exp_org_302
(array([0., 1., 2., 3., 4., 5.]), array([17, 58, 75, 21, 54, 75]))
Saved exp_org_303
(array([0., 1., 2., 3., 4., 5.]), array([19, 45

Saved exp_org_390
(array([0., 1., 2., 3., 4., 5.]), array([ 8, 37, 45,  8, 47, 55]))
Saved exp_org_391
(array([0., 1., 2., 3., 4., 5.]), array([ 3, 33, 36, 38,  6, 44]))
Saved exp_org_392
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 33, 42,  1, 37, 38]))
Saved exp_org_393
(array([0., 1., 2., 3., 4., 5.]), array([ 6, 34, 40,  5, 35, 40]))
Saved exp_org_394
(array([0., 1., 2., 3., 4., 5.]), array([ 8, 36, 44,  3, 53, 56]))
Saved exp_org_395
(array([0., 1., 2., 3., 4., 5.]), array([10, 36, 46,  2, 52, 54]))
Saved exp_org_396
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 39, 48,  2, 30, 32]))
Saved exp_org_397
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 33, 42,  2, 36, 38]))
Saved exp_org_398
(array([0., 1., 2., 3., 4., 5.]), array([14, 34, 48, 11, 41, 52]))
Saved exp_org_399
(array([0., 1., 2., 3., 4., 5.]), array([14, 37, 51, 10, 39, 49]))
Saved exp_org_400
(array([0., 1., 2., 3., 4., 5.]), array([25, 24, 49, 12, 39, 51]))
Saved exp_org_401
(array([0., 1., 2., 3., 4., 5.]), array([11, 39

Saved exp_org_487
(array([0., 1., 2., 3., 4., 5.]), array([12, 36, 48,  8, 44, 52]))
Saved exp_org_488
(array([0., 1., 2., 3., 4., 5.]), array([11, 35, 46,  2, 52, 54]))
Saved exp_org_489
(array([0., 1., 2., 3., 4., 5.]), array([16, 37, 53,  6, 41, 47]))
Saved exp_org_490
(array([0., 1., 2., 3., 4., 5.]), array([12, 53, 65, 11, 24, 35]))
Saved exp_org_491
(array([0., 1., 2., 3., 4., 5.]), array([ 9, 38, 47, 12, 41, 53]))
Saved exp_org_492
(array([0., 1., 2., 3., 4., 5.]), array([14, 39, 53,  8, 39, 47]))
Saved exp_org_493
(array([0., 1., 2., 3., 4., 5.]), array([16, 41, 57, 11, 57, 68]))
Saved exp_org_494
(array([0., 1., 2., 3., 4., 5.]), array([14, 48, 62, 15, 48, 63]))
Saved exp_org_495
(array([0., 1., 2., 3., 4., 5.]), array([18, 39, 57,  1, 42, 43]))
Saved exp_org_496
(array([0., 1., 2., 3., 4., 5.]), array([19, 37, 56,  1, 43, 44]))
Saved exp_org_497
(array([0., 1., 2., 3., 4., 5.]), array([14, 44, 58,  9, 33, 42]))
Saved exp_org_498
(array([0., 1., 2., 3., 4., 5.]), array([28, 27

In [107]:
def combineEventData(data_dir, fileLabel):
    """
    This function saves: 
        the neuron spike data,
        the index of the events,
        the valid session index,
        the binary of which neurons exist in each session
    Args:
       data_dir (str): local directory that stores the folder for neuron data
       fileLabel (str): label for saved files
    Returns:
       None
    """

    M = 50
    #M = 25

    fileNames = loadFileNames(data_dir)

    validArgs = loadnpz(data_dir + '/eventData/seperate/validArgs_' + fileLabel + '.npz')
    
    inputNamesAll = np.array([])
    outputTypeAll = np.array([])
    keyAll = np.array([])
    for v in validArgs:
        inputNames = loadnpz(data_dir + '/eventData/seperate/input_' + fileLabel + '_' + str(v) + '.npz')
        outputType = loadnpz(data_dir + '/eventData/seperate/output_' + fileLabel + '_' + str(v) + '.npz')

#         print(np.unique(outputType,return_counts=True)[0])

        #This loop removes '_ver_0' from names
        for b in range(len(inputNames)):
            if inputNames[b][-len('_ver_0'):] == '_ver_0':
                inputNames[b] = inputNames[b][:-len('_ver_0')]

        #This combines the data from sessions
        inputNamesAll = np.concatenate((inputNamesAll, inputNames))
        outputTypeAll = np.concatenate((outputTypeAll, outputType)) # output event index for each session
        keyAll = np.concatenate((keyAll, np.zeros(outputType.shape[0]) + v)) # session number of each events

    keyAll = keyAll.astype(int)

    inputNamesUnique = np.unique(inputNamesAll) #This is a list of unique neuron channel names

    np.savez_compressed(data_dir + '/eventData/general/uniqueNames.npz', inputNamesUnique)
    


#*******************************************************
    #This array will contain all the combined neuron spike data.
    dataAll = np.zeros((outputTypeAll.shape[0], inputNamesUnique.shape[0], 100 ))
    #dataAll = np.zeros((outputTypeAll.shape[0], inputNamesUnique.shape[0], 200 ))

    #sensorLocation is a binary array showing which neuron channels exist in this session
    sensorLocation = np.zeros(( validArgs.shape[0], inputNamesUnique.shape[0] ))
    count1 = 0
    
    v0=0
    for v in validArgs:       
        print(v , '/' , len(validArgs))

        inputNames = loadnpz(data_dir + '/eventData/seperate/input_' + fileLabel + '_' + str(v) + '.npz')
        inputArgs = []
        for b in range(len(inputNames)):
            if inputNames[b][-len('_ver_0'):] == '_ver_0': #Removing "ver_0" from name
                inputNames[b] = inputNames[b][:-len('_ver_0')] #Removing "ver_0" from name
            
            # This finds the number corresponding to the neuron channel name of inputNames[b]
            arg1 = np.argwhere(inputNamesUnique == inputNames[b])[0, 0]
            inputArgs.append(arg1)
        inputArgs = np.array(inputArgs).astype(int) #inputArgs is the arguments of the subset of inputNamesUnique which is equal to 'inputNames'

        sensorLocation[v0, inputArgs] = 1

        data = loadnpz(data_dir + '/eventData/seperate/data_' + fileLabel + '_' + str(v) + '.npz')
        data = data.reshape((data.shape[0], data.shape[1], data.shape[2] // M, M ))
        data = np.sum(data, axis=3) #This modifies the timing to measure how many spikes have occured in a time period of M/500 seconds. For M = 50, it is one 10th of a second.

        shape1 = data.shape
        #data1 = data.reshape((shape1[0]*shape1[1]*shape1[2],))
        #plt.hist(data1, bins=100)
        #plt.show()
        #quit()

        #data[data > 1] = 1

        #data[data > 2] = 2
        #data = data / 2

        #data[data > 5] = 5
        #data = data / 3

        #data[data > 3] = 3
        #data = data / 3

        data = np.log(data + 1) #This is a numerical transformation of the number of spikes which occur. This transformation prevents the values in "data" from being overly large in cases where many spikes occur rapidly.
        size1 = data.shape[0]

        #plt.plot(np.sum(np.sum(data, axis=1), axis=1))
        #plt.show()

        dataAll[count1:count1+size1, inputArgs] = np.copy(data) #This adds the spike data to the array of all spike data.

        count1 += size1
        v0+=1

    #quit()

    #plt.plot(np.sum(np.sum(dataAll, axis=1), axis=1))
    #plt.show()


    np.savez_compressed(data_dir + '/eventData/combined/data_' + fileLabel + '.npz', dataAll ) #This saves the neuron spike data
    np.savez_compressed(data_dir + '/eventData/combined/outputType_' + fileLabel + '.npz', outputTypeAll ) #This saves the event labels
    np.savez_compressed(data_dir + '/eventData/combined/keys_' + fileLabel + '.npz', keyAll ) #This saves the session numbers
    np.savez_compressed(data_dir + '/eventData/combined/inputLocation_' + fileLabel + '.npz', sensorLocation ) #This saves the binary of which neurons exist in each session.


In [132]:
combineEventData(data_dir, fileLabel='exp_org')

0 / 519
1 / 519
2 / 519
3 / 519
4 / 519
5 / 519
6 / 519
7 / 519
8 / 519
9 / 519
10 / 519
11 / 519
12 / 519
13 / 519
14 / 519
15 / 519
16 / 519
17 / 519
18 / 519
19 / 519
20 / 519
21 / 519
22 / 519
23 / 519
24 / 519
25 / 519
26 / 519
27 / 519
28 / 519
29 / 519
30 / 519
31 / 519
32 / 519
33 / 519
34 / 519
35 / 519
36 / 519
37 / 519
38 / 519
39 / 519
40 / 519
41 / 519
42 / 519
43 / 519
44 / 519
45 / 519
46 / 519
47 / 519
48 / 519
49 / 519
50 / 519
51 / 519
52 / 519
53 / 519
54 / 519
55 / 519
56 / 519
57 / 519
58 / 519
59 / 519
60 / 519
61 / 519
62 / 519
63 / 519
64 / 519
65 / 519
66 / 519
67 / 519
68 / 519
69 / 519
70 / 519
71 / 519
72 / 519
73 / 519
74 / 519
75 / 519
76 / 519
77 / 519
78 / 519
79 / 519
80 / 519
81 / 519
82 / 519
83 / 519
84 / 519
85 / 519
86 / 519
87 / 519
88 / 519
89 / 519
90 / 519
91 / 519
92 / 519
93 / 519
94 / 519
95 / 519
96 / 519
97 / 519
98 / 519
99 / 519
100 / 519
101 / 519
102 / 519
103 / 519
104 / 519
105 / 519
106 / 519
107 / 519
108 / 519
109 / 519
110 / 519


## Data Checking

In [134]:
data_all = loadnpz(data_dir+'/eventData/combined/data_exp_org.npz') # S|M & S|NM(NS).npz')
data_all.shape

(95642, 152, 100)

In [135]:
output_all = loadnpz(data_dir+'/eventData/combined/outputType_exp_org.npz') # _S|M & S|NM(NS).npz')
output_all.shape

(95642,)

In [138]:
a, c = np.unique(output_all, return_counts=True)
eventNames = loadnpz(data_dir+'/eventData/seperate/outputNames_exp_org.npz') #_S|M & S|NM(NS).npz')
# ['A_MATCH', 'A_NONMATCH', 'B_MATCH', 'B_NONMATCH', 'A_S|M', 'A_S|NM', 'B_S|M', 'B_S|NM'] 

for i in np.arange(len(a)):
    print(eventNames[i], c[i])

print("A_MATCH + A_NONMATCH: ", c[0] + c[1])
print("B_MATCH + B_NONMATCH: ",c[3] + c[4])
# print("A diff: ", (c[0] + c[1]) - (c[4] + c[5])) # 23719
# print("B_diff: ", (c[2] + c[3]) - (c[6] + c[7])) # 23914

# some Sample events are followed by another Sample events (Not always in order of Sample & Match/ Nonmatch events)
# thus the difference in extracted Sample events given different conditions
# also for some session there is no A_Match event as variable name, causing mislabeling of output index

A_MATCH 5240
A_NONMATCH 18571
A_SAMPLES 23811
B_MATCH 5747
B_NONMATCH 18263
B_SAMPLES 24010
A_MATCH + A_NONMATCH:  23811
B_MATCH + B_NONMATCH:  24010


In [119]:
key_all = loadnpz(data_dir+'/eventData/combined/keys_exp.npz')
key_all

array([  0,   0,   0, ..., 519, 519, 519])

In [120]:
inputLocation = loadnpz(data_dir+'/eventData/combined/inputLocation_exp.npz')
inputLocation.shape

(519, 152)

## Unused Functions

In [16]:
def investigateEventData(data_dir, eventNames, fileLabel):
    """
    This function saves *******
    """
    
    eventNames = ['A_MATCH', 'A_NONMATCH', 'B_MATCH', 'B_NONMATCH', 'A_SAMPLES', 'B_SAMPLES', 'NOSEPOKE']
    
    #"validArgs" will be the set of sessions index with valid data that are saved from the function saveEventData().
    validArgs = loadnpz(data_dir+'/eventData/seperate/validArgs_' + fileLabel + '.npz') 

    #full directory of each session
    fileNames = loadFileNames(data_dir)

    # loop over all the sessions 
    for a0 in range(len(validArgs)): 
        a = validArgs[a0]
        fileName = fileNames[a] 
        reader = nexfile.Reader()
        fileData = reader.ReadNexFile(fileName)

        #This gets the names of all the variables which have timestamps data
        variableNames = []
        for b in range(len(fileData['Variables'])):
            if 'Timestamps' in fileData['Variables'][b].keys():
                varName = fileData['Variables'][b]['Header']["Name"]
                variableNames.append(varName)
#             else: 
#                 print("b", b)
#                 print("no timestamp", fileData['Variables'][b]['Header']["Name"])
                
        variableNames = np.array(variableNames)

#         print (len(fileData['Variables']))
#         print("1size", len(variableNames))
        
        #This gets the neuron spike data and the event time data.
        variableNames2 = []
        inputNums = []
        inputNames = []
        outputNums = []
        outputNames = []
        spikeData = []
        b = 0
        for b0 in range(len(fileData['Variables'])):
            if 'Timestamps' in fileData['Variables'][b].keys():       
                #This gives the timing data for the variable
                spikes = fileData['Variables'][b]['Timestamps']
                spikes = np.array(spikes)
                spikeData.append(np.copy(spikes))

                #This puts the variable number in "spikeData" for each relevent variable
                name = variableNames[b]
                variableNames2.append(name)
                
                #This appends the arg in "fileData['Variables']" which corresponds to the variable "name"
                if ('wire' in name) and ('cell' in name):
                    inputNums.append(b)
                    inputNames.append(name)
                if name in eventNames:
                    outputNums.append(b)
                    outputNames.append(name)
                b += 1
#             else:
#                 print("b0", b0)
#                 print("no timestamp", fileData['Variables'][b0]['Header']["Name"])

        variableNames2 = np.array(variableNames2)
        
#         print ("2", variableNames2)
#         print("2size", len(variableNames2))

        #ALL_S_PHASE: A_SAMPLES, B_SAMPLES
        #NOSEPOKE
        #TRIAL
        existSpike = np.concatenate((spikeData[outputNums[0]] , spikeData[outputNums[1]] , spikeData[outputNums[2]] , spikeData[outputNums[3]]))
#         if False:
#             for b in range(len(variableNames2)):
#                 if not 'cell' in variableNames2[b]:
#                     print (variableNames2[b])
#                     spike1 = spikeData[b]
#                     print (spike1.shape)
#                     print (np.intersect1d(spike1, existSpike).shape)

        #quit()

        #arg1 = np.argwhere(variableNames2 == 'NOSEPOKE')[0, 0] #NOSEPOKE, REWARDCOUNT
        #arg2 = np.argwhere(variableNames2 == 'TRIAL')[0, 0]
        #spike1 = spikeData[arg1]
        #spike2 = spikeData[arg2]

        #print (spike1.shape)
        #print (spike2.shape)
        #print (np.intersect1d(spike1, spike2).shape)
        #print (variableNames2)
        #quit()


        outputNums = np.array(outputNums)[np.argsort(np.array(outputNames))]

        numOutput = 0
        outputType = np.array([])
        for b0 in range(len(outputNums)):
            b = outputNums[b0]
            spikes0 = spikeData[b]
            numOutput += spikes0.shape[0]

            b1 = np.argwhere(eventNames == outputNames[b0])[0, 0]

            outputType = np.concatenate(( outputType, np.zeros(spikes0.shape) + b1  ))


        spikeTimer = np.zeros((numOutput, len(inputNums), 2000 ))

        b0 = 0
        for b in outputNums:
            spikes0 = spikeData[b]
            for c in range(spikes0.shape[0]):
                timeNow = spikes0[c]

                d0 = 0
                for d in inputNums:
                    spikes = spikeData[d] - timeNow
                    spikes = spikes[np.abs(spikes) < 2]
                    spikes = spikes + 2
                    spikes = np.floor(spikes * 500).astype(int)

                    spikeTimer[b0, d0, spikes] = 1

                    #print (timeNow)
                    #print (spikes)
                    #quit()
                    d0 += 1
                #print (len(spikes))
                #quit()

                b0 += 1

        print (np.unique(outputType))


#         plt.plot(np.sum(np.sum(spikeTimer, axis=1), axis=1))
#         plt.show()

        np.savez_compressed(data_dir + '/eventData/seperate/investigate/data_' + fileLabel + '_' + str(a) + '.npz', spikeTimer)
        inputNames = np.array(inputNames)
        np.savez_compressed(data_dir + '/eventData/seperate/investigate/input_' + fileLabel + '_' + str(a) + '.npz', inputNames)
        np.savez_compressed(data_dir + '/eventData/seperate/investigate/output_' + fileLabel + '_' + str(a) + '.npz', outputType)

    validArgs = np.array(validArgs)
    np.savez_compressed(data_dir + '/eventData/seperate/investigate/validArgs_' + fileLabel + '.npz', validArgs)


In [72]:
def saveLocationNum(data_dir):

    #This loads the list of all neuron channel names
    inputNamesUnique = loadnpz(data_dir + '/eventData/general/uniqueNames.npz')

    locNames = [] #This list will include the brain location for all input neuron channels
    hemNames = [] #This list will include the hemisphere for all input neuron channels
    for a in range(inputNamesUnique.shape[0]):
        name = inputNamesUnique[a]
        name = name.split('_')
        #name = name[1:-4]


        hemName = name[1]
        locName = name[2]

        locNames.append(locName)
        hemNames.append(hemName)


    locNames = np.array(locNames)
    hemNames = np.array(hemNames)
    locNamesUnique, locNames = np.unique(locNames, return_inverse=True) #Converts brain location to number
    hemNamesUnique, hemNames = np.unique(hemNames, return_inverse=True) #Converts hemisphere to number

    #Saves information
    np.savez_compressed(data_dir + '/eventData/general/brainLocationNamesUnique.npz', locNamesUnique)
    np.savez_compressed(data_dir + '/eventData/general/brainLocationNames.npz', locNames)
    np.savez_compressed(data_dir + '/eventData/general/brainHemisphereNames.npz', hemNames)


In [73]:
saveLocationNum(data_dir)

In [74]:
def saveWireNum(data_dir):

    inputNamesUnique = loadnpz(data_dir + '/eventData/general/uniqueNames.npz') #This loads the names of the neuron channels

    wireName = [] #This list includes the wire names for all neuron input channels
    for a in range(inputNamesUnique.shape[0]):
        name = inputNamesUnique[a]
        name = name.split('_')
        name = name[1:-2]
        name = '_'.join(name)
        wireName.append(name)

    wireName = np.array(wireName)
    wireNameUnique, wireName = np.unique(wireName, return_inverse=True) #This converts the wire name to a number


    np.savez_compressed(data_dir + '/eventData/general/wireNames.npz', wireName) #This saves the wire number for each neuron input channel


In [75]:
saveWireNum(data_dir)