In [None]:
import ants
from glob import glob
from importlib import reload
from ipywidgets import interact
import matplotlib.pyplot as plt
import numpy as np
import os.path
import pickle
from pyprind import prog_percent
from skimage.io import imread
import re

from zebrafishframework import ants_cmd
from zebrafishframework import img
from zebrafishframework import io
from zebrafishframework import regtools
from zebrafishframework import rendering
from zebrafishframework import segmentation
from zebrafishframework import signal

regtools = reload(regtools)
segmentation = reload(segmentation)
img = reload(img)
ants_cmd = reload(ants_cmd)
rendering = reload(rendering)

In [None]:
base = '/Users/koesterlab/segmented/control/'
base_mask = '/Users/koesterlab/masks/'
base_fn = base + 'fish%02d_6dpf_medium'
r = re.compile('.*fish(?P<num>\d+).*')
fish_ids = []
for f in glob(base + '*.h5'):
    num = int(r.match(f).group('num'))
    fish_ids.append(num)
    
all_rois_raw = [np.load((base_fn + '_rois.npy') % i) for i in prog_percent(fish_ids)]
all_traces_raw = [np.load((base_fn + '_traces.npy') % i) for i in prog_percent(fish_ids)]
all_anatomies_raw = [io.load((base_fn + '_std_dev.h5') % i)[0] for i in prog_percent(fish_ids)]

fish_ids

In [None]:
enlarge_xy = (1500 - 1024)/2
enlarge_z = 3
enlarge_by = [(enlarge_z, enlarge_z)] + [(enlarge_xy, enlarge_xy)]*2
enlarge_by

In [None]:
all_rois = [img.enlarge_points(rois, enlarge_by) for rois in prog_percent(all_rois_raw)]
all_anatomies = [img.enlarge_image(anatomy, enlarge_by) for anatomy in prog_percent(all_anatomies_raw)]

In [None]:
all_anatomies = all_anatomies_raw.copy()
all_rois = all_rois_raw.copy()
all_traces = all_traces_raw.copy()

In [None]:
align_to_fish = 12
align_to = fish_ids.index(align_to_fish)

In [None]:
params = ants_cmd.get_default_params()[:1]

tmpdir = '/Users/koesterlab/tmp/'
ref_fn = tmpdir + 'ref.nrrd'
io.save(ref_fn, all_anatomies[align_to], spacing=io.SPACING_JAKOB)
transforms = []
for i, (anatomy, rois) in prog_percent(list(enumerate(zip(all_anatomies, all_rois)))):
    if i == align_to:
        transforms.append(ants.new_ants_transform()) # unity transform
    else:
        in_fn = tmpdir + 'in_tmp_%02d.nrrd' % i
        io.save(in_fn, anatomy, spacing=io.SPACING_JAKOB)
        
        args = ants_cmd.Arguments(in_fn, ref_fn, params, output_folder=tmpdir)
        res = ants_cmd.run_antsreg(args)
        t_fn = res.get_generic_affine()
        transforms.append(ants.read_transform(t_fn))

In [None]:
zb_ref = '/Users/koesterlab/Registrations/Elavl3-H2BRFP.tif' # metadata!!!
in_fn = tmpdir + 'in_tmp_%02d.nrrd' % align_to
io.save(in_fn, img.our_view_to_zbrain_img(all_anatomies[align_to]), spacing=io.SPACING_JAKOB)
params = ants_cmd.get_default_params()[:2]
args = ants_cmd.Arguments(in_fn, zb_ref, params, output_folder=tmpdir)
res = ants_cmd.run_antsreg(args)
t_to_zb = ants.read_transform(res.get_generic_affine())

In [None]:
zb_ref_fn = '/Users/koesterlab/Registrations/Elavl3-H2BRFP.tif' # metadata!!!
zb_ref = ants.image_read(zb_ref_fn)
zb_ref.set_spacing(io.SPACING_ZBB)
transforms = []
for anatomy in prog_percent(all_anatomies):
    our = regtools.to_ants(img.our_view_to_zbrain_img(anatomy))
    our.set_spacing(io.SPACING_JAKOB)
    res = ants.registration(zb_ref, our, type_of_transform='Affine')
    transforms.append(ants.read_transform(res['fwdtransforms'][0]))

