In [1]:
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
%matplotlib inline
from fish.ephys.ephys import load, estimate_onset, chop_trials

In [2]:
def trigger_data(triggers, window, fnames, average=False):
    from numpy import array
    from fish.image.vol import get_stack_dims
    from os.path import split, sep
    dims = get_stack_dims(split(fnames[0])[0] + sep)[::-1]
    ims_trial = []
    
    for ind, trig in enumerate(triggers):
        
        tr = []        
        for t_w in window:
            tr.append(load_single_plane(fnames, trig + t_w, tuple(dims)))
        tr = array(tr)
        
        
        if (len(ims_trial) == 0) or not average:
            ims_trial.append(tr)
        if average:
            if ind == 0:
                ims_trial[0] = ims_trial[0].astype('float32') / len(triggers) 
            
            ims_trial[0] += tr.astype('float32') / len(triggers)
        
    if average:         
        return ims_trial[0]
    else:
        return array(ims_trial)
    
def match_cam_time(events, cam, timing):
    from numpy import array, where
    tmp = []
    output = []
    for a in events:
        lags = array([a-b for b in cam])
        before = len(lags[lags > 0]) - 1
        after = before + 1
        
        if (before >= 0) and (after < len(cam)):
            if timing == 'pre':                
                output.append(before)
            if timing == 'post':
                output.append(after)    
    return array(output)

def load_single_plane(fnames, t, dims):
    from numpy import memmap, array    
    fn = fnames[t // dims[0]]
    mem = memmap(fn, shape=dims, dtype='uint16')
    return array(mem[t % dims[0]])

In [23]:
dirs = {}
dirs['ephys'] = 'W:/YuMu/SPIM/active_datasets/20170727/fish1/fish1/6dpf_gfapiglusnfr_gfap_reachr_opto_4/6dpf_gfapiglusnfr_gfap_reachr_opto_4.10chFlt'
dirs['ims'] = 'F:/Yu/20170727/fish1/20170727_1_4_gfapiglusfnr_gfapreachr_6dpf_single_fourROI_5on_10off_100laser594_20170727_143538//'
from fish.image.vol import get_stack_dims

dims = get_stack_dims(dirs['ims'])[::-1]
fnames = glob(dirs['ims'] + 'TM*')
print(len(fnames))

2862


In [24]:
ep_dat = load(dirs['ephys'])
cam_times = np.unique(estimate_onset(ep_dat[2], threshold=3, duration=50))
laser_on = estimate_onset(ep_dat[3], threshold=3.7, duration=400)
laser_off = len(ep_dat[3]) - estimate_onset(ep_dat[3][::-1], threshold=3.7, duration=400)
print(len(cam_times))

143060


In [25]:
trials = chop_trials(ep_dat[4])
# for each trial onset look for the nearest cam time before the onset.
# for each offset, look for the nearest cam time after offset
trials_cam = {}

for key, val in trials.items():
    pre = match_cam_time(val[0], cam_times, timing='pre')
    post = match_cam_time(val[1], cam_times, timing='post')
    trials_cam[key] = [pre, post]

In [26]:
# Get all events
window=np.arange(-300, 150)
triggered = {}
for key, value in trials_cam.items():
    if key > 0:
        print(key)
        print(value)
        triggered[key] = trigger_data(value[-1], window, fnames, average=True)

1.0
[array([   280,   8691,  11096,  14701,  19509,  26720,  29123,  33930,
        41139,  44747,  48352,  55562,  59169,  62774,  71187,  74792,
        80801,  82003,  90416,  92819, 100030, 102433, 109644, 110846,
       115653, 122863, 125268, 130074, 137286, 142093]), array([   481,   8894,  11298,  14905,  19709,  26922,  29324,  34134,
        41341,  44948,  48556,  55766,  59370,  62975,  71388,  74993,
        81002,  82204,  90616,  93020, 100233, 102635, 109845, 111047,
       115855, 123065, 125469, 130276, 137487, 142294])]
2.0
[array([  1481,   5087,  12298,  18307,  20710,  27922,  31527,  37536,
        38738,  47150,  49554,  56765,  60370,  65178,  68783,  72388,
        79599,  83204,  86810,  94021,  97626, 101232, 106039, 113250,
       119259, 120460, 128873, 131277, 138488, 139689]), array([  1683,   5288,  12498,  18508,  20912,  28122,  31728,  37739,
        38941,  47352,  49755,  56966,  60573,  65379,  68982,  72589,
        79800,  83405,  87014,  94222,

In [18]:
import pyqtgraph as pq
%gui qt

In [28]:
pq.image(triggered[1.0] - triggered[4.0])

<pyqtgraph.graphicsWindows.ImageWindow at 0xc5f2708>