In [None]:
import os
import time
import h5py
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
from multiprocessing import Pool
from tqdm import tqdm
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore', message='invalid value encountered in subtract', category=RuntimeWarning)

# Custom modules (update these imports if needed)
from image_processing import process_single_image
from update_h5 import create_updated_h5
from statsmodels.nonparametric.smoothers_lowess import lowess

##########################################################################
# SECTION 1: PROCESS IMAGES + CREATE CSV
##########################################################################

def load_data_index_or_default(image_file):
    """
    Return a NumPy array of length n_images containing the data_index if present,
    otherwise an array equal to [0, 1, 2, ..., n_images-1].
    """
    with h5py.File(image_file, 'r') as f:
        images = f['/entry/data/images']
        n_images = images.shape[0]

        index_dataset = f.get('/entry/data/index')
        if index_dataset is not None:
            data_index = index_dataset[:]
        else:
            data_index = np.arange(n_images)

    return data_index  # length == n_images

def process_images(
        image_file,
        mask,
        n_wedges=4,
        n_rad_bins=100,
        xatol=0.01,
        fatol=10,
        chunk_size=1000,
        frame_interval=10,
        verbose=False,
        xmin=0,
        xmax=9999999,
        ymin=0,
        ymax=9999999
    ):
    """
    Processes all frames in the H5 (0..n_images-1).
    For frames that are first, last, or multiples of frame_interval, we run the 
    center-finding function. Otherwise, we keep center_x, center_y = NaN.
    
    We then write exactly one row per frame to the CSV:
      [frame_number, data_index, center_x, center_y].
    This ensures the CSV has length == n_images, covering all frames.

    xatol, fatol: used for the center search in process_single_image.
    xmin..ymax: bounding region for valid centers. If out-of-bounds or invalid,
    we store NaN but still write a row.
    """

    # Open the file to see how many images total
    with h5py.File(image_file, 'r') as f:
        n_images = f['/entry/data/images'].shape[0]

    # Load the entire data_index array or create a default
    full_data_index = load_data_index_or_default(image_file)  # length = n_images

    # We'll store the final results in a DataFrame of shape n_images x 4
    # Initialize with all NaNs for centers
    df_all = pd.DataFrame({
        "frame_number": np.arange(n_images),
        "data_index": full_data_index,
        "center_x": np.full(n_images, np.nan, dtype=float),
        "center_y": np.full(n_images, np.nan, dtype=float),
    })

    # Decide which frames we actually try to process
    frames_to_process = set([0, n_images-1]) | {i for i in range(n_images) if i % frame_interval == 0}

    # Output CSV path
    csv_file = os.path.join(
        os.path.dirname(image_file),
        f"centers_xatol_{xatol}_frameinterval_{frame_interval}.csv"
    )
    if os.path.exists(csv_file):
        os.remove(csv_file)

    start_time = time.time()

    # We'll process in chunks to avoid loading all images at once
    with Pool() as pool:
        for start_idx in range(0, n_images, chunk_size):
            end_idx = min(start_idx + chunk_size, n_images)
            chunk_frame_list = [fn for fn in range(start_idx, end_idx) if fn in frames_to_process]
            if not chunk_frame_list:
                continue

            # Load images [start_idx:end_idx]
            with h5py.File(image_file, 'r') as f:
                images_chunk = f['/entry/data/images'][start_idx:end_idx].astype(np.float32)

            # Prepare arguments
            args = []
            for fn in chunk_frame_list:
                img_local = images_chunk[fn - start_idx]  # local index in chunk
                args.append((img_local, mask, n_wedges, n_rad_bins, xatol, fatol, verbose))

            # Parallel center finding
            results = pool.starmap(process_single_image, args)

            # Fill in the df_all with results
            for fn, (cx, cy) in zip(chunk_frame_list, results):
                if not (np.isfinite(cx) and np.isfinite(cy) and
                        xmin <= cx < xmax and ymin <= cy < ymax):
                    cx, cy = np.nan, np.nan
                df_all.at[fn, "center_x"] = cx
                df_all.at[fn, "center_y"] = cy

    elapsed = time.time() - start_time
    print(f"Processed {len(frames_to_process)} frames with center-finding in {elapsed:.1f}s")

    # Finally, write the CSV of length == n_images
    df_all.to_csv(csv_file, index=False)
    print(f"CSV (with {n_images} rows) written to: {csv_file}")


# --- UI Widgets for Section 1 ---
image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select Mask H5 File"
mask_file_chooser.filter_pattern = "*.h5"

use_mask_checkbox = widgets.Checkbox(value=True, description="Use Mask")

xatol_widget = widgets.FloatText(value=0.01, description="xatol:", layout=widgets.Layout(width="140px"))
frame_interval_widget = widgets.IntText(value=10, description="Interval:", layout=widgets.Layout(width="140px"))
verbose_checkbox = widgets.Checkbox(value=False, description="Verbose")

