In [2]:
import os, sys, pickle
import datetime, time
import baysian_neural_decoding as bnd
import numpy as np
from matplotlib import pyplot as plt
from ipyparallel import Client
import pdb

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%run animal_info

In [5]:
# Initializing engines
rc = Client()
dv = rc[:]
dv.block = True
dv.activate()
print("Number of active engines: {0}".format(len(dv)))

Number of active engines: 2


In [6]:
# Loading packages on engines
%%px --block
import sys
# Add path to 
sys.path.append('YOUR_BASE_DIRECTORY/baysian_neural_decoding/')
import numpy
import baysian_neural_decoding as bnd
from random import shuffle
%load_ext autoreload
%autoreload 2
%run animal_info

[stderr:0] 
[stderr:1] 


In [7]:
def current_time(abbreviated = False):
    if not abbreviated:
        time_string = '%Y-%m-%d %H:%M:%S'
    else:
        time_string = '%Y%m%d-%H%M%S'
    return(datetime.datetime.fromtimestamp(time.time()).strftime(time_string))

def append_log(file_name, message, echo=False):
    full_message = '[' + current_time() + '] ' + message + '\n' 
    with open(file_name, "a") as f:
        f.write(full_message)
    if echo:
        print(full_message)
    pass

def convert_data_format(s, a, np, data):
    c = (s == 'T')*(a == 'NP')+(s == 'F')*(a == 'W')
    
    return({'all_stimulus': [ar['stimulus']['counts_summary'] for ar in data],
            'all_action': [ar['choice']['counts_summary'] for ar in data],
            'correct_stimulus': [ar['stimulus']['correct_counts_summary'] for ar in data],
            'correct_action': [ar['choice']['correct_counts_summary'] for ar in data],
            'incorrect_stimulus': [ar['stimulus']['incorrect_counts_summary'] for ar in data],
            'incorrect_action': [ar['choice']['incorrect_counts_summary'] for ar in data],
            'all_stimulus_probs': [ar['stimulus']['probs_summary'] for ar in data],
            'all_action_probs': [ar['choice']['probs_summary'] for ar in data],
            'correct_stimulus_probs': [ar['stimulus']['correct_probs_summary'] for ar in data], 
            'correct_action_probs': [ar['choice']['correct_probs_summary'] for ar in data],
            'incorrect_stimulus_probs': [ar['stimulus']['incorrect_probs_summary'] for ar in data], 
            'incorrect_action_probs': [ar['choice']['incorrect_probs_summary'] for ar in data],
            'stimulus_choices': [ar['stimulus']['counts'] for ar in data], 
            'action_choices': [ar['choice']['counts'] for ar in data],
            'stimulus_probs': [ar['stimulus']['probs'] for ar in data],
            'action_probs': [ar['choice']['probs'] for ar in data],
            'stimulus_times': [ar['stimulus']['times'] for ar in data],
            'action_times': [ar['choice']['times'] for ar in data],
            'stimulus': s,
            'action': a,
            'nosepokes': np,
            'correct' : c})

