In [None]:
from astrocast import analysis, detection

import numpy as np
import scipy
import matplotlib.pyplot as plt
import napari
import dask.array as da
from dask.diagnostics import ProgressBar

from skimage.filters import threshold_triangle
import scipy.signal as signal
from skimage import morphology

In [None]:
path = "/media/janrei1/data/astrocyte_examples/31570865_AquA/GlusnfrSuppRaw.h5"
h5_loc = "dff/ch0"
pixel = [(41, 185, "blue"), (57, 43, "red"), (104, 24, "green")]
frames =[8, 41, 174]

In [None]:
def plot_images(arr, frames, lbls=None, figsize=(10, 3), vmin=None, vmax=None):

    if not isinstance(arr, list):
        arr = [arr]

    if lbls is None:
        lbls = range(len(arr))
    
    fig, axx = plt.subplots(len(arr), len(frames), figsize=figsize)
    for x, img in enumerate(arr):
        for y, f in enumerate(frames):
            axx[x, y].imshow(img[f, :, :], vmin=vmin, vmax=vmax)
            axx[x, y].set_title(f"{lbls[x]} #{f}")

            axx[x, y].axis("off")
            
    # plt.tight_layout()

In [None]:
vid = analysis.Video(path, h5_loc=h5_loc, lazy=False, name=h5_loc)
arr = vid.get_data()
arr = da.from_array(arr)
display(arr)

In [None]:
smooth = detection.Detector.gaussian_smooth_3d(arr, sigma=3, radius=2, chunks=(25, 25, 25))
display(smooth)

# plot_images([arr, smooth], frames, lbls=("original", "smooth"))

In [None]:
spat_orig = detection.Detector.spatial_threshold(arr, min_ratio=1, threshold_z_depth=1)
display(spat_orig)

spat_smooth = detection.Detector.spatial_threshold(smooth, min_ratio=1, threshold_z_depth=1)
spat_smooth_d3 = detection.Detector.spatial_threshold(smooth, min_ratio=1, threshold_z_depth=2)

plot_images([spat_orig, spat_smooth, spat_smooth_d3], frames, lbls=("orig:spatial", "smooth:spatial", "smooth:spatial_D3"), 
            vmin=0, vmax=1, figsize=(15, 5))

In [None]:
A = np.zeros((11, 5, 5))
depth = 3
for i in range(depth, 11-depth):
    z0, z1 = i-depth, i+depth+1
    print(i, z0, z1, A[z0:z1, :, :].shape)

In [None]:
temp_orig = detection.Detector.temporal_threshold(arr,  prominence=10, width=3, rel_height=0.9, wlen=60, plateau_size=None)
temp_smooth = detection.Detector.temporal_threshold(smooth, prominence=10, width=3, rel_height=0.9, wlen=60, plateau_size=None)
display(temp_smooth)
# TODO bigger chunks


In [None]:
# Assuming arr1, arr2, arr3, and arr4 are your dask arrays, and you have corresponding labels
arrays = da.compute([spat_orig, spat_smooth_d3, temp_orig, temp_smooth])[0]
labels = ["SP_orig", "SP_sm_d3", "T_orig", "T_sm"]

In [None]:
[spat_orig, spat_smooth, temp_orig, temp_smooth] = arrays
plot_images([temp_orig, temp_smooth], frames, lbls=("T_orig", "T_sm"), vmin=0, vmax=1)

In [None]:
# Create a dictionary to store the results
result_dict = {}

# Loop over all combinations of arrays
for i in range(len(arrays)):
    for j in range(i+1, len(arrays)):
        # Create a key to identify this combination using the labels
        key = f"{labels[i]}_{labels[j]}"
        
        # Calculate the union of the two arrays
        result_dict[key] = arrays[i] | arrays[j]

display(result_dict.keys())

In [None]:
plot_images(list(result_dict.values()), frames, lbls=list(result_dict.keys()), vmin=0, vmax=1, figsize=(15, 15))

In [None]:
def remove_small_objects_2D(arr, min_size=10, connectivity=4):

    if not isinstance(arr, da.Array):
        arr = da.from_array(arr)
    
    arr = arr.rechunk((1, -1, -1))

    def rm_small(frame):
        return morphology.remove_small_objects(frame, min_size=min_size, connectivity=connectivity)
    
    arr = da.map_blocks(rm_small, arr, dtype=arr.dtype)
    
    return arr    

In [None]:
use_lbls = ['SP_orig_T_sm', 'SP_sm_d3_T_sm']
rm_2d = [remove_small_objects_2D(result_dict[key], min_size=50, connectivity=1) for key in use_lbls]
display(rm_2d[0])

plot_images(rm_2d, frames, lbls=use_lbls, vmin=0, vmax=1, figsize=(15, 7))

In [None]:
rm_2d = [remove_small_objects_2D(result_dict[key], min_size=25, connectivity=1) for key in use_lbls]
display(rm_2d[0])

plot_images(rm_2d, frames, lbls=use_lbls, vmin=0, vmax=1, figsize=(15, 7))

In [None]:

# fig, ax = plt.subplots(1, 1, figsize=(20, 3))
# for (x, y, color) in pixel:

#     line = arr[:, x, y]
    
#     peaks, prominences = signal.find_peaks(line, prominence=15, width=3, rel_height=0.9, wlen=60)
#     prominences["peaks"]=peaks
#     # print(prominences)
    
#     for n in range(len(peaks)):

#         for key in prominences.keys():

#             if key in ["prominences", "widths", "width_heights", 'left_bases', 'right_bases']:
#                 continue
            
#             val = prominences[key][n]
#             ax.axvline(val, color=color, linestyle="--", alpha=0.5)
        
#         # center = prominences["peaks"][n]
#         # ax.axvline(center, color=color, linestyle="--", alpha=0.5)
        
#         # left = prominences["left_bases"][n]
#         # ax.axvline(left, color=color, linestyle="dotted", alpha=0.5)
        
#         # right = prominences["right_bases"][n]
#         # ax.axvline(right, color=color, linestyle="dotted", alpha=0.5)
    
#     ax.plot(line, color=color)
    

In [None]:
# def find_peaks(x, prominence=10, width=3, rel_height=0.9, wlen=60, plateau_size=None):
#     peaks, prominences = signal.find_peaks(np.squeeze(x), prominence=prominence, wlen=wlen,
#                                            width=width, rel_height=rel_height, plateau_size=plateau_size)
    
#     active_pixels = np.zeros(x.shape, dtype=int)
#     for (left, right, prom) in list(zip(prominences['left_ips'], prominences['right_ips'], prominences['prominences'])):
#         active_pixels[int(left):int(right)] = prom

#     return active_pixels

# act = da.map_blocks(find_peaks, arr, dtype=int)
# display(act)

# fig, ax = plt.subplots(1, 1, figsize=(25, 3))
# ax2 = ax.twinx()
# for (x, y, color) in pixel:

#     ax.plot(arr[:, x, y], color=color)
#     ax2.plot(act[:, x, y], color=color)

In [None]:
# act_c = act.compute()
# arr_c = arr.compute()

# viewer = napari.Viewer()
# viewer.add_image(act_c, colormap="red")
# viewer.add_image(arr_c, colormap="gray")
# viewer.show()