In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm.auto import tqdm
from glob import glob
from IPython.display import display

In [3]:
# initialize_parameters.py, voluseg_submit.py
# set up
import os
import pprint
import voluseg
import sys
import time
import shutil

In [4]:
import dask
import dask.array as da
from analysis_toolbox.parallel import start_cluster
from dask.distributed import Client

dask.config.set({
    'distributed.scheduler.allowed-failures': 9999,
    'distributed.comm.timeouts.connect':'1440s',
    'distributed.comm.timeouts.tcp':'4320s'
})
dask.config.config

{'distributed': {'worker': {'use-file-locking': False,
   'blocked-handlers': [],
   'multiprocessing-method': 'spawn',
   'connections': {'outgoing': 50, 'incoming': 10},
   'preload': [],
   'preload-argv': [],
   'daemon': True,
   'validate': False,
   'lifetime': {'duration': None, 'stagger': '0 seconds', 'restart': False},
   'profile': {'interval': '10ms', 'cycle': '1000ms', 'low-level': False},
   'memory': {'target': 0.6, 'spill': 0.7, 'pause': 0.8, 'terminate': 0.95},
   'http': {'routes': ['distributed.http.worker.prometheus',
     'distributed.http.health',
     'distributed.http.statics']}},
  'scheduler': {'allowed-failures': 9999,
   'memory': {'target': 0.6, 'spill': 0.7, 'pause': 0.8, 'terminate': 0.95},
   'bandwidth': 100000000,
   'blocked-handlers': [],
   'default-data-size': '1kiB',
   'events-cleanup-delay': '1h',
   'idle-timeout': None,
   'transition-log-length': 100000,
   'work-stealing': True,
   'work-stealing-interval': '100ms',
   'worker-ttl': None,
  

In [5]:
root_dir = '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/'
dir_input = root_dir + 'im_CM0/'
dir_output = root_dir + 'im_CM0_voluseg/'
src_ds_dir = root_dir + 'im_CM0_ds/'
dest_ds_dir = dir_output + 'volumes/0/'

In [6]:
data_cit = dir_output + 'volumes/0.zarr'
data_cis = dir_output + 'volumes/0_rechunked.zarr'

## Prepare to start at step 3

In [7]:
from analysis_toolbox.utils import disk_usage

disk_usage(root_dir)

Total: 48.4TiB
Used: 32.5TiB
Free: 15.9TiB


### Rename h5 files and datasets

In [8]:
from analysis_toolbox.fileio import rename_file_with_replace, rename_h5_dataset

In [9]:
h5_paths = sorted(glob(src_ds_dir+"*.h5"))

h5_paths[:10]

['/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000000_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000001_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000002_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000003_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000004_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_ds/TM0000005_CM0_CHN00.h5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregec

In [10]:
%%time
for h5_path in tqdm(h5_paths):
    rename_h5_dataset(h5_path, 'default', 'volume')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47767.0), HTML(value='')))


CPU times: user 8.72 s, sys: 1.72 s, total: 10.4 s
Wall time: 16.6 s


In [11]:
for h5_path in tqdm(h5_paths):
    rename_file_with_replace(h5_path, 'CHN00.h5', 'CHN00_aligned.hdf5')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47767.0), HTML(value='')))




### Move h5 files to expected directory

In [12]:
from analysis_toolbox.utils import create_dir
create_dir(dir_output + 'volumes')

In [13]:
import shutil

shutil.move(src_ds_dir, dest_ds_dir)

'/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/'

## Initialize parameters

In [15]:
from voluseg._tools.parameter_dictionary import parameter_dictionary
parameter_dictionary()

{'detrending': 'standard',
 'registration': 'medium',
 'diam_cell': 6.0,
 'dir_ants': '',
 'dir_input': '',
 'dir_output': '',
 'dir_transform': '',
 'ds': 2,
 'planes_pad': 0,
 'planes_packed': False,
 'parallel_clean': True,
 'parallel_volume': True,
 'save_volume': False,
 'type_timepoints': 'dff',
 'type_mask': 'geomean',
 'timepoints': 1000,
 'f_hipass': 0,
 'f_volume': 2.0,
 'n_cells_block': 100,
 'n_colors': 1,
 'res_x': 0.40625,
 'res_y': 0.40625,
 'res_z': 5.0,
 't_baseline': 300,
 't_section': 0.01,
 'thr_mask': 0.5}

In [16]:
# initialize_parameters.py

### set these parameters ###
channel_file = os.path.join(dir_input,'ch0.xml')
stack_file = os.path.join(dir_input,'Stack_frequency.txt')
### end set these parameters ###

# get default parameters and set directories
parameters0 = voluseg.parameter_dictionary()
parameters0['dir_input'] = dir_input
parameters0['dir_output'] = dir_output

# retrieve metadata from channel and stack files
parameters0 = voluseg.load_metadata(parameters0, channel_file, stack_file)

# set other parameters as necessary
parameters0['ds'] = 2  # default is 2 | if downsampled already, make sure this matches that
parameters0['diam_cell'] = 6.0  #  cell_diameter = 6: 100-150k cells, cell_diameter=5: ~300-400k cells
parameters0['n_cells_block'] = 200  # increase block size to reduce blockiness in segments
parameters0['registration'] = 'none'  # comment or set to 'medium' to enable default (i.e. run registration)
parameters0['parallel_volume']=False  # False if running on local workstation
parameters0['parallel_clean']=False  # False if running on local workstation
parameters0['type_timepoints']='periodic'
parameters0['t_section'] = 0
parameters0['res_z'] = 100 ## microns; manually set to ensure cells belong to a single plane

# create parameter file with metadata
voluseg.step0_process_parameters(parameters0)

# check saved parameters
parameters = voluseg.load_parameters(os.path.join(dir_output, 'parameters.pickle'))
pprint.pprint(parameters)

