In [None]:
import os
import h5py
import numpy as np
import pandas as pd
import time
from multiprocessing import Pool
from tqdm import tqdm
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output

from image_processing import process_single_image

def load_chunk(image_file, start, end):
    """
    Load a chunk of images from the H5 file.
    Slicing the dataset loads only the required subset.
    """
    with h5py.File(image_file, 'r') as f:
        images = f['/entry/data/images']
        chunk = images[start:end].astype(np.float32)
    return chunk

def process_images(image_file, mask, n_wedges=4, n_rad_bins=100, xatol=0.01, fatol=10, 
                   chunk_size=1000, frame_interval=10, verbose=False):
    # Open the image file to determine the total number of images.
    with h5py.File(image_file, 'r') as f_img:
        n_images = f_img['/entry/data/images'].shape[0]

    # Compute total frames that will be processed:
    # Always include the first (0) and the last (n_images - 1) frames and any frame where index % frame_interval == 0.
    valid_indices = set([0, n_images - 1]) | {i for i in range(n_images) if i % frame_interval == 0}
    total_centers = len(valid_indices)

    # Define output CSV file path.
    csv_file = os.path.join(os.path.dirname(image_file), 
                            f"centers_xatol_{xatol}_frameinterval_{frame_interval}.csv")
    if os.path.exists(csv_file):
        os.remove(csv_file)
    header_written = False
    start_time = time.time()

    # Initialize tqdm progress bar.
    pbar = tqdm(total=total_centers, desc="Calculating centers")

    # Create a multiprocessing Pool.
    with Pool() as pool:
        for start_idx in range(0, n_images, chunk_size):
            end_idx = min(start_idx + chunk_size, n_images)
            # Determine which global frame indices in this chunk to process.
            chunk_frame_indices = [
                i for i in range(start_idx, end_idx)
                if (i == 0 or i == n_images - 1 or (i % frame_interval == 0))
            ]
            if not chunk_frame_indices:
                continue  # Skip this chunk if no frames meet the criteria.

            current_chunk = load_chunk(image_file, start_idx, end_idx)

            # Prepare arguments for each selected image.
            args = [
                (current_chunk[i - start_idx], mask, n_wedges, n_rad_bins, xatol, fatol, verbose)
                for i in chunk_frame_indices
            ]

            # Process selected images in parallel.
            results = pool.starmap(process_single_image, args)

            # Write results incrementally.
            df_chunk = pd.DataFrame(
                [[i, res[0], res[1]] for i, res in zip(chunk_frame_indices, results)],
                columns=["frame_number", "center_x", "center_y"]
            )
            mode = "w" if not header_written else "a"
            df_chunk.to_csv(csv_file, index=False, mode=mode, header=not header_written)
            header_written = True

            # Update progress bar by the number of frames processed in this chunk.
            pbar.update(len(chunk_frame_indices))
            if verbose:
                print(f"Processed frames {chunk_frame_indices[0]} to {chunk_frame_indices[-1]} "
                      f"from chunk {start_idx} to {end_idx}")

    pbar.close()
    elapsed = time.time() - start_time
    print("Processing complete in {:.1f}s".format(elapsed))
    print("CSV file written to:", csv_file)

# ---------------------------
# Build the UI
# ---------------------------

# File chooser for the image H5 file.
image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

# File chooser for the mask H5 file.
mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select Mask H5 File"
mask_file_chooser.filter_pattern = "*.h5"

# Checkbox: whether to use the mask from the file.
use_mask_checkbox = widgets.Checkbox(
    value=True,
    description="Use Mask"
)

# Widget for setting xatol.
xatol_widget = widgets.FloatText(
    value=0.01,
    description="xatol:",
    layout=widgets.Layout(width="200px")
)

# Widget for setting the frame interval.
frame_interval_widget = widgets.IntText(
    value=15,
    description="Frame Interval:",
    layout=widgets.Layout(width="200px")
)

# Checkbox for verbose output.
verbose_checkbox = widgets.Checkbox(
    value=False,
    description="Verbose"
)

# Button to start the image processing.
process_images_button = widgets.Button(
    description="Process Images",
    button_style="primary"
)

# Output widget to capture log messages.
output_area = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_process_images_clicked(b):
    with output_area:
        clear_output()
        # Get the selected image file.
        image_file = image_file_chooser.selected
        if not image_file:
            print("Please select an H5 image file.")
            return
        
        # Get the selected mask file.
        mask_file = mask_file_chooser.selected
        if not mask_file:
            print("Please select a mask H5 file.")
            return
        
        # Load the mask.
        use_mask_val = use_mask_checkbox.value
        try:
            with h5py.File(mask_file, 'r') as f_mask:
                if use_mask_val:
                    mask = f_mask['/mask'][:].astype(bool)
                else:
                    mask_dataset = f_mask['/mask']
                    sample_mask = mask_dataset[0]
                    mask = np.ones_like(sample_mask, dtype=bool)
        except Exception as e:
            print("Error loading mask file:", e)
            return
        
        # Get processing parameters.
        xatol_val = xatol_widget.value
        frame_interval_val = frame_interval_widget.value
        verbose_val = verbose_checkbox.value
        
        # Inform the user and run the process.
        print("Starting image processing...")
        print(f"Image file: {image_file}")
        print(f"Mask file: {mask_file}")
        print(f"xatol: {xatol_val}, Frame Interval: {frame_interval_val}, Verbose: {verbose_val}")
        
        process_images(image_file, mask, xatol=xatol_val, frame_interval=frame_interval_val,
                       verbose=verbose_val)

process_images_button.on_click(on_process_images_clicked)

# Assemble the UI layout.
ui = widgets.VBox([
    widgets.HTML("<h2>Process Images from H5 File</h2>"),
    image_file_chooser,
    mask_file_chooser,
    use_mask_checkbox,
    widgets.HBox([xatol_widget, frame_interval_widget, verbose_checkbox]),
    process_images_button,
    output_area
])

display(ui)


VBox(children=(HTML(value='<h2>Process Images from H5 File</h2>'), FileChooser(path='/Users/xiaodong/Desktop/i…