In [None]:
our_ref_warped = ants.apply_ants_transform_to_image(t_to_zb, our_ref, zb_ref)

In [None]:
@interact
def browse(i:(0,137)):
    plt.figure(figsize=(12,12))
    plt.imshow(regtools.to_numpy(our_ref_warped)[i])

In [None]:
our_ref = regtools.to_ants(all_anatomies[align_to])
our_ref.set_spacing(io.SPACING_JAKOB)
def transform_xyz(t, roi, img_from, img_to):
    phys = ants.transform_index_to_physical_point(img_from, np.round(roi).astype(np.int))
    trans = ants.apply_ants_transform_to_point(t, phys)
#    ind = ants.transform_physical_point_to_index(img_to, trans)
    ind = np.array(trans)/img_to.spacing
    """
    print('Point: %s' % np.array(p))
    print('Phys:  %s' % phys)
    print('Trans: %s' % trans)
    print('Ind:   %s' % ind)
    """
    return ind

In [None]:
def transform_all_xyzs(transforms, all_xyzs, img_from, img_to):
    return [np.array([transform_xyz(t, roi, img_from, img_to) for roi in rois]) for t, rois in zip(transforms, all_xyzs)]

In [None]:
io.SPACING_ZBB

In [None]:
all_xyzs = [rois[:,:3] for rois in all_rois]
shape = (21, 1024, 1024)

print('pretransform')
# make the pretransform to zb (flip z and rotate 90°)
rois_transformed = [[img.our_view_to_zbrain_point(xyz, shape) for xyz in xyzs] for xyzs in prog_percent(all_xyzs)]

print('All to reference')
#rois_transformed = transform_all_xyzs([t.invert() for t in transforms], prog_percent(rois_transformed), our_ref, zb_ref)

In [None]:
ants.transform_index_to_physical_point(image, (2, 3, 1))

In [None]:
rois_concat.shape

In [None]:
def add_r(rois, r=5):
    return np.array([(x, y, z, r) for x, y, z in rois])

In [None]:
zb_avg = np.average(regtools.to_numpy(zb_ref), axis=0)
zb_avg = zb_avg.reshape((1,) + zb_avg.shape)

In [None]:
#roi_map = segmentation.draw_rois(add_r(rois_transformed[1]), regtools.to_numpy(zb_image))
def color_func(i):
    F = all_traces[0][i][0]
    m = 1000
    return tuple((np.array((255, 255, 255)) * np.min((F, m))/m).astype(np.int))
rois_concat = np.concatenate(rois_transformed[:1])
rois_to_draw = all_rois[0]
roi_map = segmentation.draw_rois(all_rois[0], np.zeros_like(all_anatomies[0]), fixed_z=0, color_func=color_func)

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(roi_map[0])

In [None]:
@interact
def browse(i:(0,137)):
    plt.figure(figsize=(12, 12))
    plt.imshow(roi_map[i])

In [None]:
filtered = segmentation.filter_rois_shape(rois_concat, np.flip(np.array(regtools.to_numpy(zb_ref).shape)))

In [None]:
zb_ref_shape = tuple(np.array(regtools.to_numpy(zb_ref).shape)[1:])
zb_ref_shape

In [None]:
rois_filtered = add_r(rois_concat[filtered].astype(np.int))
traces_filtered = np.concatenate([traces[:,:1800] for traces in all_traces[:1]])[filtered]

In [None]:
np.flip(rois_filtered[:,:2], axis=1)

In [None]:
rendering = reload(rendering)
pix_map = rendering.pixel_map(rois_filtered[:,:2], zb_ref_shape)

In [None]:
pix_map_filtered = rendering.pix_map_filter(pix_map, 0)

In [None]:
involved, dists, pixel_list = pix_map
count_map = np.zeros(shape=involved.shape)
for p in pixel_list:
    count_map[p] = len(involved[p])

plt.figure(figsize=(12,10))
plt.imshow(count_map)
plt.colorbar()

In [None]:
np.max([np.max(dists[p]) for p in np.ndindex(dists.shape)])

In [None]:
plt.figure(figsize=(12,10))
plt.imshow(np.array(rendered_frames[0]), cmap=plt.get_cmap('Greys_r'))