fetched t_section.
fetched res_z.
fetched f_volume.
Checking 'timepoints' for 'type_timepoints'='periodic'.
parameter file successfully saved.
parameter file successfully loaded.
{'affine_mat': array([[  0.8125,   0.    ,   0.    ,   0.    ],
       [  0.    ,   0.8125,   0.    ,   0.    ],
       [  0.    ,   0.    , 100.    ,   0.    ],
       [  0.    ,   0.    ,   0.    ,   1.    ]]),
 'detrending': 'standard',
 'diam_cell': 6.0,
 'dir_ants': '',
 'dir_input': '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0/',
 'dir_output': '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/',
 'dir_transform': '',
 'ds': 2,
 'ext': '.h5',
 'f_hipass': 0,
 'f_volume': 3.01,
 'lt': 47767,
 'n_cells_block': 200,
 'n_colors': 1,
 'parallel_clean': False,
 'parallel_volume': False,
 'planes_packed': False,
 'planes_pad': 0,
 'registration': No

In [17]:
fix_parameters = False

In [18]:
# check saved parameters
filename_parameters = os.path.join(dir_output, 'parameters.pickle')

parameters = voluseg.load_parameters(filename_parameters)
pprint.pprint(parameters)

parameter file successfully loaded.
{'affine_mat': array([[  0.8125,   0.    ,   0.    ,   0.    ],
       [  0.    ,   0.8125,   0.    ,   0.    ],
       [  0.    ,   0.    , 100.    ,   0.    ],
       [  0.    ,   0.    ,   0.    ,   1.    ]]),
 'detrending': 'standard',
 'diam_cell': 6.0,
 'dir_ants': '',
 'dir_input': '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0/',
 'dir_output': '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/',
 'dir_transform': '',
 'ds': 2,
 'ext': '.h5',
 'f_hipass': 0,
 'f_volume': 3.01,
 'lt': 47767,
 'n_cells_block': 200,
 'n_colors': 1,
 'parallel_clean': False,
 'parallel_volume': False,
 'planes_packed': False,
 'planes_pad': 0,
 'registration': None,
 'res_x': 0.40625,
 'res_y': 0.40625,
 'res_z': 100,
 'save_volume': False,
 't_baseline': 300,
 't_section': 0,
 'thr_mask': 0.5,
 'timepo

In [19]:
hdf5_paths = sorted(glob(dest_ds_dir+"*.hdf5"))
hdf5_paths[:10]

['/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/TM0000000_CM0_CHN00_aligned.hdf5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/TM0000001_CM0_CHN00_aligned.hdf5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/TM0000002_CM0_CHN00_aligned.hdf5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/TM0000003_CM0_CHN00_aligned.hdf5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/volumes/0/TM0000004_CM0_CHN00_aligned.hdf5',
 '/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_ex

In [20]:
from analysis_toolbox.utils import path_leaf
hdf5_basenames = [path_leaf(path) for path in tqdm(hdf5_paths)]
hdf5_basenames[:10]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=47767.0), HTML(value='')))




['TM0000000_CM0_CHN00_aligned.hdf5',
 'TM0000001_CM0_CHN00_aligned.hdf5',
 'TM0000002_CM0_CHN00_aligned.hdf5',
 'TM0000003_CM0_CHN00_aligned.hdf5',
 'TM0000004_CM0_CHN00_aligned.hdf5',
 'TM0000005_CM0_CHN00_aligned.hdf5',
 'TM0000006_CM0_CHN00_aligned.hdf5',
 'TM0000007_CM0_CHN00_aligned.hdf5',
 'TM0000008_CM0_CHN00_aligned.hdf5',
 'TM0000009_CM0_CHN00_aligned.hdf5']

In [114]:
hf_basenames = [basename.replace('_aligned.hdf5', '') for basename in tqdm(hdf5_basenames)]
hf_basenames[:10]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=54202.0), HTML(value='')))




['TM0000000_CM1_CHN00',
 'TM0000001_CM1_CHN00',
 'TM0000002_CM1_CHN00',
 'TM0000003_CM1_CHN00',
 'TM0000004_CM1_CHN00',
 'TM0000005_CM1_CHN00',
 'TM0000006_CM1_CHN00',
 'TM0000007_CM1_CHN00',
 'TM0000008_CM1_CHN00',
 'TM0000009_CM1_CHN00']

In [115]:
lt = len(hf_basenames)
lt

54202

In [117]:
parameters['timepoints'] = parameter_dictionary()['timepoints']
parameters['timepoints']

1000

In [118]:
# check timepoints
parameters['type_timepoints'] = parameters['type_timepoints'].lower()
if not parameters['type_timepoints'] in ['dff', 'periodic', 'custom']:
    raise Exception('\'type_timepoints\' must be \'dff\', \'periodic\' or \'custom\'.')
else:
    print('Checking \'timepoints\' for \'type_timepoints\'=\'%s\'.'%parameters['type_timepoints'])
    tp = parameters['timepoints']
    if parameters['type_timepoints'] in ['dff', 'periodic']:
        if not (np.isscalar(tp) and (tp >= 0) and (tp == np.round(tp))):
            raise Exception('\'timepoints\' must be a nonnegative integer.')
        elif tp >= lt:
            warn('specified number of timepoints is greater than the number of volumes, overriding.')
            tp = 0
    elif parameters['type_timepoints'] in ['custom']:
        tp = np.unique(tp)
        if not ((np.ndim(tp) == 1) and np.all(tp >= 0) and np.all(tp == np.round(tp))):
            raise Exception('\'timepoints\' must be a one-dimensional vector of nonnegative integers.')
        elif np.any(tp >= lt):
            warn('discarding timepoints that exceed the number of volumes.')
            tp = tp[tp < lt]
        tp = tp.astype(int)
        
print(tp)

Checking 'timepoints' for 'type_timepoints'='periodic'.
1000


In [119]:
parameters['volume_names'] = hf_basenames
parameters['lt'] = lt
parameters['timepoints'] = tp
parameters['ext'] = '.h5'

In [120]:
if fix_parameters:
    
    import pickle
    os.remove(filename_parameters)
    
    with open(filename_parameters, 'wb') as file_handle:
        pickle.dump(parameters, file_handle)
        print('parameter file successfully saved.')

parameter file successfully saved.


## Run the pipeline