In [9]:
def animal_script(animal_info, neuron, log_file, case='whole_trial', spike_cutoff=3, **kwargs):
        # Collecting the basic information
        multiple = FLAGS['multiple']
        append_log(log_file, "  Loading data...")
        trial_duration = animal_info['trial_duration']
        last_trial = animal_info['last_trial']
        if not multiple:
            neuron = [neuron]
            
        event_set, spike_set = bnd.load_events_spikes_script(neuron_num=neuron, **animal_info)
        st0, s0, a0, np0, r0 = bnd.create_complete_table(event_set, spike_set, animal_info['variables'], 
                                                         trial_duration = trial_duration, 
                                                         pre_trial_duration = trial_duration)
        
        if not multiple:
            r0 = r0[0]

        if last_trial != None:
            st0 = st0[:last_trial]
            s0 = s0[:last_trial]
            a0 = a0[:last_trial]
            np0 = np0[:last_trial]
            if multiple:
                r0 = [r[:last_trial] for r in r0]
            else:
                r0 = r0[:last_trial]
        else:
            last_trial = len(s0)

        # Filtering out low spike counts
        cutoff_filter = bnd.spike_cutoff_script(r0, multiple=multiple, spike_cutoff=spike_cutoff)
        st0 = numpy.array(st0)[cutoff_filter]
        s0 = numpy.array(s0)[cutoff_filter]
        a0 = numpy.array(a0)[cutoff_filter]
        np0 = numpy.array(np0)[cutoff_filter]
        if multiple:
            r0 = [numpy.array(r)[cutoff_filter] for r in r0]
        else:
            r0 = numpy.array(r0)[cutoff_filter]

        # Replacing withhold trials with average response time
        avg_np = numpy.nanmean(np0)
        trial_times = np0[:]
        trial_times[numpy.isnan(trial_times)] = avg_np
        num_trials = len(s0) 
        
        # Creating spikes, inference times, and offsets on a case by case basis
        if case is 'whole_trial':
            # for standard analysis            
            s_offset = numpy.matrix(num_trials*[0]).T
            s_inf_times = numpy.matrix(zip(num_trials*[0], num_trials*[numpy.max(trial_times)]))
            s_trial_times = numpy.matrix(zip(num_trials*[0], trial_times[:]))
            c_offset = numpy.matrix(trial_times[:]).T
            c_inf_times = numpy.matrix(zip(num_trials*[-numpy.max(trial_times)], num_trials*[0]))
            c_trial_times = numpy.matrix(zip(-trial_times[:], num_trials*[0]))
        elif case is 'whole_trial_pre':
            # for standard analysis    
            window = FLAGS['window']
            s_offset = numpy.matrix(num_trials*[0]).T
            s_inf_times = numpy.matrix(zip(num_trials*[- window], num_trials*[numpy.max(trial_times) + window]))
            s_trial_times = numpy.matrix(zip(num_trials*[0], trial_times[:]))
            c_offset = numpy.matrix(trial_times[:]).T
            c_inf_times = numpy.matrix(zip(num_trials*[-numpy.max(trial_times) - window], num_trials*[window]))
            c_trial_times = numpy.matrix(zip(-trial_times[:], num_trials*[0]))     
        elif case is 'first_second':
            # for standard analysis            
            trial_length = FLAGS['trial_length']
            s_offset = numpy.matrix(num_trials*[0]).T
            s_inf_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
            s_trial_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
            c_offset = numpy.matrix(num_trials*[0]).T
            c_inf_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
            c_trial_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
        elif case is 'last_second':
            # for standard analysis            
            trial_length = FLAGS['trial_length']
            s_offset = numpy.matrix(trial_times[:]).T
            s_inf_times = numpy.matrix(zip(num_trials*[- trial_length], num_trials*[0]))
            s_trial_times = numpy.matrix(zip(num_trials*[- trial_length], num_trials*[0]))  
            c_offset = numpy.matrix(trial_times[:]).T
            c_inf_times = numpy.matrix(zip(num_trials*[- trial_length], num_trials*[0]))
            c_trial_times = numpy.matrix(zip(num_trials*[- trial_length], num_trials*[0])) 
        elif case is 'first_second_pre':
            # for standard analysis            
            window = FLAGS['window']
            trial_length = FLAGS['trial_length']
            s_offset = numpy.matrix(num_trials*[0]).T
            s_inf_times = numpy.matrix(zip(num_trials*[- window], num_trials*[trial_length + window]))
            s_trial_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
            c_offset = numpy.matrix(num_trials*[0]).T
            c_inf_times = numpy.matrix(zip(num_trials*[- window], num_trials*[trial_length + window]))
            c_trial_times = numpy.matrix(zip(num_trials*[0], num_trials*[trial_length]))
        elif case is 'last_second_pre':
            # for standard analysis            
            window = FLAGS['window']
            trial_length = FLAGS['trial_length']
            s_offset = numpy.matrix(trial_times[:]).T
            s_inf_times = numpy.matrix(zip(num_trials*[- trial_length - window], num_trials*[window]))
            s_trial_times = numpy.matrix(zip(num_trials*[-trial_length], num_trials*[0]))  
            c_offset = numpy.matrix(trial_times[:]).T
            c_inf_times = numpy.matrix(zip(num_trials*[-trial_length - window], num_trials*[window]))
            c_trial_times = numpy.matrix(zip(num_trials*[-trial_length], num_trials*[0]))              

        else:
            raise ValueError()
        
        # Converting the responses
        append_log(log_file, "  Converting responses...")

        max_time = numpy.max([s_inf_times + s_offset, s_trial_times + s_offset,  c_inf_times + c_offset,  c_trial_times + c_offset])
        min_time = numpy.min([s_inf_times + s_offset, s_trial_times + s_offset,  c_inf_times + c_offset,  c_trial_times + c_offset])
        if multiple:
            spikes = [[ numpy.array(resp[(resp > min_time)*(resp < max_time)]) for resp in r ] for r in r0]
        else:
            spikes = [ numpy.array(resp[(resp > min_time)*(resp < max_time)]) for resp in r0 ]
        
        times = {
            'total': (min_time, max_time),
            'stimulus': [s_offset, s_inf_times, s_trial_times],
            'choice': [c_offset, c_inf_times, c_trial_times]
                }

        variable_dict = {
            'stimulus': [{'T':0, 'F':1}, s0], 
            'choice': [{'NP':0, 'W':1}, a0]
                }

        condition_variables = {
            'correct': (s0 == 'T')*(a0 == 'NP')+(s0 == 'F')*(a0 == 'W'),
            'incorrect': (s0 == 'T')*(a0 == 'W')+(s0 == 'F')*(a0 == 'NP')
                }
        
        append_log(log_file, "  Calculating parameters...")
        PARAMS = bnd.pre_script(variable_dict, spikes, times, RESP_FUNCTION, PRE_FUNCTION, condition_variables=condition_variables, num_folds=NUM_FOLDS, **FLAGS)
        dv.push({'PARAMS': PARAMS})
            
        dv.push({'variable_dict' : variable_dict, 
                 'condition_variables' : condition_variables, 
                 'spikes' : spikes, 
                 'times' : times, 
                 'trial_duration' : trial_duration})
                
        # Running loop on engines
        append_log(log_file, "  Calculating predictions...")
        %px collection = [ bnd.main_script(variable_dict, spikes, times, RESP_FUNCTION, PROB_FUNCTION, condition_variables=condition_variables, num_folds=NUM_FOLDS, params=PARAMS, **FLAGS) for i in xrange(NUM_REPETITIONS) ]
        
        # Gathering data from engines
        append_log(log_file, "  Gathering predictions...")
        collection = dv.gather('collection', block = True)
        return(s0, a0, np0, collection)

