In [1]:
# Imports
import numpy as np
import pandas as pd
# import dask.array as da
from skimage.measure import regionprops_table
import time
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import feature_extraction as fe
import sparse


mask_path = '/net/beliveau/vol2/instrument/E9.5_290/Zoom_290_subset_test/dataset_fused_masks_cpsamr5.zarr'
n5_image_path = '/net/beliveau/vol2/instrument/E9.5_290/Zoom_290_subset_test/dataset_fused.n5'

def load_all_data(mask_path, 
                  n5_path, 
                  do_crop=False,
                  subset_size=None):
    # load mask and image data
    mask_da = fe.load_n5_zarr_array(mask_path)
    image_da = fe.load_n5_zarr_array(n5_path, n5_subpath='ch0/s0')

    if do_crop:
        crop_size = (800, 800, 800)  # z, y, x
        mask_da = mask_da[:crop_size[0], :crop_size[1], :crop_size[2]]
        image_da = image_da[:crop_size[0], :crop_size[1], :crop_size[2]]

    # convert to sparse
    chunk_shape = tuple(c[0] for c in mask_da.chunks)
    meta_block = sparse.COO.from_numpy(np.zeros(chunk_shape, 
                                                dtype=mask_da.dtype))
    mask_sparse = mask_da.map_blocks(
    fe.to_sparse,
    dtype=mask_da.dtype,
    meta=meta_block,
    chunks=mask_da.chunks
    )

    # find bounding boxes
    df_bboxes = fe.find_objects(mask_sparse).compute()
    df_bboxes = pd.DataFrame(df_bboxes)
    print(f"Found {len(df_bboxes)} objects")

    if subset_size is not None:
        obj_idxs = np.random.randint(0, 
                                     len(df_bboxes), 
                                     size=subset_size)
        test_objects = df_bboxes.iloc[obj_idxs]
        return mask_da, image_da, test_objects
    return mask_da, image_da 

mask_da, image_da, test_objects = load_all_data(mask_path=mask_path,
                                                n5_path=n5_image_path,
                                                do_crop=True,
                                                subset_size=100)

2025-07-10 10:28:40,236 - INFO - --- Environment Versions ---
2025-07-10 10:28:40,236 - INFO - Platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
2025-07-10 10:28:40,237 - INFO - Python: 3.13.3 | packaged by conda-forge | (main, Apr 14 2025, 20:44:03) [GCC 13.3.0]
2025-07-10 10:28:40,238 - INFO - Dask: 2025.2.0
2025-07-10 10:28:40,240 - INFO - Distributed: 2025.2.0
2025-07-10 10:28:40,241 - INFO - Cloudpickle: 3.0.0
2025-07-10 10:28:40,241 - INFO - Msgpack: 1.0.8
2025-07-10 10:28:40,242 - INFO - Zarr: 2.13.3
2025-07-10 10:28:40,243 - INFO - NumPy: 2.2.6
2025-07-10 10:28:40,243 - INFO - Scikit-image: 0.25.0
2025-07-10 10:28:40,244 - INFO - --- Dask Config (relevant parts) ---
2025-07-10 10:28:40,245 - INFO - distributed.comm.compression: False
2025-07-10 10:28:40,246 - INFO - --- End Environment Info ---
2025-07-10 10:28:40,247 - INFO - Attempting to load from: /net/beliveau/vol2/instrument/E9.5_290/Zoom_290_subset_test/dataset_fused_masks_cpsamr5.zarr
2025-07-10 10:28:40,251 - IN

Found 976 objects