In [7]:
file_output = os.path.join(dir_output, 'prepro.output')
file_output

'/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/prepro.output'

### Load parameters

In [8]:
# voluseg_submit.py

parameters = voluseg.load_parameters(os.path.join(dir_output, 'parameters.pickle'))
with open(file_output, 'a') as fh:
    pprint.pprint(parameters, fh)

parameter file successfully loaded.


In [9]:
## Prepare data
data_cis_da = da.from_zarr(data_cis, inline_array=True)
data_cit_da = da.from_zarr(data_cit, inline_array=True)

display(data_cis_da)
display(data_cit_da)

Unnamed: 0,Array,Chunk
Bytes,11.86 TB,391.31 MB
Shape,"(47767, 53, 572, 1024)","(47767, 1, 32, 32)"
Count,30528 Tasks,30528 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 11.86 TB 391.31 MB Shape (47767, 53, 572, 1024) (47767, 1, 32, 32) Count 30528 Tasks 30528 Chunks Type float64 numpy.ndarray",47767  1  1024  572  53,

Unnamed: 0,Array,Chunk
Bytes,11.86 TB,391.31 MB
Shape,"(47767, 53, 572, 1024)","(47767, 1, 32, 32)"
Count,30528 Tasks,30528 Chunks
Type,float64,numpy.ndarray


Unnamed: 0,Array,Chunk
Bytes,11.86 TB,248.35 MB
Shape,"(47767, 53, 572, 1024)","(1, 53, 572, 1024)"
Count,47767 Tasks,47767 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 11.86 TB 248.35 MB Shape (47767, 53, 572, 1024) (1, 53, 572, 1024) Count 47767 Tasks 47767 Chunks Type float64 numpy.ndarray",47767  1  1024  572  53,

Unnamed: 0,Array,Chunk
Bytes,11.86 TB,248.35 MB
Shape,"(47767, 53, 572, 1024)","(1, 53, 572, 1024)"
Count,47767 Tasks,47767 Chunks
Type,float64,numpy.ndarray


In [10]:
import os
import h5py
import numpy as np
import pandas as pd
import time
from types import SimpleNamespace

### Step 3: Mask volume

In [19]:
# step3.py

from scipy import stats
from sklearn import mixture
from skimage import morphology
from scipy.ndimage.filters import median_filter
from voluseg._tools.ball import ball
from voluseg._tools.constants import ali, hdf
from voluseg._tools.load_volume import load_volume
from voluseg._tools.clean_signal import clean_signal

In [20]:
# set up matplotlib
import matplotlib
import matplotlib.pyplot as plt

In [21]:
p = SimpleNamespace(**parameters)

In [22]:
# compute mean timeseries and ranked dff
fullname_timemean = os.path.join(p.dir_output, 'mean_timeseries')
fullname_timemean

'/scratch/limj2/data/raw/20211007/fish00/5dpf_elavl3-gc8f-gfap-jregeco_MG-vs_NGGU-trunc65_fish00_exp07_20211007_215308/im_CM0_voluseg/mean_timeseries'

In [29]:
%%time
if not os.path.isfile(fullname_timemean+hdf):
    
    dff_rank = np.zeros(p.lt)
    mean_timeseries_raw = np.zeros((p.n_colors, p.lt))
    mean_timeseries = np.zeros((p.n_colors, p.lt))
    mean_baseline = np.zeros((p.n_colors, p.lt))
    mean_timeseries_raw.shape    
    
    for color_i in range(p.n_colors):
        
        dir_volume = os.path.join(p.dir_output, 'volumes', str(color_i))       
        
        mean_timeseries_raw_da = da.nanmean(data_cit_da, axis=(1,2,3))
        display(mean_timeseries_raw_da)
        
        with start_cluster(
            cores=8,
            force_local=False, processes=32,
#             worker_lifetime="350s", worker_lifetime_stagger="2m", 
            death_timeout=1200, store_logs=True, spillover=True, verbose=True) as cluster, Client(cluster) as client:

            display(client)
                  
            mean_timeseries_raw[color_i] = mean_timeseries_raw_da.compute()
                
    time, base = clean_signal(parameters, mean_timeseries_raw[color_i])
    mean_timeseries[color_i], mean_baseline[color_i] = time, base
    dff_rank += stats.rankdata((time - base) / time)        
    
    # get high delta-f/f timepoints
    if p.type_timepoints == 'custom':
        timepoints = p.timepoints
    else:
        nt = p.timepoints
        if not nt:
            timepoints = np.arange(p.lt)
        else:
            if p.type_timepoints == 'dff':
                timepoints = np.sort(np.argsort(dff_rank)[::-1][:nt])
            elif p.type_timepoints == 'periodic':
                timepoints = np.linspace(0, p.lt, nt, dtype='int', endpoint=False)

    with h5py.File(fullname_timemean+hdf, 'w') as file_handle:
        file_handle['mean_timeseries_raw'] = mean_timeseries_raw
        file_handle['mean_timeseries'] = mean_timeseries
        file_handle['mean_baseline'] = mean_baseline
        file_handle['timepoints'] = timepoints    

Kernel currently running on e06u03.int.janelia.org
Scheduler: tcp://127.0.0.1:45661
Dashboard link: http://127.0.0.1:8787/status


0,1
Client  Scheduler: tcp://127.0.0.1:45661  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 32  Cores: 256  Memory: 1.01 TB


CPU times: user 12min 27s, sys: 2min 9s, total: 14min 37s
Wall time: 18min 3s


In [32]:
# load timepoints
with h5py.File(fullname_timemean+hdf, 'r') as file_handle:
    timepoints = file_handle['timepoints'][()]

