In [None]:
import ants
from glob import glob
from importlib import reload
import matplotlib.pyplot as plt
import numpy as np
import os
from pyprind import prog_percent

from zebrafishframework import io
from zebrafishframework import img
from zebrafishframework import regtools
from zebrafishframework import rendering
from zebrafishframework import segmentation

from zebrafishframework.regtools import to_ants, to_numpy

img = reload(img)
regtools = reload(regtools)
rendering = reload(rendering)

Load the reference image

In [None]:
zb_ref_fn = '/Users/koesterlab/Registrations/Elavl3-H2BRFP.tif'
zb_ref_ants = ants.image_read(zb_ref_fn)
zb_ref_ants.set_spacing(io.SPACING_ZBB)
zb_ref_np = to_numpy(zb_ref_ants)

Load our stuff: traces, rois and anatomies for both control and stimulus groups

In [None]:
def find_bases(base_glob):
    cut = '_std_dev.h5'
    globs = glob(base_glob + cut)
    bases = [g[:-len(cut)] for g in globs]
    return bases

def load_stuff(bases):
    all_traces = [np.load(b + '_traces.npy') for b in bases]
    all_rois = [np.load(b + '_rois.npy') for b in bases]
    all_std_devs = [io.load(b + '_std_dev.h5')[0] for b in bases]
    all_masks = [io.load(b + '_mask.h5') 
                 if os.path.exists(b + '_mask.h5') 
                 else np.ones(std_dev.shape)
                for b, std_dev in zip(bases, all_std_devs)]
    
    filters = [segmentation.mask_rois(rois, mask) for rois, mask in zip(all_rois, all_masks)]
    
    # apply masks
    all_rois = [rois[filt] for rois, filt in zip(all_rois, filters)]
    all_traces = [traces[filt] for traces, filt in zip(all_traces, filters)]
    all_std_devs = [std_dev * (mask > 0) for std_dev, mask in zip(all_std_devs, all_masks)]
    
    # pretransform to zbb (flip and rotate)
    all_rois = [img.our_view_to_zbrain_rois(rois, std_dev.shape) for rois, std_dev in zip(all_rois, all_std_devs)]
    all_std_devs = [to_ants(img.our_view_to_zbrain_img(std_dev)) for std_dev in all_std_devs]
    
    for std_dev in all_std_devs:
        std_dev.set_spacing(io.SPACING_JAKOB)
        
    return all_traces, all_rois, all_std_devs

In [None]:
control_bases = find_bases('/Users/koesterlab/segmented/control/fish*_6dpf_medium')
stimulus_bases = find_bases('/Users/koesterlab/segmented/stimulus/fish*_6dpf_amph')

control_traces, control_rois, control_std_devs = load_stuff(control_bases)
stimulus_traces, stimulus_rois, stimulus_std_devs = load_stuff(stimulus_bases)

In [None]:
control_bases, stimulus_bases

Cut traces to uniform timescale and calculate dFF

In [None]:
ts = np.arange(1800)
control_traces_cut = [traces[:,ts] for traces in control_traces]
stimulus_traces_cut = [traces[:,ts] for traces in stimulus_traces]

control_dFF = [segmentation.dFF(traces, np.arange(110, 160)) for traces in control_traces_cut]
stimulus_dFF = [segmentation.dFF(traces, np.arange(110, 160)) for traces in stimulus_traces_cut]

Register our anatomies to the reference and transform the rois accordingly

In [None]:
control_rois_transformed = [regtools.transform_rois(zb_ref_ants, std_dev, rois) 
                            for std_dev, rois in zip(control_std_devs, prog_percent(control_rois))]

In [None]:
stimulus_rois_transformed = [regtools.transform_rois(zb_ref_ants, std_dev, rois) 
                            for std_dev, rois in zip(stimulus_std_devs, prog_percent(stimulus_rois))]

Visualize all rois with collapsed z axis

In [None]:
zb_ref_avg = np.expand_dims(np.average(zb_ref_np, axis=0), axis=0)

rois = np.concatenate(stimulus_rois_transformed, axis=0)
roi_map = segmentation.draw_rois(rois, zb_ref_avg, fixed_z=0)

plt.figure(figsize=(12, 10))
plt.imshow(roi_map[0])

Generate pixel maps. They associate pixels in the output frame with ROIs and the fraction with which they contribute to the pixel color.

In [None]:
out_shape = zb_ref_np.shape[1:]

In [None]:
control_pix_map_raw = rendering.pixel_map(np.concatenate(control_rois_transformed, axis=0)[:, :2], out_shape)

# filter out pixels where less than N rois contribute to
control_pix_map = rendering.pix_map_filter(control_pix_map_raw, 3)

In [None]:
stimulus_pix_map_raw = rendering.pixel_map(np.concatenate(stimulus_rois_transformed, axis=0)[:, :2], out_shape)

# filter out pixels where less than N rois contribute to
stimulus_pix_map = rendering.pix_map_filter(stimulus_pix_map_raw, 3)

Render all frames

In [None]:
rendering = reload(rendering)
rendered_frames = rendering.orthogonal_averaged(control_pix_map, np.concatenate(stimulus_dFF), 
                                               ts, out_shape, fill_value=0)

Encode the frames into a video file

In [None]:
rendering.to_file('/Users/koesterlab/rendered.mp4', rendered_frames, fps=30)

Save into h5 file

In [None]:
io.save('/Users/koesterlab/rendered_control.h5', rendered_frames, io.SPACING_ZBB)

In [None]:
rendered_frames.shape