In [1]:
import os
import h5py
import numpy as np
import pandas as pd
import threading
import time

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display

# Import the updated functions.
from ICFTOTAL import center_of_mass_initial_guess, find_diffraction_center

# Global flag for stopping processing.
stop_processing = False

# -------------------------------
# File chooser widgets.
# -------------------------------
image_file_chooser = FileChooser("/Users/xiaodong/Desktop/UOX-data/UOX1_sub/", filename="UOX1_sub.h5")
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser("/Users/xiaodong/mask/", filename="pxmask.h5")
mask_file_chooser.title = "Select H5 Mask File"
mask_file_chooser.filter_pattern = "*.h5"

# -------------------------------
# Processing parameters.
# -------------------------------
n_wedges_widget = widgets.IntText(
    value=4, description="n_wedges:", layout=widgets.Layout(width='200px')
)
n_rad_bins_widget = widgets.IntText(
    value=100, description="n_rad_bins:", layout=widgets.Layout(width='200px')
)

xatol_widget = widgets.FloatText(
    value=1, 
    description='xatol:', 
    layout=widgets.Layout(width='200px')
)
fatol_widget = widgets.FloatText(
    value=10, 
    description='fatol:', 
    layout=widgets.Layout(width='200px')
)


# -------------------------------
# Advanced parameters: center specification.
# -------------------------------
use_default_center_widget = widgets.Checkbox(
    value=True, description="Use default center-of-mass guess"
)
initial_center_x_widget = widgets.FloatText(
    value=0.0, description="Initial Center X:", layout=widgets.Layout(width='200px')
)
initial_center_y_widget = widgets.FloatText(
    value=0.0, description="Initial Center Y:", layout=widgets.Layout(width='200px')
)
def get_initial_center():
    # Return None to signal that the center should be computed automatically.
    return None if use_default_center_widget.value else (initial_center_x_widget.value, initial_center_y_widget.value)

verbose_widget = widgets.Checkbox(
    value=True, description="Verbose:"
)

# -------------------------------
# Control buttons.
# -------------------------------
process_button = widgets.Button(
    description="Process Images", button_style="primary"
)
stop_button = widgets.Button(
    description="Stop Processing", button_style="danger"
)

# -------------------------------
# Progress bar widget.
# -------------------------------
# progress_bar = widgets.FloatProgress(
#     min=0, max=100, value=0, description="Progress:"
# )

progress_bar = widgets.FloatProgress(
    min=0, max=100, value=0,
    description="Progress:",
    layout=widgets.Layout(width='600px')  # Increase overall width
)
# Increase the width for the description text area.
progress_bar.style = {'description_width': '250px'}

def update_ui_state(processing=False):
    """Enable or disable UI elements based on processing state."""
    process_button.disabled = processing
    image_file_chooser.disabled = processing
    mask_file_chooser.disabled = processing