for color_i in range(p.n_colors):
    
    fullname_volmean = os.path.join(p.dir_output, 'volume%d'%(color_i))
    if os.path.isfile(fullname_volmean+hdf):
        continue

    dir_volume = os.path.join(p.dir_output, 'volumes', str(color_i))
    dir_plot = os.path.join(p.dir_output, 'mask_plots', str(color_i))
    os.makedirs(dir_plot, exist_ok=True)

    fullname_volume = os.path.join(dir_volume, p.volume_names[0])
    lx, ly, lz = load_volume(fullname_volume+ali+hdf).T.shape    

    
    if p.type_mask == 'max':
        volume_mean_da = da.nanmax(data_cis_da[timepoints], axis=0).T
    elif p.type_mask == 'mean':
        volume_mean_da = da.nanmean(data_cis_da[timepoints], axis=0).T
    elif p.type_mask == 'geomean':
        volume_mean_da = 10 ** da.nanmean(da.log10(data_cis_da[timepoints]), axis=0).T

    display(volume_mean_da)
    
    with start_cluster(
        cores=8,
        force_local=False, processes=32,
    #     worker_lifetime="350s", worker_lifetime_stagger="2m", 
        death_timeout=1200, store_logs=True, spillover=True, verbose=True) as cluster, Client(cluster) as client:

        display(client)

        volume_mean = volume_mean_da.compute()    
        
    from analysis_toolbox.stacks_helper import plot_planes
    plot_planes(volume_mean.T, percentile_min=20, percentile_max=99)  
    
    # get peaks by comparing to a median-smoothed volume
    ball_radi = ball(0.5 * p.diam_cell, p.affine_mat)[0]
    volume_peak = volume_mean >= median_filter(volume_mean, footprint=ball_radi)

    # compute power and probability
    voxel_intensity = np.percentile(volume_mean[volume_mean>0], np.r_[5:95:0.001])[:, None]
    gmm = mixture.GaussianMixture(n_components=2, max_iter=100, n_init=100).fit(voxel_intensity)
    voxel_probability = gmm.predict_proba(voxel_intensity)
    voxel_probability = voxel_probability[:, np.argmax(voxel_intensity[np.argmax(voxel_probability, 0)])]    
    
    
    # compute intensity threshold
    if (p.thr_mask > 0) and (p.thr_mask <= 1):
        thr_probability = p.thr_mask
        ix = np.argmin(np.abs(voxel_probability - thr_probability))
        thr_intensity = voxel_intensity[ix][0]
    elif p.thr_mask > 1:
        thr_intensity = p.thr_mask
        ix = np.argmin(np.abs(voxel_intensity - thr_intensity))
        thr_probability = voxel_probability[ix]
    else:
        thr_intensity = - np.inf
        thr_probability = 0

    print('using probability threshold of %f.'%(thr_probability))
    print('using intensity threshold of %f.'%(thr_intensity))    
    
    # get and save brain mask
    fig = plt.figure(1, (18, 6))
    plt.subplot(131),
    _ = plt.hist(voxel_intensity, 100)
    plt.plot(thr_intensity, 0, '|', color='r', markersize=200)
    plt.xlabel('voxel intensity')
    plt.title('intensity histogram with threshold (red)')

    plt.subplot(132),
    _ = plt.hist(voxel_probability, 100)
    plt.plot(thr_probability, 0, '|', color='r', markersize=200)
    plt.xlabel('voxel probability')
    plt.title('probability histogram with threshold (red)')

    plt.subplot(133),
    plt.plot(voxel_intensity, voxel_probability, linewidth=3)
    plt.plot(thr_intensity, thr_probability, 'x', color='r', markersize=10)
    plt.xlabel('voxel intensity')
    plt.ylabel('voxel probability')
    plt.title('intensity-probability plot with threshold (red)')

    plt.savefig(os.path.join(dir_plot, 'histogram.png'))
    plt.close(fig)

    # remove all disconnected components less than 5000 cubic microliters in size
    rx, ry, rz, _ = np.diag(p.affine_mat)
    volume_mask = (volume_mean > thr_intensity).astype(bool)
    thr_size = np.round(5000 * rx * ry * rz).astype(int)
    volume_mask = morphology.remove_small_objects(volume_mask, thr_size)

    # compute background fluorescence
    background = np.median(volume_mean[volume_mask==0])    
    
    # save brain mask figures
    for i in range(lz):
        fig = plt.figure(1, (18, 6))
        plt.subplot(131)
        plt.imshow(volume_mean[:, :, i].T, vmin=voxel_intensity[0], vmax=voxel_intensity[-1])
        plt.title('volume intensity (plane %d)'%(i))

        plt.subplot(132)
        plt.imshow(volume_mask[:, :, i].T)
        plt.title('volume mask (plane %d)'%(i))

        plt.subplot(133)
        img = np.stack((volume_mean[:, :, i], volume_mask[:, :, i], volume_mask[:, :, i]), axis=2)
        img[:, :, 0] = (img[:, :, 0] - voxel_intensity[0]) / (voxel_intensity[-1] - voxel_intensity[0])
        img[:, :, 0] = np.minimum(np.maximum(img[:, :, 0], 0), 1)
        plt.imshow(np.transpose(img, [1, 0, 2]))
        plt.title('volume mask/intensity overlay (plane %d)'%(i))

        plt.savefig(os.path.join(dir_plot, 'mask_z%03d.png'%(i)))
        plt.close(fig)    
        
    with h5py.File(fullname_volmean+hdf, 'w') as file_handle:
        file_handle['volume_mask'] = volume_mask.T
        file_handle['volume_mean'] = volume_mean.T
        file_handle['volume_peak'] = volume_peak.T
        file_handle['thr_intensity'] = thr_intensity
        file_handle['thr_probability'] = thr_probability
        file_handle['background'] = background        

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]


Unnamed: 0,Array,Chunk
Bytes,248.35 MB,8.19 kB
Shape,"(1024, 572, 53)","(32, 32, 1)"
Count,213696 Tasks,30528 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 248.35 MB 8.19 kB Shape (1024, 572, 53) (32, 32, 1) Count 213696 Tasks 30528 Chunks Type float64 numpy.ndarray",53  572  1024,

Unnamed: 0,Array,Chunk
Bytes,248.35 MB,8.19 kB
Shape,"(1024, 572, 53)","(32, 32, 1)"
Count,213696 Tasks,30528 Chunks
Type,float64,numpy.ndarray


Kernel currently running on e06u03.int.janelia.org
Scheduler: tcp://127.0.0.1:37493
Dashboard link: http://127.0.0.1:8787/status


