In [None]:
import os
import h5py
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm  # Use notebook-friendly tqdm

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output

from helpers import process_chunk

# Create file chooser widgets for the image and mask files.
image_file_chooser = FileChooser("/Users/xiaodong/Desktop/UOX-data/UOX1_sub/", filename="UOX1_sub.h5")
# image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser("/Users/xiaodong/mask/", filename="pxmask.h5")
# mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select H5 Mask File"
mask_file_chooser.filter_pattern = "*.h5"

# Create widgets for processing parameters.
threshold_widget = widgets.FloatText(
    value=0.1,
    description="Threshold:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
max_iters_widget = widgets.IntText(
    value=10,
    description="Max Iters:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
step_size_widget = widgets.IntText(
    value=1,
    description="Step Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_steps_widget = widgets.IntText(
    value=5,
    description="n_steps:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_wedges_widget = widgets.IntText(
    value=4,
    description="n_wedges:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_rad_bins_widget = widgets.IntText(
    value=100,
    description="n_rad_bins:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
chunk_size_widget = widgets.IntText(
    value=100,
    description="Chunk Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# Button to trigger processing.
process_button = widgets.Button(
    description="Process Images",
    button_style="primary"
)

# Output widget to capture logs and feedback.
processing_output = widgets.Output(layout={
    'border': '1px solid black',
    'padding': '5px',
    'height': '400px',
    'overflow_y': 'auto'
})

# Define the callback to run processing.
def on_process_button_clicked(b):
    with processing_output:
        clear_output()
        # Retrieve file paths.
        image_file = image_file_chooser.selected
        mask_file = mask_file_chooser.selected
        if not image_file:
            print("Please select an image file.")
            return
        if not mask_file:
            print("Please select a mask file.")
            return

        # Retrieve parameter values.
        threshold = threshold_widget.value
        max_iters = max_iters_widget.value
        step_size = step_size_widget.value
        n_steps = n_steps_widget.value
        n_wedges = n_wedges_widget.value
        n_rad_bins = n_rad_bins_widget.value
        chunk_size = chunk_size_widget.value

        # Compute output CSV file path.
        csv_file = os.path.join(os.path.dirname(image_file), "centers.csv")
        if os.path.exists(csv_file):
            os.remove(csv_file)
        header_written = False

        print("Opening image and mask files...")
        with h5py.File(image_file, 'r') as f_img, h5py.File(mask_file, 'r') as f_mask:
            dataset_images = f_img['/entry/data/images']
            # Read and cast mask to boolean.
            mask = f_mask['/mask'][:].astype(bool)
            
            n_images = dataset_images.shape[0]
            print(f"Total images: {n_images}")
            
            all_results = []
            # Create a tqdm progress bar with a total equal to the number of images.
            with tqdm(total=n_images, desc="Processing centers") as pbar:
                for i in range(0, n_images, chunk_size):
                    images_chunk = dataset_images[i:i+chunk_size].astype(np.float32)
                    chunk_results = process_chunk(i, images_chunk, mask, pbar,
                                                  threshold, max_iters, step_size,
                                                  n_steps, n_wedges, n_rad_bins)
                    # Sort results by frame number.
                    chunk_results = sorted(chunk_results, key=lambda x: x[0])
                    all_results.extend(chunk_results)
                    
                    # Create a DataFrame for the chunk.
                    df_chunk = pd.DataFrame(chunk_results, columns=["frame_number", "center"])
                    # Convert the "center" tuple to two separate columns.
                    df_chunk = df_chunk.drop(columns=["center"])
                    
                    mode = "w" if not header_written else "a"
                    df_chunk.to_csv(csv_file, index=False, mode=mode, header=(not header_written))
                    header_written = True
                    
        print("Processing complete.")
        print("CSV file written to:")
        print(csv_file)

# Set the button callback.
process_button.on_click(on_process_button_clicked)
 
# Arrange the UI elements.
param_box = widgets.VBox([
    threshold_widget,
    max_iters_widget,
    step_size_widget,
    n_steps_widget,
    n_wedges_widget,
    n_rad_bins_widget,
    chunk_size_widget
])

file_chooser_box = widgets.HBox([image_file_chooser, mask_file_chooser])

ui = widgets.VBox([
    widgets.HTML("<h2>Interactive Image Processing Tool</h2>"),
    file_chooser_box,
    widgets.HTML("<h3>Processing Parameters</h3>"),
    param_box,
    process_button,
    widgets.HTML("<h3>Logs & Feedback</h3>"),
    processing_output
])

display(ui)
