In [2]:
from fish.ephys.ephys import load, chop_trials, estimate_onset
from fish.util.fileio import load_image
from glob import glob
import numpy as np

In [3]:
dirs = {}
dirs['ephys'] = 'W:/davis/data/ephys/20170613/6dpf_cy171xcy421_f1_opto_4/6dpf_cy171xcy421_f1_opto_4.10chFlt'
dirs['im'] = 'W:/davis/data/spim/proc/20170613/6dpf_cy421xcy171_f1_opto_4_20170613_214237/proj/'

fnames = glob(dirs['im'] + 't_*')
fnames.sort()

ep_dat = load(dirs['ephys'])
cam_times = estimate_onset(ep_dat[2], 3.7, 100)[:len(fnames)]
trials = chop_trials(ep_dat[4])

In [7]:
# for each trial onset look for the nearest cam time before the onset.
# for each offset, look for the nearest cam time after offset
def match_cam_time(events, cam, timing):
    from numpy import array, where
    tmp = []
    output = []
    for a in events:
        lags = array([a-b for b in cam])
        before = len(lags[lags > 0]) - 1
        after = before + 1
        
        if (before >= 0) and (after < len(cam)):
            if timing == 'pre':                
                output.append(before)
            if timing == 'post':
                output.append(after)    
    return array(output)
    
def trigger_data(triggers, window, fnames, average=False):
    from numpy import array
    from fish.image.vol import get_stack_dims
    from os.path import split, sep
    ims_trial = []
    
    for ind, trig in enumerate(triggers):
        
        tr = []        
        for t_w in window:
            tr.append(load_image(fnames[trig + t_w]))
        tr = array(tr)
        
        
        if (len(ims_trial) == 0) or not average:
            ims_trial.append(tr)
        if average:
            if ind == 0:
                ims_trial[0] = ims_trial[0].astype('float32') / len(triggers) 
            
            ims_trial[0] += tr.astype('float32') / len(triggers)
        
    if average:         
        return ims_trial[0]
    else:
        return array(ims_trial)
    
# turn 4D array into 3D
def unroll(v):
    if v.ndim == 3:
        return v.reshape(v.shape[0] * v.shape[1], v.shape[2])
    elif v.ndim == 4:        
        return v.reshape(v.shape[0], v.shape[1] * v.shape[2], v.shape[3])

In [8]:
trials_cam = {}

for key, val in trials.items():
    pre = match_cam_time(val[0], cam_times, timing='pre')
    post = match_cam_time(val[1], cam_times, timing='post')
    trials_cam[key] = [pre, post]

In [16]:
# Get all events
window=np.arange(-40, 40)
triggered = {}
for key, value in trials_cam.items():
    if key > 0:
        triggered[key] = trigger_data(value[-1], window, fnames, average=True)

In [20]:
# save triggered volumes
out_dir = 'W:/davis/data/spim/proc/20170613/6dpf_cy421xcy171_f1_opto_4_20170613_214237/'
from skimage.io import imsave
[imsave(out_dir + 'pattern_{0}_triggered_average.tif'.format(int(k - 1)), triggered[k].astype('float32'), compress=1) for k in triggered.keys()]

[None, None, None, None, None, None, None]

In [11]:
import pyqtgraph as pq
%gui qt

In [15]:
[pq.image(unroll(triggered[k])) for k in triggered.keys()]

[<pyqtgraph.graphicsWindows.ImageWindow at 0xb0820d8>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xcaef9d8>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xcab2318>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xcaf6c18>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xc798558>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xc7bee58>,
 <pyqtgraph.graphicsWindows.ImageWindow at 0xc7e4798>]