In [1]:
import os
import time
import h5py
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
from multiprocessing import Pool
from tqdm import tqdm
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore', message='invalid value encountered in subtract', category=RuntimeWarning)

# Import custom functions
from image_processing import process_single_image
from update_h5 import create_updated_h5
from statsmodels.nonparametric.smoothers_lowess import lowess

# ----------------------------------
# Section 1: Process Images UI
# ----------------------------------

def load_chunk(image_file, start, end):
    with h5py.File(image_file, 'r') as f:
        images = f['/entry/data/images']
        chunk_images = images[start:end].astype(np.float32)
        indices = f['/entry/data/index'][start:end]
    return chunk_images, indices

def process_images(image_file, mask, n_wedges=4, n_rad_bins=100, xatol=0.01, fatol=10, 
                   chunk_size=1000, frame_interval=10, verbose=False,
                   xmin=0, xmax=99999, ymin=0, ymax=99999):
    """
    Processes images in chunks and writes valid centers to CSV.
    Only frames 0, last, and those in multiples of frame_interval are processed.
    Centers must be finite and within the [xmin, xmax, ymin, ymax] region.
    """
    with h5py.File(image_file, 'r') as f_img:
        n_images = f_img['/entry/data/images'].shape[0]

    valid_indices = set([0, n_images - 1]) | {i for i in range(n_images) if i % frame_interval == 0}
    total_centers = len(valid_indices)

    csv_file = os.path.join(
        os.path.dirname(image_file),
        f"centers_xatol_{xatol}_frameinterval_{frame_interval}.csv"
    )
    if os.path.exists(csv_file):
        os.remove(csv_file)
    header_written = False
    start_time = time.time()
    pbar = tqdm(total=total_centers, desc="Calculating centers")

    with Pool() as pool:
        for start_idx in range(0, n_images, chunk_size):
            end_idx = min(start_idx + chunk_size, n_images)
            # Determine which frame indices in this chunk to process.
            chunk_frame_indices = [
                i for i in range(start_idx, end_idx)
                if (i == 0 or i == n_images - 1 or (i % frame_interval == 0))
            ]
            if not chunk_frame_indices:
                continue
            current_chunk, current_indices = load_chunk(image_file, start_idx, end_idx)
            
            # Get image dimensions from the first image in the chunk.
            # (Assumes all images in the dataset share the same shape.)
            img_height, img_width = current_chunk[0].shape

            args = [
                (current_chunk[i - start_idx], mask, n_wedges, n_rad_bins, xatol, fatol, verbose)
                for i in chunk_frame_indices
            ]
            results = pool.starmap(process_single_image, args)
            
            # Build rows, but only include centers that are valid (finite and within bounds).
            rows = []
            for i, res in zip(chunk_frame_indices, results):
                center_x, center_y = res[0], res[1]
                # Check for valid center
                if (np.isfinite(center_x) and np.isfinite(center_y) and 
                    xmin <= center_x < xmax and ymin <= center_y < ymax):
                    rows.append([i, current_indices[i - start_idx], center_x, center_y])
                else:
                    if verbose:
                        print(f"Skipping frame {i} due to invalid center: ({center_x}, {center_y})")
            
            if rows:
                df_chunk = pd.DataFrame(rows,
                                        columns=["frame_number", "data_index", "center_x", "center_y"])
                mode = "w" if not header_written else "a"
                df_chunk.to_csv(csv_file, index=False, mode=mode, header=not header_written)
                header_written = True

            pbar.update(len(chunk_frame_indices))
            if verbose:
                print(f"Processed frames {chunk_frame_indices[0]} to {chunk_frame_indices[-1]} "
                      f"from chunk {start_idx} to {end_idx}")

    pbar.close()
    elapsed = time.time() - start_time
    print("Processing complete in {:.1f}s".format(elapsed))
    print("CSV file written to:", csv_file)

# Build Process Images UI components
image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select Mask H5 File"
mask_file_chooser.filter_pattern = "*.h5"

use_mask_checkbox = widgets.Checkbox(value=True, description="Use Mask")
xatol_widget = widgets.FloatText(value=0.01, description="xatol:", layout=widgets.Layout(width="150px"))
frame_interval_widget = widgets.IntText(value=15, description="Frame Interval:", layout=widgets.Layout(width="150px"))
verbose_checkbox = widgets.Checkbox(value=False, description="Verbose")

# -- ADDED: Bounds Widgets --
xmin_widget = widgets.IntText(value=0, description="xmin:", layout=widgets.Layout(width="150px"))
xmax_widget = widgets.IntText(value=2048, description="xmax:", layout=widgets.Layout(width="150px"))
ymin_widget = widgets.IntText(value=0, description="ymin:", layout=widgets.Layout(width="150px"))
ymax_widget = widgets.IntText(value=2048, description="ymax:", layout=widgets.Layout(width="150px"))
# ---------------------------