0,1
Client  Scheduler: tcp://127.0.0.1:37493  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 32  Cores: 256  Memory: 1.01 TB


Task was destroyed but it is pending!
task: <Task pending coro=<RequestHandler._execute() running at /groups/ahrens/home/limj2/anaconda3/envs/python37/lib/python3.6/site-packages/tornado/web.py:1704> wait_for=<Future pending cb=[<TaskWakeupMethWrapper object at 0x7f12aaccb318>()]> cb=[_HandlerDelegate.execute.<locals>.<lambda>() at /groups/ahrens/home/limj2/anaconda3/envs/python37/lib/python3.6/site-packages/tornado/web.py:2326]>


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  if colorbar: fig.tight_layout(rect=[0, 0.03, 0.85, 0.95])


using probability threshold of 0.500000.
using intensity threshold of 106.145482.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Detect cells

In [11]:
from voluseg._steps.step4a import define_blocks
from voluseg._steps.step4b import process_block_data
from voluseg._steps.step4c import initialize_block_cells
from voluseg._steps.step4d import nnmf_sparse
from voluseg._tools.ball import ball
from voluseg._tools.constants import hdf

In [12]:
p = SimpleNamespace(**parameters)

In [13]:
ball_diam, ball_diam_xyz0 = ball(1.0 * p.diam_cell, p.affine_mat)

# load timepoints
fullname_timemean = os.path.join(p.dir_output, 'mean_timeseries')
with h5py.File(fullname_timemean+hdf, 'r') as file_handle:
    timepoints = file_handle['timepoints'][()]

