In [1]:
def simpleaxis(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(0.5)
    ax.tick_params(width=0.5)

def ROI_planeID(expt, label):
    signals = expt.imaging_dataset().signals()[label]['rois']
    nROIs = len(signals)
    planeID = []
    for i in xrange(nROIs):
        planeID.append(signals[i]['polygons'][0][0][2])

    return np.asarray(planeID).astype('int')

def resample_trace(fluorescence_trace, number_of_data_points):
    y = fluorescence_trace
    x = np.arange(0, len(y))
    f = interpolate.interp1d(x,y)
    xnew = np.linspace(x.min(), x.max(), num = number_of_data_points)
    ynew = f(xnew)
    return ynew

def calculate_Cross_Correlation(expt, cell, velocity, number_of_lag_seconds = 5):
    mean_centered_velocity = velocity - np.nanmean(velocity)
    mean_centered_cell_trace = cell - np.nanmean(cell)
    lag_start = -(int(np.rint(1/expt.frame_period() * number_of_lag_seconds)))
    lag_stop = int(np.rint(1/expt.frame_period() * number_of_lag_seconds)) + 1
    correlation_coefficients = []
    shifts = []
    for i in range(lag_start, lag_stop):
        shifted_trace = shift(mean_centered_cell_trace, i, mode = 'constant', order = 0, cval = 0.0)
        correlation_coefficient = ma.corrcoef(ma.masked_invalid(shifted_trace), \
                                            ma.masked_invalid(mean_centered_velocity))[0][1]
        #correlation_coefficient = np.corrcoef(mean_centered_velocity, y = shifted_trace)[0][1]
        shifts.append(i)
        correlation_coefficients.append(correlation_coefficient)
    shifts = np.asarray(shifts) / float(1/expt.frame_period())
    ordered_corr_coefs = np.reshape(np.array(list(it.izip(shifts, correlation_coefficients))), (-1,2))
    corr_coef_dataframe = pd.DataFrame(data = ordered_corr_coefs, index = None, columns = ['Shift', 'Corr_Coef'])
    max_abs_value_corr_coef_index = corr_coef_dataframe.Corr_Coef.abs().idxmax()
    if ~np.isnan(max_abs_value_corr_coef_index):
        lag = corr_coef_dataframe.iloc[max_abs_value_corr_coef_index]['Shift']
        corr_coef = corr_coef_dataframe.iloc[max_abs_value_corr_coef_index]['Corr_Coef']
        return corr_coef, lag 
    else:
        return np.nan, np.nan
    
def find_putative_run_intervals(velocity_array, velocity_cut_off = 0.2): 
    putative_run_start_frames = []
    for i in range(len(velocity_array)-1):
        if velocity_array[i] <= velocity_cut_off and velocity_array[i+1] >= velocity_cut_off: 
            putative_run_start_frames.append(i+1)
    putative_run_stop_frames = []
    for i in range(len(velocity_array)-1):
        if velocity_array[i] >= velocity_cut_off and velocity_array[i+1] < velocity_cut_off:
            putative_run_stop_frames.append(i+1)
    if putative_run_stop_frames[0] < putative_run_start_frames[0]:
        #print("Appending 0 to start")
        putative_run_start_frames.insert(0,0)
    if putative_run_stop_frames[-1] < putative_run_start_frames[-1]:
        #print("Appending last frame to end")
        putative_run_stop_frames.append(len(velocity_array) - 1)
    putative_run_intervals = np.array(list(it.izip(putative_run_start_frames,putative_run_stop_frames)))
    return putative_run_intervals 

def merge_nearby_intervals(run_intervals, minimum_interval_separation = 20):
    flattened_intervals = run_intervals.flatten()
    indices_to_delete = []
    for i, difference in enumerate(np.diff(flattened_intervals)):
        if i % 2 != 0 and difference < minimum_interval_separation: ###
            indices_to_delete.append(i)
            indices_to_delete.append(i + 1)
    merged_run_intervals = np.delete(run_intervals, indices_to_delete, axis = None)
    merged_run_intervals = np.reshape(merged_run_intervals, (-1,2))
    return merged_run_intervals 

def remove_short_intervals(run_intervals, minimum_interval_length = 5):
    short_intervals = [] 
    for i in range(len(run_intervals)):
        if run_intervals[i,1] - run_intervals[i,0] < minimum_interval_length:
            short_intervals.append(i)
    filtered_run_intervals = np.delete(run_intervals, short_intervals, axis = 0)
    return filtered_run_intervals

def remove_small_amplitude_intervals(velocity_array, run_intervals, velocity_threshold = 0.5):
    small_amplitude_intervals = []
    for i in range(len(run_intervals)):
        if max(velocity_array[run_intervals[i,0]:run_intervals[i,1]]) < velocity_threshold:
            small_amplitude_intervals.append(i)
    filtered_run_intervals = np.delete(run_intervals, small_amplitude_intervals, axis = 0)
    return filtered_run_intervals 

def get_average_run_start_trace(expt, cell, velocity, velocity_cut_off = 0.2, \
    length_of_window_in_seconds = 3, number_of_data_points = 50):
    putative_intervals = find_putative_run_intervals(velocity, velocity_cut_off = velocity_cut_off)
    int1 = merge_nearby_intervals(putative_intervals)
    int2 = remove_short_intervals(int1)
    final_intervals = remove_small_amplitude_intervals(velocity,int2)
    run_start_frames = []
    for i, interval in enumerate(final_intervals):
        run_start_frames.append(interval[0])
    run_start_frames = [int(i) for i in run_start_frames]
    
    run_start_traces = []
    window_length = int(np.rint(1/expt.frame_period() * length_of_window_in_seconds)) + 1
    for run_start in run_start_frames:
        if ((run_start + window_length) < len(cell)) & (run_start - window_length > 0): 
            run_start_traces.append(cell[run_start - window_length : run_start + window_length])
        else:
            continue

    resampled_run_start_traces = []
    for run_start_trace in run_start_traces:
        resampled_run_start_traces.append(resample_trace(run_start_trace, number_of_data_points))
    resampled_run_start_traces = np.asarray(resampled_run_start_traces)

    average_run_start_trace = np.nanmean(resampled_run_start_traces, axis = 0)
    return average_run_start_trace

def get_average_run_stop_trace(expt, cell, velocity, velocity_cut_off = 0.2, \
    length_of_window_in_seconds = 3, number_of_data_points = 50):
    putative_intervals = find_putative_run_intervals(velocity, velocity_cut_off = velocity_cut_off)
    int1 = merge_nearby_intervals(putative_intervals)
    int2 = remove_short_intervals(int1)
    final_intervals = remove_small_amplitude_intervals(velocity,int2)
    run_stop_frames = []
    for i, interval in enumerate(final_intervals):
        run_stop_frames.append(interval[1])
    run_stop_frames = [int(i) for i in run_stop_frames]

    run_stop_traces = []
    window_length = int(np.rint(1/expt.frame_period() * length_of_window_in_seconds)) + 1
    for run_stop in run_stop_frames:
        if ((run_stop + window_length) < len(cell)) & (run_stop - window_length > 0): 
            run_stop_traces.append(cell[run_stop - window_length : run_stop + window_length])
        else:
            continue

    resampled_run_stop_traces = []
    for run_stop_trace in run_stop_traces:
        resampled_run_stop_traces.append(resample_trace(run_stop_trace, number_of_data_points))
    resampled_run_stop_traces = np.asarray(resampled_run_stop_traces)

    average_run_stop_trace = np.nanmean(resampled_run_stop_traces, axis = 0)
    return average_run_stop_trace  