# Testing on one Animal

In [10]:
ANIMAL = 'PFC_12302014'
NEURON = 1

NUM_FOLDS = 10
NUM_REPETITIONS = 2
SPIKE_CUTOFF = 3
CASE = 'first_second'

# Response function choses how to charicterize the spike trains (ISIs, first spike latency, etc)
# See the modules included in the baysian_neural_decoding package for functions
RESP_FUNCTION = bnd.calc_ISIs

# Probability function determines how to model the responses
PROB_FUNCTION = bnd.timed_prob

# Pre-function determines any hyperparameters needed for the algorithm
PRE_FUNCTION = bnd.set_bw

# Additional flags
FLAGS = {'multiple': False,
         'use_false': False,
         'use_PSTH': False,
         'within_class': False,
         'shuffle': False,
         'prob_from_spikes': False, 
         'at_best': False, 
         'log': True, 
         'window':  1.0,
         'step': .1,
         'bin_size': .020,
         'trial_length': .750,
         'model': 'kde'}

log_file = '../test.log'
open(log_file, "w").close()

append_log(log_file, "Sending settings to engines.")
dv.push({'RESP_FUNCTION': RESP_FUNCTION,
         'PROB_FUNCTION': PROB_FUNCTION,
         'PRE_FUNCTION': PRE_FUNCTION,
         'NUM_FOLDS': NUM_FOLDS,
         'NUM_REPETITIONS': NUM_REPETITIONS,
         'FLAGS': FLAGS})