In [None]:
rendering = reload(rendering)
render_ts = [0]
rendered_frames = rendering.orthogonal_averaged(pix_map, traces_filtered, 
                                               rendering.green_magenta_dFF_func,
                                               render_ts, zb_ref_shape)
plt.figure(figsize=(12,10))
plt.imshow(np.array(rendered_frames[0]), cmap=plt.get_cmap('Greys_r'))

In [None]:
io.save('/Users/koesterlab/fish20_zb.h5', rendered_frames, io.SPACING_ZBB)

In [None]:
zb_view_ants = regtools.to_ants(zb_view)
zb_view_ants.set_spacing(io.SPACING_JAKOB)
zb_image.set_spacing(io.SPACING_ZBB)
transf_img = ants.apply_ants_transform_to_image(t_to_zb, zb_view_ants, zb_image)
to_show = regtools.to_numpy(transf_img)[60]
plt.figure(figsize=(12, 12))
plt.imshow(to_show)

In [None]:
np.argmax(regtools.to_numpy(transf_img), axis=0)

In [None]:
cmap = plt.get_cmap('hsv', len(transformed_rois))
colors = [(np.array(cmap(i))*255)[:3].astype(np.uint8) for i in range(len(transformed_rois))]

In [None]:
transformed_rois_bak = transformed_rois.copy()

In [None]:
transformed_rois = [rois if rois.shape[1]==3 else rois[:,:3].astype(np.int) for rois in transformed_rois_bak]

In [None]:
def argfilter_rois(rois, shape=(1024, 1024, 21)):
    filtered = []
    for roi_id, roi in enumerate(rois):
        if np.any(roi < 0):
            continue
        if np.any(roi >= shape):
            continue
        filtered.append(roi_id)
    return np.array(filtered)

rois_ids = np.array(list(map(argfilter_rois, transformed_rois)))

In [None]:
filtered_rois = [t[ids] for t, ids in zip(transformed_rois, rois_ids)]
filtered_traces = [t[ids] for t, ids in zip(all_traces, rois_ids)]

In [None]:
def render_rois(ndas, colors, matching=None, planes=[10], shape=(1024, 1024)):
    image = np.zeros(shape + (3,), dtype=np.uint8)
    for nda, color in zip(ndas, colors):
        for roi in nda:
            x, y, z = roi
            if z in planes:
                image[y, x] += color
    
    if matching:
        pass
    
    return image

for p in np.arange(1, 19):
    image = render_rois(filtered_rois, colors, planes=[p])
    plt.figure(figsize=(12, 12))
    plt.imshow(image)

In [None]:
rendering = reload(rendering)

render_ts = np.arange(0, 1800, 1)
render_rois = np.array(all_rois)
render_rois = render_rois.reshape(-1, render_rois.shape[-1])
render_traces = np.array(all_traces)
render_traces = render_traces.reshape(-1, 1)
render_traces = signal.correct_bleaching(render_ts, render_traces, -0.000065)
render_dFF = signal.dFF(render_traces, np.arange(110, 160))

In [None]:
render_dFF.shape, render_rois.shape

In [None]:
def color_func(dFF):
    final_a = (0, 255, 0)
    final_b = (255, 0, 255)
    alpha = 1
    max_dFF = 1
    c = np.array(final_b if dFF > 0 else final_a, dtype=np.float32)
    dFF = min(abs(dFF), max_dFF)/max_dFF
    return np.array(c*alpha*dFF, dtype=np.uint8)

activity = rendering.orthogonal(render_rois, render_dFF, color_func, render_ts, (1024, 1024))
plt.figure(figsize=(12,10))
plt.imshow(np.array(activity[0], dtype=np.uint8))

In [None]:
def initial_solution(ndas):
    dim = max(map(np.alen, ndas))
    num = np.alen(ndas)
    sol = np.full((dim, num), -1, dtype=np.int32)
    for i, rois in enumerate(ndas):
        
    return sol

In [None]:
initial_solution([a, b])

In [None]:
args = ants_zff.AntsArguments(input_file=std_dev_b, reference=std_dev_a, params=ants_zff.get_default_params()[:2])
args.output_folder = os.path.join('/Users/koesterlab/ants_tmp', args.output_folder)
res = ants_zff.run_antsreg(args)