process_images_button = widgets.Button(description="Process Images", button_style="primary")
output_area = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_process_images_clicked(b):
    with output_area:
        clear_output()
        image_file = image_file_chooser.selected
        if not image_file:
            print("Please select an H5 image file.")
            return
        mask_file = mask_file_chooser.selected
        if not mask_file:
            print("Please select a mask H5 file.")
            return
        use_mask_val = use_mask_checkbox.value
        try:
            with h5py.File(mask_file, 'r') as f_mask:
                if use_mask_val:
                    mask = f_mask['/mask'][:].astype(bool)
                else:
                    mask_dataset = f_mask['/mask']
                    sample_mask = mask_dataset[0]
                    mask = np.ones_like(sample_mask, dtype=bool)
        except Exception as e:
            print("Error loading mask file:", e)
            return
        
        xatol_val = xatol_widget.value
        frame_interval_val = frame_interval_widget.value
        verbose_val = verbose_checkbox.value
        
        # -- ADDED: Read bounding values --
        xmin_val = xmin_widget.value
        xmax_val = xmax_widget.value
        ymin_val = ymin_widget.value
        ymax_val = ymax_widget.value
        # ---------------------------------

        print("Starting image processing...")
        print(f"Image file: {image_file}")
        print(f"Mask file: {mask_file}")
        print(f"xatol: {xatol_val}, Frame Interval: {frame_interval_val}, Verbose: {verbose_val}")
        print(f"xmin={xmin_val}, xmax={xmax_val}, ymin={ymin_val}, ymax={ymax_val}")

        process_images(
            image_file,
            mask,
            xatol=xatol_val,
            frame_interval=frame_interval_val,
            verbose=verbose_val,
            xmin=xmin_val,
            xmax=xmax_val,
            ymin=ymin_val,
            ymax=ymax_val
        )

process_images_button.on_click(on_process_images_clicked)

process_images_ui = widgets.VBox([
    widgets.HTML("<h2>Process Images from H5 File</h2>"),
    image_file_chooser,
    mask_file_chooser,
    use_mask_checkbox,
    widgets.HBox([xatol_widget, frame_interval_widget, verbose_checkbox]),
    # Show bounding inputs in a row:
    widgets.HBox([xmin_widget, xmax_widget, ymin_widget, ymax_widget]),
    process_images_button,
    output_area
])

# ----------------------------------
# Section 2: Lowess-Fit & Update H5 UI
# ----------------------------------

# Part A: Lowess-Fit Centers & Shift CSV
csv_file_chooser = FileChooser(os.getcwd())
csv_file_chooser.title = "Select Input CSV File"
csv_file_chooser.filter_pattern = "*.csv"

shift_x_widget = widgets.FloatText(value=0, description="Shift X:", layout=widgets.Layout(width="200px"))
shift_y_widget = widgets.FloatText(value=0, description="Shift Y:", layout=widgets.Layout(width="200px"))
lowess_frac_widget = widgets.FloatSlider(
    value=0.1, min=0.01, max=1.0, step=0.01,
    description="Lowess frac:",
    continuous_update=False,
    layout=widgets.Layout(width="300px")
)
process_csv_button = widgets.Button(description="Lowess & Save CSV", button_style="primary")
csv_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
shifted_csv_path = None  # Will hold path to the new CSV with smoothed centers