xmin_widget = widgets.IntText(value=0, description="xmin:", layout=widgets.Layout(width="140px"))
xmax_widget = widgets.IntText(value=2048, description="xmax:", layout=widgets.Layout(width="140px"))
ymin_widget = widgets.IntText(value=0, description="ymin:", layout=widgets.Layout(width="140px"))
ymax_widget = widgets.IntText(value=2048, description="ymax:", layout=widgets.Layout(width="140px"))

process_images_button = widgets.Button(description="Process Images", button_style="primary")
output_area = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_process_images_clicked(b):
    with output_area:
        clear_output()
        image_file = image_file_chooser.selected
        if not image_file:
            print("Please select an H5 image file.")
            return

        mask_file = mask_file_chooser.selected
        if not mask_file:
            print("Please select a mask H5 file.")
            return

        try:
            with h5py.File(mask_file, 'r') as f_mask:
                if use_mask_checkbox.value:
                    mask = f_mask['/mask'][:].astype(bool)
                else:
                    sample_mask = f_mask['/mask'][0]
                    mask = np.ones_like(sample_mask, dtype=bool)
        except Exception as e:
            print("Error loading mask file:", e)
            return

        xatol_val = xatol_widget.value
        frame_interval_val = frame_interval_widget.value
        verbose_val = verbose_checkbox.value
        
        xmin_val = xmin_widget.value
        xmax_val = xmax_widget.value
        ymin_val = ymin_widget.value
        ymax_val = ymax_widget.value

        print("Starting Section 1: Process Images, output CSV with n_images rows")
        print(f"Image file: {image_file}")
        print(f"Mask file: {mask_file}")
        print(f"xatol={xatol_val}, Interval={frame_interval_val}, Verbose={verbose_val}")
        print(f"Bounds: xmin={xmin_val}, xmax={xmax_val}, ymin={ymin_val}, ymax={ymax_val}")

        process_images(
            image_file,
            mask,
            xatol=xatol_val,
            frame_interval=frame_interval_val,
            verbose=verbose_val,
            xmin=xmin_val,
            xmax=xmax_val,
            ymin=ymin_val,
            ymax=ymax_val
        )

process_images_button.on_click(on_process_images_clicked)

process_images_ui = widgets.VBox([
    widgets.HTML("<h2>Section 1: Process Images from H5 File (One Row Per Frame)</h2>"),
    image_file_chooser,
    mask_file_chooser,
    use_mask_checkbox,
    widgets.HBox([xatol_widget, frame_interval_widget, verbose_checkbox]),
    widgets.HBox([xmin_widget, xmax_widget, ymin_widget, ymax_widget]),
    process_images_button,
    output_area
])

##########################################################################
# SECTION 2: LOWESS-FIT + UPDATE H5
##########################################################################

# PART A: Lowess-Fit & SHIFT
csv_file_chooser = FileChooser(os.getcwd())
csv_file_chooser.title = "Select CSV from Section 1"
csv_file_chooser.filter_pattern = "*.csv"

shift_x_widget = widgets.FloatText(value=0, description="Shift X:", layout=widgets.Layout(width="150px"))
shift_y_widget = widgets.FloatText(value=0, description="Shift Y:", layout=widgets.Layout(width="150px"))

lowess_frac_widget = widgets.FloatSlider(
    value=0.1, min=0.01, max=1.0, step=0.01,
    description="Lowess frac:",
    continuous_update=False,
    layout=widgets.Layout(width="300px")
)

process_csv_button = widgets.Button(description="Lowess & Save CSV", button_style="primary")
csv_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
shifted_csv_path = None  # will hold path to final "smoothed" CSV