def process_images():
    global stop_processing
    stop_processing = False  # Reset stop flag.
    update_ui_state(processing=True)
    
    try:
        # Retrieve selected file paths.
        image_file = image_file_chooser.selected
        mask_file = mask_file_chooser.selected
        
        if not image_file or not mask_file:
            print("Please select both an image file and a mask file.")
            return
        
        # Load the mask and convert to boolean.
        with h5py.File(mask_file, 'r') as f_mask:
            mask = f_mask['/mask'][:].astype(bool)
        
        # Retrieve parameter values.
        n_wedges = n_wedges_widget.value
        n_rad_bins = n_rad_bins_widget.value
        xatol = xatol_widget.value
        fatol = fatol_widget.value
        user_defined_center = get_initial_center()
        verbose = verbose_widget.value
        
        # Define output CSV file path.
        csv_file = os.path.join(os.path.dirname(image_file), f"centers_xatol_{xatol}_fatol_{fatol}.csv")
        if os.path.exists(csv_file):
            os.remove(csv_file)
        header_written = False
        
        print("Opening image file...")
        with h5py.File(image_file, 'r') as f_img:
            images = f_img['/entry/data/images']
            n_images = images.shape[0]
            print(f"Total images: {n_images}")
            
            progress_bar.max = n_images
            progress_bar.value = 0
            start_time = time.time()
            
            # Initialize prev_center with the user-defined (or computed) initial guess.
            prev_center = user_defined_center
            
            # Process each image.
            for i in range(n_images):
                if stop_processing:
                    print("\nProcessing stopped by user.")
                    break
                
                image = images[i].astype(np.float32)
                
                # Use the previous refined center if available, otherwise compute a center-of-mass.
                if prev_center is None:
                    init_center = center_of_mass_initial_guess(image, mask)
                else:
                    init_center = prev_center
                
                # Refine the diffraction center.
                refined_center = find_diffraction_center(
                    image, mask,
                    initial_center=init_center,
                    n_wedges=n_wedges,
                    n_rad_bins=n_rad_bins,
                    xatol = xatol,
                    fatol = fatol,
                    verbose=verbose
                )
                # Immediately pass the refined center as the initial guess for the next image.
                prev_center = refined_center
                
                # Write the result incrementally.
                df_chunk = pd.DataFrame([[i, refined_center[0], refined_center[1]]],
                                        columns=["frame_number", "center_x", "center_y"])
                mode = "w" if not header_written else "a"
                df_chunk.to_csv(csv_file, index=False, mode=mode, header=not header_written)
                header_written = True
                
                # Update progress.
                progress_bar.value = i + 1
                percent = (i + 1) / n_images * 100
                elapsed = time.time() - start_time
                remaining = (elapsed / (i + 1)) * (n_images - (i + 1))
                progress_bar.description = f"{percent:.1f}% - Elapsed: {elapsed:.1f}s, Rem: {remaining:.1f}s"
                print(f"[{time.strftime('%H:%M:%S')}] Processed image {i+1}/{n_images} ({percent:.1f}%)")
                
        print("\nProcessing complete.")
        print("CSV file written to:", csv_file)
    except Exception as e:
        print("An error occurred during processing:")
        print(str(e))
    finally:
        update_ui_state(processing=False)

def on_process_button_clicked(b):
    threading.Thread(target=process_images, daemon=True).start()

def on_stop_button_clicked(b):
    global stop_processing
    stop_processing = True

# Set button callbacks.
process_button.on_click(on_process_button_clicked)
stop_button.on_click(on_stop_button_clicked)

# -------------------------------
# Assemble the UI.
# -------------------------------
basic_params_box = widgets.VBox([n_wedges_widget, n_rad_bins_widget])
center_box = widgets.HBox([use_default_center_widget, initial_center_x_widget, initial_center_y_widget])
advanced_params_box = widgets.VBox([center_box, verbose_widget, xatol_widget, fatol_widget])
file_chooser_box = widgets.HBox([image_file_chooser, mask_file_chooser])
button_box = widgets.HBox([process_button, stop_button])

ui = widgets.VBox([
    widgets.HTML("<h2>Interactive Image Processing Tool</h2>"),
    file_chooser_box,
    widgets.HTML("<h3>Processing Parameters</h3>"),
    basic_params_box,
    widgets.HTML("<h3>Advanced Parameters</h3>"),
    advanced_params_box,
    button_box,
    progress_bar
])

display(ui)


VBox(children=(HTML(value='<h2>Interactive Image Processing Tool</h2>'), HBox(children=(FileChooser(path='/Use…

Opening image file...
Total images: 100
Starting center refinement with initial center: (np.float64(515.6643540693383), np.float64(513.4013368936525))
Candidate center: (np.float64(515.6643540693383), np.float64(513.4013368936525)), Metric: 319.8850267379679
Metric at initial center: 319.8850267379679
Initial center: [515.66435407 513.40133689]
Candidate center: [515.66435407 513.40133689], Metric: 319.8850267379679
Candidate center: [541.44757177 513.40133689], Metric: 1510.8473684210526
Candidate center: [515.66435407 539.07140374], Metric: 1438.3567708333333
Candidate center: [489.88113637 539.07140374], Metric: 1764.6806282722514
Candidate center: [528.55596292 519.8188536 ], Metric: 813.9465240641712
Candidate center: [528.55596292 494.14878676], Metric: 2167.103260869565
Candidate center: [518.88725628 527.84074949], Metric: 771.3101604278074
Candidate center: [505.99564743 521.42323278], Metric: 182.5752688172043
Candidate center: [494.71548969 522.22542237], Metric: 703.0026595