In [1]:
def simpleaxis(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    for axis in ['top','bottom','left','right']:
        ax.spines[axis].set_linewidth(0.5)
    ax.tick_params(width=0.5)

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def loadRippleFrames(expt):
    try:
        with open(expt.LFPFilePath() + '/' + 'ripple_frames.pkl', 'rb') as f:
            rFrames = pickle.load(f)
    except:
        r_path = [fn for fn in locate('*keep_ripples.npy', expt.LFPFilePath(), ignore=None)][0]
        rtimes = np.load(r_path)
        ts = np.arange(expt.num_frames()) * expt.frame_period()
        ripple_frames = []
        for t in xrange(len(rtimes)):
            ripple_frames.append(find_nearest(ts, rtimes[t][0]))
        ripple_frames = np.asarray(ripple_frames)
        rFrames = ripple_frames[ripple_frames < expt.num_frames()]
        with open(expt.LFPFilePath() + '/' + 'ripple_frames.pkl', 'wb') as f:
            pickle.dump(rFrames, f)
    return rFrames

def ROI_planeID(expt, label):
    signals = expt.imaging_dataset().signals()[label]['rois']
    nROIs = len(signals)
    planeID = []
    for i in xrange(nROIs):
        planeID.append(signals[i]['polygons'][0][0][2])

    return np.asarray(planeID).astype('int')

def resample_trace(fluorescence_trace, number_of_data_points):
    y = fluorescence_trace
    x = np.arange(0, len(y))
    f = interpolate.interp1d(x,y)
    xnew = np.linspace(x.min(), x.max(), num = number_of_data_points)
    ynew = f(xnew)
    return ynew

def get_immobility_frames(expt, immobility_cutoff = 0.2):
    return np.where(expt.velocity()[0] < immobility_cutoff)[0]

def get_dur_SWR_fr(mousename, expt_trial_id_number, exptListAll, low_dur_cutoff, high_dur_cutoff):
    mouse = dbMouse(mousename)
    trial_ids = []
    n_ripples = []
    all_durs = []
    for expt in exptListAll:
        filepath = expt.LFPFilePath() + '/Ripple_Properties.pkl'
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        durs = data['ripple_durations']
        all_durs.append(durs)
        n_ripples.append(len(durs))
        trial_ids.append(expt.trial_id)
    trial_ids = np.asarray(trial_ids)
    all_durs = [item for sublist in all_durs for item in sublist]
    all_durs = stats.zscore(all_durs)
    dur_inds = np.where((all_durs > np.percentile(all_durs, low_dur_cutoff)) & \
                           (all_durs < np.percentile(all_durs, high_dur_cutoff)))[0]
    
    intervals = []
    for x in range(len(n_ripples)+1):
        intervals.append(int(np.sum(n_ripples[0:x])))
    
    expt_ind = np.where(trial_ids == expt_trial_id_number)[0][0]
    expt_ripple_ind = range(intervals[expt_ind], intervals[expt_ind+1])
    
    dur_ripple_frames = np.intersect1d(dur_inds, expt_ripple_ind) - expt_ripple_ind[0]
    return dur_ripple_frames

def butter_bandpass_filter(signal, lowcut, highcut, Fs = 20000, order = 4):
    nyq_f = 0.5 * Fs
    low = lowcut / nyq_f
    high = highcut / nyq_f
    b,a = butter(order, [low, high], btype = 'band')
    return lfilter(b,a,signal)

def get_0_20_perc_amp_ripple_frames(expt, amps, ripple_frames):
    combined_data = list(it.izip(amps, ripple_frames))
    frames = []
    for couple in combined_data:
        if couple[0] < np.percentile(amps, 20):
            frames.append(couple[1])
    return np.asarray(frames)

def get_20_40_perc_amp_ripple_frames(expt, amps, ripple_frames):
    combined_data = list(it.izip(amps, ripple_frames))
    frames = []
    for couple in combined_data:
        if (couple[0] > np.percentile(amps, 20)) & (couple[0] < np.percentile(amps, 40)):
            frames.append(couple[1])
    return np.asarray(frames)

def get_40_60_perc_amp_ripple_frames(expt, amps, ripple_frames):
    combined_data = list(it.izip(amps, ripple_frames))
    frames = []
    for couple in combined_data:
        if (couple[0] > np.percentile(amps, 40)) & (couple[0] < np.percentile(amps, 60)):
            frames.append(couple[1])
    return np.asarray(frames)

def get_60_80_perc_amp_ripple_frames(expt, amps, ripple_frames):
    combined_data = list(it.izip(amps, ripple_frames))
    frames = []
    for couple in combined_data:
        if (couple[0] > np.percentile(amps, 60)) & (couple[0] < np.percentile(amps, 80)):
            frames.append(couple[1])
    return np.asarray(frames)

def get_80_100_perc_amp_ripple_frames(expt, amps, ripple_frames):
    combined_data = list(it.izip(amps, ripple_frames))
    frames = []
    for couple in combined_data:
        if couple[0] > np.percentile(amps, 80):
            frames.append(couple[1])
    return np.asarray(frames)

def get_Vgat_mouse_normalized_ripple_amplitudes(mousename, expt_trial_id, exptListAll):
    mouse = dbMouse(mousename)
    trial_ids = []
    n_ripples = []
    all_amps = []
    for expt in exptListAll:
        filepath = expt.LFPFilePath() + '/Ripple_Properties.pkl'
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        amps = data['max_amplitudes']
        all_amps.append(amps)
        n_ripples.append(len(amps))
        trial_ids.append(expt.trial_id)
    all_amps = [item for sublist in all_amps for item in sublist]
    all_amps = stats.zscore(all_amps)
    
    intervals = []
    for x in range(len(n_ripples)+1):
        intervals.append(int(np.sum(n_ripples[0:x])))
    
    expt_data = {}
    for n in range(len(intervals) - 1):
        expt_data[trial_ids[n]] = all_amps[intervals[n]:intervals[n+1]]
    return expt_data[expt_trial_id]