In [None]:
# load plane filename
for color_i in range(p.n_colors):
    
    fullname_cells = os.path.join(p.dir_output, 'cells%s_clean'%(color_i))
    if os.path.isfile(fullname_cells+hdf):
        continue

    dir_cell = os.path.join(p.dir_output, 'cells', str(color_i))
    os.makedirs(dir_cell, exist_ok=True)

    fullname_volmean = os.path.join(p.dir_output, 'volume%d'%(color_i))
    with h5py.File(fullname_volmean+hdf, 'r') as file_handle:
        volume_mean = file_handle['volume_mean'][()].T
        volume_mask = file_handle['volume_mask'][()].T
        volume_peak = file_handle['volume_peak'][()].T
        if 'n_blocks' in file_handle.keys():
            flag = 0
            n_voxels_cell = file_handle['n_voxels_cell'][()]
            n_blocks = file_handle['n_blocks'][()]
            block_valids = file_handle['block_valids'][()]
            xyz0 = file_handle['block_xyz0'][()]
            xyz1 = file_handle['block_xyz1'][()]
        else:
            flag = 1    
            
    # dimensions and resolution
    lxyz = volume_mean.shape
    rxyz = np.diag(p.affine_mat)[:3]

    # compute number of blocks (do only once)
    if flag:
        lx, ly, lz = lxyz
        rx, ry, rz = rxyz

        # get number of voxels in cell
        if (lz == 1) or (rz >= p.diam_cell):
            # area of a circle
            n_voxels_cell = np.pi * ((p.diam_cell / 2.0)**2) / (rx * ry)
        else:
            # volume of a cylinder (change to sphere later)
            n_voxels_cell = p.diam_cell * np.pi * ((p.diam_cell / 2.0)**2) / (rx * ry * rz)

        n_voxels_cell = np.round(n_voxels_cell).astype(int)

        # get number of voxels in each cell
        n_blocks, block_valids, xyz0, xyz1 = \
            define_blocks(lx, ly, lz, p.n_cells_block, n_voxels_cell, volume_mask)

        # save number and indices of blocks
        with h5py.File(fullname_volmean+hdf, 'r+') as file_handle:
            file_handle['n_voxels_cell'] = n_voxels_cell
            file_handle['n_blocks'] = n_blocks
            file_handle['block_valids'] = block_valids
            file_handle['block_xyz0'] = xyz0
            file_handle['block_xyz1'] = xyz1

    print('number of blocks, total: %d.'%(block_valids.sum()))            
    
    for ii in np.where(block_valids)[0]:
        try:
            fullname_block = os.path.join(dir_cell, 'block%05d'%(ii))
            with h5py.File(fullname_block+hdf, 'r') as file_handle:
                if ('completion' in file_handle.keys()) and file_handle['completion'][()]:
                    block_valids[ii] = 0
        except (NameError, OSError):
            pass

    print('number of blocks, remaining: %d.'%(block_valids.sum()))
    ix = np.where(block_valids)[0]
    block_ixyz01 = list(zip(ix, xyz0[ix], xyz1[ix]))   
    
    ## Open data
    import zarr
    data_cis_zarr = zarr.open(data_cis, mode='r')
    data_cis_zarr    
    
    def process_block_data(xyz0, xyz1, parameters, color_i, lxyz, rxyz,
                           ball_diam, bvolume_mean, bvolume_peak, timepoints):
        '''load timeseries in individual blocks, slice-time correct, and find similar timeseries'''

        import os
        import h5py
        import time
        import numpy as np
        from scipy import interpolate
        from skimage import morphology
        from types import SimpleNamespace
        from voluseg._tools.constants import ali, hdf, dtype

        os.environ['MKL_NUM_THREADS'] = '1'

        p = SimpleNamespace(**parameters)
        lz = lxyz[2]

        # load and dilate initial voxel peak positions
        x0_, y0_, z0_ = xyz0
        x1_, y1_, z1_ = xyz1
        voxel_peak = np.zeros_like(bvolume_peak)
        voxel_peak[x0_:x1_, y0_:y1_, z0_:z1_] = 1
        voxel_peak = voxel_peak & bvolume_peak & (bvolume_mean > 0)
        voxel_mask = morphology.binary_dilation(voxel_peak, ball_diam) & (bvolume_mean > 0)

        voxel_xyz = np.argwhere(voxel_mask)
        voxel_xyz_peak = np.argwhere(voxel_peak)
        peak_idx = np.argwhere(voxel_peak[voxel_mask]).T[0]

        tic = time.time()
        dir_volume = os.path.join(p.dir_output, 'volumes', str(color_i))
        x0, y0, z0 = voxel_xyz.min(0)
        x1, y1, z1 = voxel_xyz.max(0) + 1
        voxel_timeseries_block = [None] * p.lt

        voxel_timeseries_block = data_cis_zarr[:, z0:z1, y0:y1, x0:x1].T  # txyz -> xyzt
        voxel_timeseries = voxel_timeseries_block[voxel_mask[x0:x1, y0:y1, z0:z1]]
        voxel_timeseries = voxel_timeseries.astype(dtype)
        del voxel_timeseries_block
        print('data loading: %.1f minutes.\n'%((time.time() - tic) / 60))

        # slice-time correct if more than one slice and t_section is positive
        if (lz > 1) and (p.t_section > 0):
            for i, zi in enumerate(voxel_xyz[:, 2]):
                # get timepoints of midpoint and zi plane for interpolation
                timepoints_zi = np.arange(p.lt) / p.f_volume +  zi      * p.t_section
                timepoints_zm = np.arange(p.lt) / p.f_volume + (lz / 2) * p.t_section

                # make spline interpolator and interpolate timeseries
                spline_interpolator_xyzi = \
                    interpolate.InterpolatedUnivariateSpline(timepoints_zi, voxel_timeseries[i])
                voxel_timeseries[i] = spline_interpolator_xyzi(timepoints_zm)

        def normalize(timeseries):
            mn = timeseries.mean(1)
            sd = timeseries.std(1, ddof=1)
            return (timeseries - mn[:, None]) / (sd[:, None] * np.sqrt(p.lt - 1))

        # get voxel connectivity from proximities (distances) and similarities (correlations)
        voxel_xyz_phys_peak = voxel_xyz_peak * rxyz
        voxel_timeseries_peak_nrm = normalize(voxel_timeseries[np.ix_(peak_idx, timepoints)])

        # compute voxel peak similarity: combination of high proximity and high correlation
        tic = time.time()
        n_peaks = len(peak_idx)
        voxel_similarity_peak = np.zeros((n_peaks, n_peaks), dtype=bool)
        for i in range(n_peaks):
            dist_i = (((voxel_xyz_phys_peak[i] - voxel_xyz_phys_peak)**2).sum(1))**0.5
            neib_i = dist_i < p.diam_cell
            corr_i = np.dot(voxel_timeseries_peak_nrm[i], voxel_timeseries_peak_nrm.T)
            voxel_similarity_peak[i] = neib_i & (corr_i > np.median(corr_i[neib_i]))

        voxel_similarity_peak = voxel_similarity_peak | voxel_similarity_peak.T
        print('voxel similarity: %.1f minutes.\n'%((time.time() - tic) / 60))

        return (voxel_xyz, voxel_timeseries, peak_idx, voxel_similarity_peak)    

    # detect individual cells with sparse nnmf algorithm
    def detect_cells_block(tuple_i_xyz0_xyz1):

        ## read from file
        with h5py.File(fullname_volmean+hdf, 'r') as file_handle:
            bvolume_peak = file_handle['volume_peak'][()].T
            bvolume_mean = file_handle['volume_mean'][()].T        

        import time

        os.environ['MKL_NUM_THREADS'] = '1'

        ii, xyz0, xyz1 = tuple_i_xyz0_xyz1

        voxel_xyz, voxel_timeseries, peak_idx, voxel_similarity_peak = \
            process_block_data(xyz0, xyz1, parameters, color_i, lxyz, rxyz,
                               ball_diam, bvolume_mean, bvolume_peak, timepoints)

        n_voxels_block = len(voxel_xyz)                        # number of voxels in block

        voxel_fraction_peak = np.argsort(((voxel_timeseries[peak_idx])**2).mean(1)) / len(peak_idx)
        for fraction in np.r_[1:0:-0.05]:
            try:
                peak_valids = (voxel_fraction_peak >= (1 - fraction))   # valid voxel indices

                n_cells = np.round(peak_valids.sum() / (0.5 * n_voxels_cell)).astype(int)
                print((fraction, n_cells))

                tic = time.time()
                voxel_timeseries_valid, voxel_xyz_valid, cell_weight_init_valid, \
                    cell_neighborhood_valid, cell_sparseness = initialize_block_cells(
                        n_voxels_cell, n_voxels_block, n_cells, voxel_xyz, voxel_timeseries, peak_idx,
                        peak_valids, voxel_similarity_peak, lxyz, rxyz, ball_diam, ball_diam_xyz0)
                print('cell initialization: %.1f minutes.\n'%((time.time() - tic) / 60))

                tic = time.time()
                cell_weights_valid, cell_timeseries_valid, d = nnmf_sparse(
                    voxel_timeseries_valid, voxel_xyz_valid, cell_weight_init_valid,
                    cell_neighborhood_valid, cell_sparseness, timepoints=timepoints,
                    miniter=10, maxiter=100, tolfun=1e-3, verbosity=False)

                success = 1
                print('cell factorization: %.1f minutes.\n'%((time.time() - tic) / 60))
                break
            except ValueError as msg:
                print('retrying factorization of block %d: %s'%(ii, msg))
                success = 0

        # get cell positions and timeseries, and save cell data
        fullname_block = os.path.join(dir_cell, 'block%05d'%(ii))
        with h5py.File(fullname_block+hdf, 'w') as file_handle:
            if success:
                for ci in range(n_cells):
                    ix = cell_weights_valid[:, ci] > 0
                    xyzi = voxel_xyz_valid[ix]
                    wi = cell_weights_valid[ix, ci]
                    bi = np.sum(wi * bvolume_mean[tuple(zip(*xyzi))]) / np.sum(wi)
                    ti = bi * cell_timeseries_valid[ci] / np.mean(cell_timeseries_valid[ci])

                    file_handle['/cell/%05d/xyz'%(ci)] = xyzi
                    file_handle['/cell/%05d/weights'%(ci)] = wi
                    file_handle['/cell/%05d/timeseries'%(ci)] = ti

            file_handle['n_cells'] = n_cells
            file_handle['completion'] = 1        
            
    import time

    with start_cluster(
        cores=1,
        force_local=False, processes=32,
        death_timeout=1200, store_logs=True, spillover=True, verbose=True) as cluster, Client(cluster) as client:

        display(client)

        sch_start = time.process_time()
        future = client.map(detect_cells_block, block_ixyz01)
        print(f'Scheduling time: {time.process_time() - sch_start} s')

        cmp_start = time.process_time()
        result = client.gather(future)
        print(f'Compute time: {time.process_time() - cmp_start} s')               

