# Phase image processing code

This code process the already unwrapped (puma) phase images from the matlab code for quantification

## Load libraries

In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import math
import numpy as np
import pathlib
import h5py
import rembg


## Load cells

In [None]:
#directory_name = "/home/jvasquez/updepla/users/jvasquez/machine_learning_datasets/20241015_Aug_to_Oct_2024_segment/hek"
directory_name = '/home/jvasquez/updepla/users/jvasquez/machine_learning_datasets/20251004_senes_processing/segment'

In [3]:
# Create a Path object for the directory
data_dir = pathlib.Path(directory_name)

# Count the number of HDF5 files in the directory
hdf5_files = list(data_dir.glob('*.hdf5'))
image_count = len(hdf5_files)
print("Total HDF5 files:", image_count)

# Create an empty dictionary to store image data
image_data_dict = {}

# Loop through each HDF5 file
for hdf5_file in hdf5_files:
    # Open the HDF5 file in read mode
    with h5py.File(hdf5_file, 'r') as hdf_file:
        # Check if 'phase_image' dataset exists in the HDF5 file
        if 'phase_image' in hdf_file:
            # Access the 'phase_image' dataset
            phase_images_dataset = hdf_file['phase_image']
            
            # Get the image name
            image_name = hdf5_file.stem
            
            # Read data from the dataset into an array
            image_data = phase_images_dataset[:]
            
            # Save image name and data in the dictionary
            image_data_dict[image_name] = image_data
        else:
            print(f"Dataset 'phase_image' not found in {hdf5_file}")

# Now image_data_dict contains image names as keys and corresponding image arrays as values


Total HDF5 files: 593


In [4]:
hdf5_file.stem

'20250917_BJ_ctrl_556'

## Plotting

In [None]:
# Select 3 random indices
random_indices = np.random.choice(len(image_data_dict), 3, replace=False)

# Store the selected indices for downstream analysis
# For example, you can save them to a file
np.savetxt("random_indices.txt", random_indices)

# Plot the selected random images
for idx in random_indices:
    image_names = list(image_data_dict.keys())
    if idx < len(image_names):
        image_name = image_names[idx]
        test_image = image_data_dict[image_name]
        
        # Plot the image with turbo colormap
        plt.figure()
        plt.imshow(test_image, cmap='turbo')
        plt.colorbar()
        plt.title(f"{image_name}")
        plt.show()
    else:
        print(f"Index {idx} is out of range.")


### Function for processing

In [6]:
import rembg

model_name = 'silueta'

sessions: dict[str, rembg.sessions.BaseSession] = {}
session = sessions.setdefault(model_name, rembg.new_session(model_name))

def segmentation_pipeline_rembg(image: np.ndarray) -> np.ndarray:
    """
    Wrapper function around `rembg.remove` - handling both SAM and regular models.
    """
    segmentation = rembg.remove(
        data=image,
        session=session,
        only_mask=True,
        post_process_mask=True,
    )
    segmentation = (segmentation == 255)
    segmentation = segmentation.astype(np.uint8)
    
    segmentation_mask =  segmentation
    
    segmented_image = image*segmentation_mask
    min_phase_value = np.min(segmented_image) # test without zeroing the components below 0 inside the mask
    segmented_image_scaled = segmented_image - min_phase_value
    
    output_phase_image = segmented_image_scaled * segmentation_mask
    
    segmentation = output_phase_image
    
    return segmentation

segmentation_pipeline_rembg_vectorized = np.vectorize(
        segmentation_pipeline_rembg, signature="(n,m)->(n,m)"
)

## High throughput processing

In [None]:
# Assuming 'data_dir' is defined elsewhere in your code
# data_dir = pathlib.Path(directory_name)
# hdf5_files = list(data_dir.glob('*.hdf5'))

# Process each HDF5 file
for hdf5_file in hdf5_files:
    image_name = hdf5_file.stem  # Extract the image name from the HDF5 file name
    with h5py.File(hdf5_file, 'a') as f:
        # Check if 'segmented_phase_image' dataset already exists in the HDF5 file
        #if 'segmented_phase_image' in f:
        #    print(f"Skipping {hdf5_file} as 'segmented_phase_image' already exists.")
        #    continue  # Skip processing this file
        if image_name in image_data_dict:
            # Load the phase image data from the dictionary
            phase_image_data = image_data_dict[image_name]

            # Process the phase image data
            #segmented_phase_image = phase_image_processing(phase_image_data)
            segmented_phase_image = segmentation_pipeline_rembg(phase_image_data)
            # Save the segmented phase image as 'segmented_phase_image' dataset
            f.create_dataset('segmented_phase_image', data=segmented_phase_image)

            # Plot the segmented phase image
            plt.figure()
            plt.imshow(segmented_phase_image, cmap='turbo')
            plt.colorbar()
            plt.title(f"Segmented Phase Image: {image_name}")
            plt.show()
        else:
            print(f"Image data for {image_name} not found.")