def on_process_csv_clicked(b):
    global shifted_csv_path
    with csv_output:
        clear_output()

        input_csv = csv_file_chooser.selected
        if not input_csv:
            print("Please select the CSV file from Section 1.")
            return

        # Read the CSV with n_images rows
        try:
            df = pd.read_csv(input_csv)
        except Exception as e:
            print(f"Error reading CSV: {e}")
            return
        
        # Must contain these columns
        required_cols = ["frame_number", "data_index", "center_x", "center_y"]
        for col in required_cols:
            if col not in df.columns:
                print(f"CSV must contain '{col}' column.")
                return

        # Sort by data_index for smoothing
        df = df.sort_values("data_index").reset_index(drop=True)

        # Extract the columns
        data_idx_all = df["data_index"].values
        cx_all       = df["center_x"].values
        cy_all       = df["center_y"].values
        frame_num_all= df["frame_number"].values  # same length

        # We'll do LOWESS on only the valid (non-NaN) centers
        valid_mask = ~np.isnan(cx_all) & ~np.isnan(cy_all)
        idx_valid  = data_idx_all[valid_mask]
        cx_valid   = cx_all[valid_mask]
        cy_valid   = cy_all[valid_mask]

        frac_val = lowess_frac_widget.value

        if len(idx_valid) < 2:
            print("Too few valid points for LOWESS. The final centers remain as-is.")
            df_smoothed = df.copy()  # No changes
        else:
            # Perform LOWESS on [min..max] of data_index
            min_idx, max_idx = idx_valid.min(), idx_valid.max()

            lowess_x = lowess(cx_valid, idx_valid, frac=frac_val, return_sorted=True)
            lowess_y = lowess(cy_valid, idx_valid, frac=frac_val, return_sorted=True)

            all_idx = np.arange(min_idx, max_idx+1)
            # Interpolate
            interp_x = np.interp(all_idx, lowess_x[:,0], lowess_x[:,1])
            interp_y = np.interp(all_idx, lowess_y[:,0], lowess_y[:,1])

            # Apply shift
            shift_x = shift_x_widget.value
            shift_y = shift_y_widget.value
            interp_x += shift_x
            interp_y += shift_y

            # Build a dict for quick lookup
            idx2sx = dict(zip(all_idx, interp_x))
            idx2sy = dict(zip(all_idx, interp_y))

            # Construct a new DataFrame
            df_smoothed = df.copy()
            for i in range(len(df_smoothed)):
                di = df_smoothed.at[i, "data_index"]
                if (di >= min_idx) and (di <= max_idx):
                    df_smoothed.at[i, "center_x"] = idx2sx[di]
                    df_smoothed.at[i, "center_y"] = idx2sy[di]
                else:
                    # data_index outside the valid range => keep original or set NaN
                    pass

        # Show a quick plot comparing only the valid points
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))

        axs[0].plot(idx_valid, cx_valid, 'o--', label='Original X (valid)', markersize=4)
        axs[1].plot(idx_valid, cy_valid, 'o--', label='Original Y (valid)', markersize=4)

        # Extract the newly smoothed ones at the same valid_mask
        s_cx_valid = df_smoothed.loc[valid_mask, "center_x"].values
        s_cy_valid = df_smoothed.loc[valid_mask, "center_y"].values

        axs[0].plot(idx_valid, s_cx_valid, 'o-', label='Smoothed+Shift X', markersize=4)
        axs[1].plot(idx_valid, s_cy_valid, 'o-', label='Smoothed+Shift Y', markersize=4)

        axs[0].set_title('Center X vs. data_index')
        axs[1].set_title('Center Y vs. data_index')
        axs[0].legend()
        axs[1].legend()
        plt.show()

        # Write out the final CSV (same length as original)
        shifted_csv_path = os.path.join(
            os.path.dirname(input_csv),
            f"centers_lowess_{frac_val:.2f}_shifted.csv"
        )
        df_smoothed.to_csv(shifted_csv_path, index=False)
        print(f"Final CSV with {len(df_smoothed)} rows saved to:\n{shifted_csv_path}")

process_csv_button.on_click(on_process_csv_clicked)

lowess_ui = widgets.VBox([
    widgets.HTML("<h2>Section 2A: Lowess-Fit & Shift (Same Frame Count)</h2>"),
    csv_file_chooser,
    widgets.HBox([shift_x_widget, shift_y_widget]),
    lowess_frac_widget,
    process_csv_button,
    csv_output
])

# PART B: UPDATE H5
image_file_chooser_h5 = FileChooser(os.getcwd())
image_file_chooser_h5.title = "Select H5 File to Update"
image_file_chooser_h5.filter_pattern = "*.h5"

update_h5_button = widgets.Button(description="Update H5 with Smoothed Centers", button_style="primary")
h5_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_update_h5_clicked(b):
    with h5_output:
        clear_output()
        if shifted_csv_path is None:
            print("No shifted CSV available. Please run the 'Lowess & Save CSV' step first.")
            return
        image_file = image_file_chooser_h5.selected
        if not image_file:
            print("Please select an H5 file to update.")
            return

        new_h5_path = os.path.join(
            os.path.dirname(image_file),
            os.path.splitext(os.path.basename(shifted_csv_path))[0] + '.h5'
        )

        try:
            create_updated_h5(image_file, new_h5_path, shifted_csv_path)
            print(f"Updated H5 file created at:\n{new_h5_path}")
        except Exception as e:
            print("Error updating H5 file:", e)

update_h5_button.on_click(on_update_h5_clicked)

h5_ui = widgets.VBox([
    widgets.HTML("<h2>Section 2B: Update H5 with Smoothed Centers</h2>"),
    image_file_chooser_h5,
    update_h5_button,
    h5_output
])

# Combine Section 2 UI
csv_h5_ui = widgets.VBox([lowess_ui, h5_ui])

##########################################################################
# FINAL: COMBINED UI
##########################################################################
tab = widgets.Tab(children=[process_images_ui, csv_h5_ui])
tab.set_title(0, "Process Images")
tab.set_title(1, "Lowess & H5 Update")
display(tab)


Tab(children=(VBox(children=(HTML(value='<h2>Process Images from H5 File</h2>'), FileChooser(path='/Users/xiao…