number of blocks, total: 2038.
number of blocks, remaining: 638.
Kernel currently running on e06u03.int.janelia.org
Scheduler: tcp://127.0.0.1:42001
Dashboard link: http://127.0.0.1:8787/status


0,1
Client  Scheduler: tcp://127.0.0.1:42001  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 32  Cores: 32  Memory: 1.01 TB


Scheduling time: 27.909320859999752 s


In [255]:
%%time
bvolume_peak = volume_peak
bvolume_mean = volume_mean
detect_cells_block(block_ixyz01[1])

(58, 53, 4, 54202)
data loading: 7.0 minutes.

voxel similarity: 0.1 minutes.

(1.0, 193)


  affinity='euclidean')


cell initialization: 0.1 minutes.

(0, 1.5836856146709233, 1.7391393772706083)
(1, 0.9894178305234703, 1.4094238792319178)
(2, 0.9942896957574211, 0.014357450711416402)
(3, 0.994427605750581, 0.008356892835768914)
(4, 0.9942688643227016, 0.007092552863908608)
(5, 0.9941694241352262, 0.004094284456387411)
(6, 0.9940860618571027, 0.0026673008816530093)
(7, 0.9940560075159935, 0.002304383455858736)
(8, 0.9940360018136267, 0.0017202246323778057)
(9, 0.9939891079151312, 0.0014905202725173515)
(10, 0.9939763058200556, 0.0011266808400308705)
(11, 0.9939458224090683, 0.0011139479350170938)
cell factorization: 1.2 minutes.

CPU times: user 2min 6s, sys: 2min 18s, total: 4min 24s
Wall time: 8min 31s


## Collect cells

In [13]:
from itertools import combinations
from voluseg._steps.step4e import collect_blocks
from voluseg._tools.constants import hdf, dtype
from voluseg._tools.clean_signal import clean_signal

In [18]:
p = SimpleNamespace(**parameters)

thr_similarity = 0.5

In [16]:
def collect_blocks(color_i, parameters):
    dir_cell = os.path.join(p.dir_output, 'cells', str(color_i))
    
    fullname_volmean = os.path.join(p.dir_output, 'volume%d'%(color_i))
    with h5py.File(fullname_volmean+hdf, 'r') as file_handle:
        block_valids = file_handle['block_valids'][()]    
        
    idx_block_valids = np.argwhere(block_valids).T[0]
    
    # valids_tuple = zip([[]]*len(idx_block_valids), idx_block_valids)  ## not necessary before add_data doesn't accept tuples as input
    
    def add_data(ii):
        cell_block_id = []
        cell_xyz = []
        cell_weights = []
        cell_timeseries = []

        fullname_block = os.path.join(dir_cell, 'block%05d'%(ii))
        with h5py.File(fullname_block+hdf, 'r') as file_handle:
            for ci in range(file_handle['n_cells'][()]):
                cell_block_id.append(ii)
                cell_xyz.append(file_handle['/cell/%05d/xyz'%(ci)][()])
                cell_weights.append(file_handle['/cell/%05d/weights'%(ci)][()])
                cell_timeseries.append(file_handle['/cell/%05d/timeseries'%(ci)][()])

        return [cell_block_id, cell_xyz, cell_weights, cell_timeseries]    
    
    cell_block_id, cell_xyz, cell_weights, cell_timeseries = list(zip(*map(add_data, tqdm(idx_block_valids))))  ## 6 mins
    
    cell_block_id = [ii for bi in cell_block_id for ii in bi]
    cell_xyz = [xyzi for ci in cell_xyz for xyzi in ci]
    cell_weights = [wi for ci in cell_weights for wi in ci]
    cell_timeseries = [ti for ci in cell_timeseries for ti in ci]    

    # convert lists to arrays
    cn = len(cell_xyz)
    cell_block_id = np.array(cell_block_id)
    cell_lengths = np.array([len(i) for i in cell_weights])
    cell_xyz_array = np.full((cn, np.max(cell_lengths), 3), -1, dtype=int)
    cell_weights_array = np.full((cn, np.max(cell_lengths)), np.nan)
    for ci, li in enumerate(cell_lengths):
        cell_xyz_array[ci, :li] = cell_xyz[ci]
        cell_weights_array[ci, :li] = cell_weights[ci]
    cell_timeseries_array = pd._libs.lib.to_object_array(cell_timeseries)
    
    return cell_block_id, cell_xyz_array, cell_weights_array, cell_timeseries_array, cell_lengths