s, a, np, data = animal_script(ANIMALS[ANIMAL], NEURON, log_file, spike_cutoff=SPIKE_CUTOFF, case=CASE)
new_data = convert_data_format(s, a, np, data)

In [11]:
new_data['all_stimulus']

[array([[ 76.,  17.],
        [ 67.,  38.]]), array([[ 76.,  17.],
        [ 69.,  36.]]), array([[ 74.,  19.],
        [ 68.,  37.]]), array([[ 75.,  18.],
        [ 69.,  36.]])]

In [12]:
new_data['all_stimulus_probs']

[array([[ 48.07419247,  44.92580753],
        [ 51.45156376,  53.54843624]]), array([[ 47.43892919,  45.56107081],
        [ 51.39214266,  53.60785734]]), array([[ 46.90030863,  46.09969137],
        [ 51.49863182,  53.50136818]]), array([[ 47.64025896,  45.35974104],
        [ 51.60939895,  53.39060105]])]

## Running on All Animals

In [None]:
from itertools import chain, combinations

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

DIRECTORY = "../results/ISI_in_time/"
NUM_REPETITIONS = 124 

num_engines = len(dv)
num_engine_reps = int(NUM_REPETITIONS / num_engines)
multiple = FLAGS['multiple']

log_file = DIRECTORY + '/run.log'
if not os.path.exists(DIRECTORY):
    os.mkdir(DIRECTORY)
if not os.path.exists(log_file):
    open(log_file, "w").close()

append_log(log_file, '*** Starting new run ***', echo=True)
append_log(log_file, "Sending settings to engines.")
dv.push({'RESP_FUNCTION': RESP_FUNCTION,
         'PROB_FUNCTION': PROB_FUNCTION,
         'PRE_FUNCTION': PRE_FUNCTION,
         'NUM_FOLDS': NUM_FOLDS, 
         'NUM_REPETITIONS': num_engine_reps,
         'FLAGS': FLAGS})

for animal in ANIMALS.keys():
    if ANIMALS[animal]['include'] == True:
        append_log(log_file, "Starting animal {0}".format(animal))
        
        if not multiple:
            current_file = DIRECTORY+'/'+animal+'.pickle'
            try:
                with open(current_file, 'rb') as f:
                    result_array = pickle.load(f)
                append_log(log_file, "  already have {0}...".format(animal))
            except:
                result_array = {}
        
        # Picking what combinations
        if multiple:
            neurons = powerset(ANIMALS[animal]['choice_neurons'])
        else:
            neurons = ANIMALS[animal]['choice_neurons']
        
        for neuron in neurons:
            if multiple and len(neuron) == 0:
                continue
            if multiple:
                current_file = DIRECTORY + '/' + animal + '-' + str(neuron) + '.pickle'
                try:
                    with open(current_file, 'rb') as f:
                        result_array = pickle.load(f)
                    append_log(log_file, "  already have {0}, {1}...".format(animal, neuron))
                except:
                    result_array = {}
            if result_array.has_key(neuron):
                append_log(log_file, "  ...skipping neuron {0}.".format(neuron))
                continue
            else:
                append_log(log_file, "Animal {0}, neuron {1}".format(animal, neuron))
            try:
                result_array[neuron] = convert_data_format(*animal_script(ANIMALS[animal], neuron, log_file, spike_cutoff=SPIKE_CUTOFF, case=CASE))                
            except KeyboardInterrupt:
                append_log(log_file, "KEYBOARD INTERRUPT!")
                raise
            except:
                append_log(log_file, "  Problem with animal {0}, neuron {1}".format(animal, neuron))
                #raise
            finally:
                # Saving the output
                append_log(log_file, "  Saving output...")
                with open(current_file, 'wb') as f:
                    pickle.dump(result_array, f)
open(DIRECTORY + "/run complete", "wb").close()
append_log(log_file, "Run complete.", echo=True)

[2017-06-22 17:31:39] *** Starting new run ***