In [None]:
# 5. Define enhanced process_object function
import align_3d as align
import plot_utils
from importlib import reload
reload(align)
reload(plot_utils)
# import warnings
from scipy.ndimage import center_of_mass
# warnings.simplefilter("always")
from aicsshparam import shparam

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def process_object(obj, mask_da, image_da, lmax):
    obj_id = int(obj.name)
    slice_z = obj[0]
    slice_y = obj[1] 
    slice_x = obj[2]

    # Compute centroid of object (in global coordinates)
    centroid_z = slice_z.start + (slice_z.stop - slice_z.start) // 2
    centroid_y = slice_y.start + (slice_y.stop - slice_y.start) // 2
    centroid_x = slice_x.start + (slice_x.stop - slice_x.start) // 2

    try:
        # Get only the pixels belonging to the object of interest
        label_slice = np.where(mask_da == obj_id, mask_da, 0)
        # Basic regionprops
        props = regionprops_table(label_slice, image_da, properties=["label", 
                                                                     "area", 
                                                                     "mean_intensity"])
        area = props['area'][0]
        if area < 4000:  # Increase threshold and fix indexing
            # logger.info(f"Object {obj_id} too small (area={area}). Skipping...")
            return None
        # else:
        #     logger.info(f"Object {obj_id} is large enough (area={area}). Processing...")

        # Add global centroid coordinates
        props['centroid_z'] = centroid_z
        props['centroid_y'] = centroid_y
        props['centroid_x'] = centroid_x
        # Align object
        aligned_slice, props = align.align_object(label_slice, props) # align object also adds features to df_props
        
        # # Check if aligned object has any foreground voxels
        # if np.sum(aligned_slice) == 0:
        #     logger.warning(f"Object {obj_id} has no foreground voxels after alignment. Skipping spherical harmonics...")
        #     return None
            
        # Spherical harmonics computation
        (coeffs, _), _ = shparam.get_shcoeffs(
            aligned_slice, 
            lmax=lmax,
            alignment_2d=False)
        
        coeffs.update({'label': obj_id})
        final_dict = props | coeffs
        # logger.info(f"Processed object {obj_id}")
        return final_dict
    except Exception as e:
        logger.error(f"Error processing object {obj_id}: {e}", exc_info=True)
        return None


In [None]:
lmax_times = []
lmax_range = list(range(4, 32, 4))
processed_objects = {}
i = 1
import time
for lmax in lmax_range:
    processed_objects[lmax] = []
    for idx, obj in test_objects.iterrows():
        # Extract bounding box
        slice_z, slice_y, slice_x = obj[0], obj[1], obj[2]
        bboxes = tuple([slice_z, slice_y, slice_x])    
        # Crop mask and image to bounding box
        mask_da_obj = mask_da[slice_z, slice_y, slice_x].compute()
        image_da_obj = image_da[slice_z, slice_y, slice_x].compute()
        # Process object
        obj_series = pd.Series(obj, name=idx)  # Create series with object ID as name
        result = process_object(obj_series, 
                                mask_da_obj, 
                                image_da_obj,
                                lmax=lmax)
        processed_objects[lmax].append(result)
print(processed_objects)

# def bootstrap_lmax_selection(self, processed_objects, n_bootstrap=100):
#     lmax_errors = {}
    
#     for lmax in self.lmax_range:
#         bootstrap_errors = []
        
#         for _ in range(n_bootstrap):
#             # Sample objects with replacement
#             sample = np.random.choice(processed_objects, 
#                                     size=len(processed_objects), 
#                                     replace=True)
            
#             sample_errors = []
#             for obj in sample:
#                 # Use existing optimized fitting
#                 error = compute_reconstruction_error(obj, lmax)
#                 sample_errors.append(error)
            
#             bootstrap_errors.append(np.mean(sample_errors))
        
#         lmax_errors[lmax] = np.mean(bootstrap_errors)
    
#     return min(lmax_errors.keys(), key=lambda k: lmax_errors[k])

2025-07-10 10:29:12,083 - INFO - Object 5796 too small (area=729.0). Skipping...
2025-07-10 10:29:42,845 - INFO - Object 7767 too small (area=3862.0). Skipping...
2025-07-10 10:29:45,136 - INFO - Object 3108 too small (area=73.0). Skipping...