In [None]:
for color_i in range(p.n_colors):
    
    fullname_cells = os.path.join(p.dir_output, 'cells%s_clean'%(color_i))
    if os.path.isfile(fullname_cells+hdf):
        continue    

    cell_block_id, cell_xyz, cell_weights, cell_timeseries, cell_lengths = collect_blocks(color_i, parameters)  ## about 55 mins
    
    fullname_volmean = os.path.join(p.dir_output, 'volume%d'%(color_i))
    with h5py.File(fullname_volmean+hdf, 'r') as file_handle:
        volume_mask = file_handle['volume_mask'][()].T
        x, y, z = volume_mask.shape    
        
    cell_x = cell_xyz[:, :, 0]
    cell_y = cell_xyz[:, :, 1]
    cell_z = cell_xyz[:, :, 2]
    cell_w = np.nansum(cell_weights, 1)

    ix = np.any(np.isnan(cell_timeseries), 1)
    if np.any(ix):
        print('nans (to be removed): %d'%np.count_nonzero(ix))
        cell_timeseries[ix] = 0

    cell_valids = np.zeros(len(cell_w), dtype=bool)
    for i, (li, xi, yi, zi) in enumerate(zip(cell_lengths, cell_x, cell_y, cell_z)):
        cell_valids[i] = np.mean(volume_mask[xi[:li], yi[:li], zi[:li]]) > p.thr_mask        
        
    ## 29 mins for next two sections
    # brain mask array
    volume_list = [[[[] for zi in range(z)] for yi in range(y)] for xi in range(x)]
    volume_cell_n = np.zeros((x, y, z), dtype='int')
    for i, (li, vi) in enumerate(zip(cell_lengths, cell_valids)):
        for j in range(li if vi else 0):
            xij, yij, zij = cell_x[i, j], cell_y[i, j], cell_z[i, j]
            volume_list[xij][yij][zij].append(i)
            volume_cell_n[xij, yij, zij] += 1

    pair_cells = [pi for a in volume_list for b in a for c in b for pi in combinations(c, 2)]
    assert(len(pair_cells) == np.sum(volume_cell_n * (volume_cell_n - 1) / 2))

    # remove duplicate cells
    pair_id, pair_count = np.unique(pair_cells, axis=0, return_counts=True)
    for pi, fi in zip(pair_id, pair_count):
        pair_overlap = (fi / np.mean(cell_lengths[pi])) > thr_similarity
        pair_correlation = np.corrcoef(cell_timeseries[pi])[0, 1] > thr_similarity
        if (pair_overlap and pair_correlation):
            cell_valids[pi[np.argmin(cell_w[pi])]] = 0        
            
    ## get valid version of cells
    cell_block_id = cell_block_id[cell_valids]
    cell_weights = cell_weights[cell_valids].astype(dtype)
    cell_timeseries = cell_timeseries[cell_valids].astype(dtype)
    cell_lengths = cell_lengths[cell_valids]
    cell_x = cell_x[cell_valids]
    cell_y = cell_y[cell_valids]
    cell_z = cell_z[cell_valids]
    cell_w = cell_w[cell_valids]
    ## end get valid version of cells      
    
    from analysis_toolbox.utils import chunk_and_saveas_zarr

    cell_timeseries_da = chunk_and_saveas_zarr(cell_timeseries, chunks=(200, -1), reload='dask')
    
    from functools import partial
    clean_signal_p = partial(clean_signal, parameters)
    clean_signal_p()  
    
    timebase_da = da.apply_along_axis(clean_signal_p, 1, cell_timeseries_da, dtype=cell_timeseries.dtype, shape=(2, cell_timeseries.shape[1]))

    ## 7 mins
    with start_cluster(
        cores=4,
        force_local=False, processes=64,
        death_timeout=1200, store_logs=True, spillover=True, verbose=True) as cluster, Client(cluster) as client:

        display(client)

        cmp_start = time.process_time()
        timebase = timebase_da.compute()
        print(f'Compute time: {time.process_time() - cmp_start} s')         
        
        
    cell_timeseries1 = timebase[:, 0, :]
    cell_baseline1 = timebase[:, 1, :]    
    
    ## 2 mins
    # check that all series are in single precision
    assert(cell_weights.dtype == dtype)
    assert(cell_timeseries.dtype == dtype)
    assert(cell_timeseries1.dtype == dtype)
    assert(cell_baseline1.dtype == dtype)

    n = np.count_nonzero(cell_valids)
    volume_id = -1 + np.zeros((x, y, z))
    volume_weight = np.zeros((x, y, z))
    for i, li in enumerate(cell_lengths):
        for j in range(li):
            xij, yij, zij = cell_x[i, j], cell_y[i, j], cell_z[i, j]
            if cell_weights[i, j] > volume_weight[xij, yij, zij]:
                volume_id[xij, yij, zij] = i
                volume_weight[xij, yij, zij] = cell_weights[i, j]    
    
    ## 2 mins
    with h5py.File(fullname_volmean+hdf, 'r') as file_handle:
        background = file_handle['background'][()]

    with h5py.File(fullname_cells+hdf, 'w') as file_handle:
        file_handle['n'] = n
        file_handle['t'] = p.lt
        file_handle['x'] = x
        file_handle['y'] = y
        file_handle['z'] = z
        file_handle['cell_x'] = cell_x
        file_handle['cell_y'] = cell_y
        file_handle['cell_z'] = cell_z
        file_handle['cell_block_id'] = cell_block_id
        file_handle['volume_id'] = volume_id
        file_handle['volume_weight'] = volume_weight
        file_handle['cell_weights'] = cell_weights
        file_handle['cell_timeseries_raw'] = cell_timeseries
        file_handle['cell_timeseries'] = cell_timeseries1
        file_handle['cell_baseline'] = cell_baseline1
        file_handle['background'] = background                

## Clean up

In [22]:
from voluseg._tools.constants import hdf

In [None]:
%%time
## have a look at results
from analysis_toolbox.seg_helper import voluseg
im_voluseg = voluseg(segmented_root=dir_output, im_dir=dir_input)

/scratch/limj2/data/raw/20211008/fish01/6dpf_elavl3-gc8f-gfap-jregeco_MG-vs-NGGU-trunc35_fish01_exp03_20211009_004814/im_CM1/

Opening metadata file...
Metadata not available...
Loading brain volume...
Loading segments...
Calculating dff...


In [None]:
im_voluseg.cell_dff.shape

In [23]:
# clean up
completion = 1
for color_i in range(p.n_colors):
    fullname_cells = os.path.join(p.dir_output, 'cells%s_clean'%(color_i))
    if not os.path.isfile(fullname_cells+hdf):
        completion = 0

if not p.save_volume:
    if completion:
        try:    
            shutil.rmtree(os.path.join(p.dir_output, 'cells'))
            shutil.rmtree(os.path.join(p.dir_output, 'volumes'))
        except:
            pass

In [24]:
from analysis_toolbox.utils import disk_usage

disk_usage(root_dir)

Total: 48.4TiB
Used: 23.0TiB
Free: 25.5TiB
