In [None]:
import ants
from glob import glob
from importlib import reload
import matplotlib.pyplot as plt
import numpy as np
import os.path
import pickle
from pyprind import prog_percent
from skimage.io import imread
import re

from zebrafishframework import ants_cmd
from zebrafishframework import io
from zebrafishframework import regtools
from zebrafishframework import rendering
from zebrafishframework import signal

regtools = reload(regtools)

In [None]:
base = '/Users/koesterlab/segmented/control/'
base_mask = '/Users/koesterlab/masks/'
base_fn = base + 'fish%02d_6dpf_medium'
r = re.compile('.*fish(?P<num>\d+).*')
fish_ids = []
for f in glob(base + '*.h5'):
    num = int(r.match(f).group('num'))
    fish_ids.append(num)
    
all_rois = [np.load((base_fn + '_rois.npy') % i) for i in fish_ids]
all_traces = [np.load((base_fn + '_traces.npy') % i) for i in fish_ids]

fish_ids

In [None]:
align_to_fish = 12
align_to = fish_ids.index(align_to_fish)

In [None]:
cell_patterns = list(map(lambda e: regtools.points_to_image(e), prog_percent(all_rois)))

In [None]:
for i, cp in enumerate(cell_patterns):
    plt.figure(figsize=(12, 12))
    plt.imshow(cp[10])

In [None]:
transformed_rois = []
transformations = []
for fish_id, rois, cp in prog_percent(list(zip(fish_ids, all_rois, cell_patterns))):
    if fish_id == fish_ids[align_to]:
        transformations.append(None)
        transformed_rois.append([np.array((x, y, z), np.int) for x, y, z, _ in rois])
        continue
    warped, tform = regtools.planewise_affine(cell_patterns[align_to], cp, return_transforms=True)
    transformed_rois.append(regtools.transform_planewise_points(rois, tform))
    transformations.append(tform)

In [None]:
np.mean([t[1] for t in transformations if t])

In [None]:
cmap = plt.get_cmap('hsv', len(transformed_rois))
colors = [(np.array(cmap(i))*255)[:3].astype(np.uint8) for i in range(len(transformed_rois))]

In [None]:
transformed_rois_bak = transformed_rois.copy()

In [None]:
transformed_rois = [rois if rois.shape[1]==3 else rois[:,:3].astype(np.int) for rois in transformed_rois_bak]

In [None]:
def argfilter_rois(rois, shape=(1024, 1024, 21)):
    filtered = []
    for roi_id, roi in enumerate(rois):
        if np.any(roi < 0):
            continue
        if np.any(roi >= shape):
            continue
        filtered.append(roi_id)
    return np.array(filtered)

rois_ids = np.array(list(map(argfilter_rois, transformed_rois)))

In [None]:
filtered_rois = [t[ids] for t, ids in zip(transformed_rois, rois_ids)]
filtered_traces = [t[ids] for t, ids in zip(all_traces, rois_ids)]

In [None]:
def render_rois(ndas, colors, matching=None, planes=[10], shape=(1024, 1024)):
    image = np.zeros(shape + (3,), dtype=np.uint8)
    for nda, color in zip(ndas, colors):
        for roi in nda:
            x, y, z = roi
            if z in planes:
                image[y, x] += color
    
    if matching:
        pass
    
    return image

for p in np.arange(1, 19):
    image = render_rois(filtered_rois, colors, planes=[p])
    plt.figure(figsize=(12, 12))
    plt.imshow(image)

In [None]:
rendering = reload(rendering)

render_ts = np.arange(0, 1800, 1)
render_rois = np.array(all_rois)
render_rois = render_rois.reshape(-1, render_rois.shape[-1])
render_traces = np.array(all_traces)
render_traces = render_traces.reshape(-1, 1)
render_traces = signal.correct_bleaching(render_ts, render_traces, -0.000065)
render_dFF = signal.dFF(render_traces, np.arange(110, 160))

In [None]:
render_dFF.shape, render_rois.shape

In [None]:
def color_func(dFF):
    final_a = (0, 255, 0)
    final_b = (255, 0, 255)
    alpha = 1
    max_dFF = 1
    c = np.array(final_b if dFF > 0 else final_a, dtype=np.float32)
    dFF = min(abs(dFF), max_dFF)/max_dFF
    return np.array(c*alpha*dFF, dtype=np.uint8)

activity = rendering.orthogonal(render_rois, render_dFF, color_func, render_ts, (1024, 1024))
plt.figure(figsize=(12,10))
plt.imshow(np.array(activity[0], dtype=np.uint8))

In [None]:
def initial_solution(ndas):
    dim = max(map(np.alen, ndas))
    num = np.alen(ndas)
    sol = np.full((dim, num), -1, dtype=np.int32)
    for i, rois in enumerate(ndas):
        
    return sol

In [None]:
initial_solution([a, b])

In [None]:
args = ants_zff.AntsArguments(input_file=std_dev_b, reference=std_dev_a, params=ants_zff.get_default_params()[:2])
args.output_folder = os.path.join('/Users/koesterlab/ants_tmp', args.output_folder)
res = ants_zff.run_antsreg(args)