def on_process_csv_clicked(b):
    global shifted_csv_path
    with csv_output:
        clear_output()
        input_csv = csv_file_chooser.selected
        if not input_csv:
            print("Please select an input CSV file.")
            return
        try:
            df = pd.read_csv(input_csv)
        except Exception as e:
            print(f"Error reading CSV: {e}")
            return
        # Ensure CSV has the required columns.
        for col in ['frame_number', 'data_index', 'center_x', 'center_y']:
            if col not in df.columns:
                print(f"CSV must contain '{col}' column.")
                return
        
        # Sort by data_index (smoothing is based on the index).
        df = df.sort_values('data_index').reset_index(drop=True)
        frames_original = df['frame_number'].values
        indices_original = df['data_index'].values
        original_x = df['center_x'].values
        original_y = df['center_y'].values
        
        frac_val = lowess_frac_widget.value
        
        # Perform LOWESS on the existing points
        lowess_x = lowess(endog=original_x, exog=indices_original,
                          frac=frac_val, return_sorted=True)
        lowess_y = lowess(endog=original_y, exog=indices_original,
                          frac=frac_val, return_sorted=True)
        
        # We'll create a full range of data_index from min to max:
        all_indices = np.arange(indices_original.min(), indices_original.max() + 1)
        
        # Interpolate the LOWESS results onto all_indices (this fills gaps)
        smoothed_x_all = np.interp(all_indices, lowess_x[:, 0], lowess_x[:, 1])
        smoothed_y_all = np.interp(all_indices, lowess_y[:, 0], lowess_y[:, 1])
        
        # Apply user shifts
        shift_x = shift_x_widget.value
        shift_y = shift_y_widget.value
        smoothed_x_all += shift_x
        smoothed_y_all += shift_y
        
        # Reconstruct frame_number for each data_index
        # For missing indices, we won't have a real frame_number; set -1 or NaN.
        idx_to_frame = dict(zip(indices_original, frames_original))
        frame_mapped = [idx_to_frame.get(idx, -1) for idx in all_indices]
        
        # Build the full output DataFrame with no gaps in data_index
        output_df = pd.DataFrame({
            'frame_number': frame_mapped,
            'data_index': all_indices,
            'center_x': smoothed_x_all,
            'center_y': smoothed_y_all
        })
        
        # Show a quick plot comparing only the original vs. the smoothed (subset)
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        
        # Original (X) vs Data Index
        axs[0].plot(indices_original, original_x, 'o--', label='Original X', markersize=4)
        axs[0].plot(all_indices, smoothed_x_all, 'o-', label='Lowess+Shift X', markersize=4)
        axs[0].set_title('Center X vs Data Index')
        axs[0].legend()
        axs[0].set_ylim(min(smoothed_x_all)-1, max(smoothed_x_all)+1)
        
        # Original (Y) vs Data Index
        axs[1].plot(indices_original, original_y, 'o--', label='Original Y', markersize=4)
        axs[1].plot(all_indices, smoothed_y_all, 'o-', label='Lowess+Shift Y', markersize=4)
        axs[1].set_title('Center Y vs Data Index')
        axs[1].legend()
        axs[1].set_ylim(min(smoothed_y_all)-1, max(smoothed_y_all)+1)
        
        plt.show()
        
        # Save the new CSV
        shifted_csv_path = os.path.join(
            os.path.dirname(input_csv),
            f"centers_lowess_{frac_val:.2f}_shifted_{shift_x}_{shift_y}.csv"
        )
        output_df.to_csv(shifted_csv_path, index=False)
        print(f"Created CSV with smoothed centers (including filled gaps):\n{shifted_csv_path}")

process_csv_button.on_click(on_process_csv_clicked)

lowess_ui = widgets.VBox([
    widgets.HTML("<h2>Lowess-Fit (Fill Missing Frames) + Shift</h2>"),
    csv_file_chooser,
    widgets.HBox([shift_x_widget, shift_y_widget]),
    lowess_frac_widget,
    process_csv_button,
    csv_output
])

# Part B: Update H5 with New Centers
image_file_chooser_h5 = FileChooser(os.getcwd())
image_file_chooser_h5.title = "Select H5 Image File for Updating"
image_file_chooser_h5.filter_pattern = "*.h5"

update_h5_button = widgets.Button(description="Update H5 with Shifted Centers", button_style="primary")
h5_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_update_h5_clicked(b):
    with h5_output:
        clear_output()
        if shifted_csv_path is None:
            print("No shifted CSV available. Please run the CSV shifting process first.")
            return
        image_file = image_file_chooser_h5.selected
        if not image_file:
            print("Please select an H5 image file.")
            return
        
        new_h5_path = os.path.join(
            os.path.dirname(image_file),
            os.path.splitext(os.path.basename(shifted_csv_path))[0] + '.h5'
        )
        try:
            create_updated_h5(image_file, new_h5_path, shifted_csv_path)
            print(f"Updated H5 file created at:\n{new_h5_path}")
        except Exception as e:
            print("Error updating H5 file:", e)

update_h5_button.on_click(on_update_h5_clicked)

h5_ui = widgets.VBox([
    widgets.HTML("<h2>Update H5 with Shifted Centers</h2>"),
    image_file_chooser_h5,
    update_h5_button,
    h5_output
])

# Combine Lowess and H5 Update UIs vertically
csv_h5_ui = widgets.VBox([lowess_ui, h5_ui])

# ----------------------------------
# Final Combined UI using Tabs
# ----------------------------------
tab = widgets.Tab(children=[process_images_ui, csv_h5_ui])
tab.set_title(0, "Process Images")
tab.set_title(1, "Lowess & H5 Update")
display(tab)


Tab(children=(VBox(children=(HTML(value='<h2>Process Images from H5 File</h2>'), FileChooser(path='/Users/xiao…