In [None]:
import os
import time
import h5py
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
from multiprocessing import Pool
from tqdm import tqdm
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore', message='invalid value encountered in subtract', category=RuntimeWarning)

# Custom functions you may need:
from image_processing import process_single_image
from update_h5 import create_updated_h5
from statsmodels.nonparametric.smoothers_lowess import lowess


##########################################################################
# SECTION 1: PROCESS IMAGES + CREATE CSV
##########################################################################

def load_chunk(image_file, start, end):
    """
    Loads a chunk of images and the dataset '/entry/data/index' from the H5 file.
    If '/entry/data/index' is not present, we create an index equal to the frame numbers.
    """
    with h5py.File(image_file, 'r') as f:
        images = f['/entry/data/images']
        chunk_images = images[start:end].astype(np.float32)

        index_dataset = f.get('/entry/data/index')
        if index_dataset is not None:
            # e.g. might have shape [6280], but range up to 27025
            indices = index_dataset[start:end]
        else:
            # No /entry/data/index => use frame numbers as data_index
            indices = np.arange(start, end)

    return chunk_images, indices

def process_images(image_file,
                   mask,
                   n_wedges=4,
                   n_rad_bins=100,
                   xatol=0.01,
                   fatol=10,
                   chunk_size=1000,
                   frame_interval=10,
                   verbose=False,
                   xmin=0,
                   xmax=9999999,
                   ymin=0,
                   ymax=9999999):
    """
    Processes images in chunks and writes valid or invalid centers to a CSV.
    Only frames 0, the last frame, and multiples of frame_interval (based on frame_number)
    are processed.

    If a center is out-of-bounds or invalid, we still write a row for that frame
    but set (center_x, center_y) = (NaN, NaN), so we never skip an entry.

    Columns in the CSV: [frame_number, data_index, center_x, center_y]
    - frame_number: 0..(n_images-1)
    - data_index: from '/entry/data/index' if present; else same as frame_number
    """
    # Determine how many frames exist in the H5 for "frame_number"
    with h5py.File(image_file, 'r') as f_img:
        n_images = f_img['/entry/data/images'].shape[0]

    # We'll process frames 0, last, and multiples of frame_interval
    valid_frames = set([0, n_images - 1]) | {i for i in range(n_images) if i % frame_interval == 0}
    total_centers = len(valid_frames)

    # Create output CSV path in the same folder as the image_file
    csv_file = os.path.join(
        os.path.dirname(image_file),
        f"centers_xatol_{xatol}_frameinterval_{frame_interval}.csv"
    )
    if os.path.exists(csv_file):
        os.remove(csv_file)

    header_written = False
    start_time = time.time()
    pbar = tqdm(total=total_centers, desc="Calculating centers")

    with Pool() as pool:
        for start_idx in range(0, n_images, chunk_size):
            end_idx = min(start_idx + chunk_size, n_images)
            # Which frames in [start_idx, end_idx) do we actually want to process?
            chunk_frame_numbers = [fn for fn in range(start_idx, end_idx) if fn in valid_frames]
            if not chunk_frame_numbers:
                continue

            # Load the images and either /entry/data/index or a fallback
            current_chunk, current_indices = load_chunk(image_file, start_idx, end_idx)

            # Prepare arguments for parallel processing
            args = []
            for fn in chunk_frame_numbers:
                img = current_chunk[fn - start_idx]
                args.append((img, mask, n_wedges, n_rad_bins, xatol, fatol, verbose))

            # Run the center-finding
            results = pool.starmap(process_single_image, args)

            # Build rows to write
            rows = []
            for fn, (cx, cy) in zip(chunk_frame_numbers, results):
                # If out-of-bounds or invalid, store NaN
                if not (np.isfinite(cx) and np.isfinite(cy) and
                        xmin <= cx < xmax and ymin <= cy < ymax):
                    cx, cy = np.nan, np.nan

                # data_index might be smaller/larger than fn if there's skipping
                data_index = current_indices[fn - start_idx]
                rows.append([fn, data_index, cx, cy])

            if rows:
                df_chunk = pd.DataFrame(rows, columns=["frame_number", "data_index", "center_x", "center_y"])
                mode = "w" if not header_written else "a"
                df_chunk.to_csv(csv_file, index=False, mode=mode, header=not header_written)
                header_written = True

            pbar.update(len(chunk_frame_numbers))

    pbar.close()
    elapsed = time.time() - start_time
    print(f"Processing complete in {elapsed:.1f}s")
    print("CSV file written to:", csv_file)


# --- Widgets for "Process Images" UI ---
image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select Mask H5 File"
mask_file_chooser.filter_pattern = "*.h5"

use_mask_checkbox = widgets.Checkbox(value=True, description="Use Mask")

xatol_widget = widgets.FloatText(value=0.01, description="xatol:", layout=widgets.Layout(width="140px"))
frame_interval_widget = widgets.IntText(value=10, description="Interval:", layout=widgets.Layout(width="140px"))
verbose_checkbox = widgets.Checkbox(value=False, description="Verbose")

xmin_widget = widgets.IntText(value=450, description="xmin:", layout=widgets.Layout(width="140px"))
xmax_widget = widgets.IntText(value=550, description="xmax:", layout=widgets.Layout(width="140px"))
ymin_widget = widgets.IntText(value=450, description="ymin:", layout=widgets.Layout(width="140px"))
ymax_widget = widgets.IntText(value=550, description="ymax:", layout=widgets.Layout(width="140px"))

process_images_button = widgets.Button(description="Process Images", button_style="primary")
output_area = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_process_images_clicked(b):
    with output_area:
        clear_output()
        image_file = image_file_chooser.selected
        if not image_file:
            print("Please select an H5 image file.")
            return

        mask_file = mask_file_chooser.selected
        if not mask_file:
            print("Please select a mask H5 file.")
            return

        try:
            with h5py.File(mask_file, 'r') as f_mask:
                if use_mask_checkbox.value:
                    mask = f_mask['/mask'][:].astype(bool)
                else:
                    sample_mask = f_mask['/mask'][0]
                    mask = np.ones_like(sample_mask, dtype=bool)
        except Exception as e:
            print("Error loading mask file:", e)
            return

        xatol_val = xatol_widget.value
        frame_interval_val = frame_interval_widget.value
        verbose_val = verbose_checkbox.value
        
        xmin_val = xmin_widget.value
        xmax_val = xmax_widget.value
        ymin_val = ymin_widget.value
        ymax_val = ymax_widget.value

        print("Starting image processing...")
        print(f"Image file: {image_file}")
        print(f"Mask file: {mask_file}")
        print(f"xatol={xatol_val}, Interval={frame_interval_val}, Verbose={verbose_val}")
        print(f"Bounds: xmin={xmin_val}, xmax={xmax_val}, ymin={ymin_val}, ymax={ymax_val}")

        process_images(
            image_file,
            mask,
            xatol=xatol_val,
            frame_interval=frame_interval_val,
            verbose=verbose_val,
            xmin=xmin_val,
            xmax=xmax_val,
            ymin=ymin_val,
            ymax=ymax_val
        )

process_images_button.on_click(on_process_images_clicked)

process_images_ui = widgets.VBox([
    widgets.HTML("<h2>Process Images from H5 File</h2>"),
    image_file_chooser,
    mask_file_chooser,
    use_mask_checkbox,
    widgets.HBox([xatol_widget, frame_interval_widget, verbose_checkbox]),
    widgets.HBox([xmin_widget, xmax_widget, ymin_widget, ymax_widget]),
    process_images_button,
    output_area
])

##########################################################################
# SECTION 2: LOWESS-FIT + WRITE FULL CSV + UPDATE H5
##########################################################################

# PART A: Lowess-Fit & SHIFT
csv_file_chooser = FileChooser(os.getcwd())
csv_file_chooser.title = "Select CSV from Section 1"
csv_file_chooser.filter_pattern = "*.csv"

shift_x_widget = widgets.FloatText(value=0, description="Shift X:", layout=widgets.Layout(width="150px"))
shift_y_widget = widgets.FloatText(value=0, description="Shift Y:", layout=widgets.Layout(width="150px"))

lowess_frac_widget = widgets.FloatSlider(
    value=0.1, min=0.01, max=1.0, step=0.01,
    description="Lowess frac:",
    continuous_update=False,
    layout=widgets.Layout(width="300px")
)

process_csv_button = widgets.Button(description="Lowess & Save FULL CSV", button_style="primary")
csv_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
shifted_csv_path = None  # Will hold the path to the new CSV with *all* data_index rows

def on_process_csv_clicked(b):
    global shifted_csv_path
    with csv_output:
        clear_output()

        input_csv = csv_file_chooser.selected
        if not input_csv:
            print("Please select the CSV file with centers.")
            return

        try:
            df = pd.read_csv(input_csv)
        except Exception as e:
            print(f"Error reading CSV: {e}")
            return
        
        # Ensure required columns exist
        for col in ["frame_number", "data_index", "center_x", "center_y"]:
            if col not in df.columns:
                print(f"CSV must contain '{col}' column.")
                return
        
        # We'll do LOWESS on data_index, then create a row for *every* data_index
        # in the range [min(data_index), max(data_index)].
        min_idx = int(df["data_index"].min())
        max_idx = int(df["data_index"].max())
        print(f"Data index range: [{min_idx}..{max_idx}]")

        # Sort the CSV by data_index
        df = df.sort_values("data_index").reset_index(drop=True)
        idx_existing = df["data_index"].values
        cx_existing = df["center_x"].values
        cy_existing = df["center_y"].values
        frame_existing = df["frame_number"].values  # may not match data_index

        # We skip any NaN centers for the LOWESS fit
        valid_mask = ~np.isnan(cx_existing) & ~np.isnan(cy_existing)
        idx_valid = idx_existing[valid_mask]
        cx_valid = cx_existing[valid_mask]
        cy_valid = cy_existing[valid_mask]

        frac_val = lowess_frac_widget.value
        # If we have fewer than 2 valid points, we cannot do a real LOWESS
        if len(idx_valid) < 2:
            print("Warning: too few valid points for a meaningful LOWESS fit. Will fill all with NaN.")
            # We'll fill with NaN except for the shift (which doesn't do anything if it's NaN anyway)
            lowess_x_sorted = np.array([[0, 0]])
            lowess_y_sorted = np.array([[0, 0]])
            can_lowess = False
        else:
            lowess_x_sorted = lowess(endog=cx_valid, exog=idx_valid, frac=frac_val, return_sorted=True)
            lowess_y_sorted = lowess(endog=cy_valid, exog=idx_valid, frac=frac_val, return_sorted=True)
            can_lowess = True

        # Build the full array of all indices [min_idx..max_idx]
        all_indices = np.arange(min_idx, max_idx + 1)
        if can_lowess:
            smoothed_x_all = np.interp(all_indices, lowess_x_sorted[:, 0], lowess_x_sorted[:, 1])
            smoothed_y_all = np.interp(all_indices, lowess_y_sorted[:, 0], lowess_y_sorted[:, 1])
        else:
            smoothed_x_all = np.full(len(all_indices), np.nan)
            smoothed_y_all = np.full(len(all_indices), np.nan)

        # Apply user shifts
        shift_x = shift_x_widget.value
        shift_y = shift_y_widget.value
        smoothed_x_all += shift_x
        smoothed_y_all += shift_y

        # Rebuild the 'frame_number' column:
        # If data_index was in the original CSV, keep that frame_number. Otherwise, set -1.
        idx2frame = dict(zip(idx_existing, frame_existing))
        full_frame_nums = [idx2frame[i] if i in idx2frame else -1 for i in all_indices]

        # Build the final DataFrame with a row per data_index
        full_df = pd.DataFrame({
            "frame_number": full_frame_nums,
            "data_index": all_indices,
            "center_x": smoothed_x_all,
            "center_y": smoothed_y_all
        })

        # Save to CSV
        shifted_csv_path = os.path.join(
            os.path.dirname(input_csv),
            f"centers_lowess_full_{frac_val:.2f}_shifted_{shift_x}_{shift_y}.csv"
        )
        full_df.to_csv(shifted_csv_path, index=False)

        # Quick plot: compare the valid points vs. the new full interpolation
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        axs[0].plot(idx_valid, cx_valid, 'o--', label='Original X (valid)', markersize=4)
        axs[0].plot(all_indices, smoothed_x_all, 'o-', label='Lowess+Shift X (full)', markersize=4)
        axs[0].set_title('Center X vs data_index')
        axs[0].legend()

        axs[1].plot(idx_valid, cy_valid, 'o--', label='Original Y (valid)', markersize=4)
        axs[1].plot(all_indices, smoothed_y_all, 'o-', label='Lowess+Shift Y (full)', markersize=4)
        axs[1].set_title('Center Y vs data_index')
        axs[1].legend()
        plt.show()

        print(f"\nCreated full CSV for all data_index in [{min_idx}..{max_idx}]:\n{shifted_csv_path}")

process_csv_button.on_click(on_process_csv_clicked)

lowess_ui = widgets.VBox([
    widgets.HTML("<h2>Lowess-Fit (Fill ALL data_index) + Shift</h2>"),
    csv_file_chooser,
    widgets.HBox([shift_x_widget, shift_y_widget]),
    lowess_frac_widget,
    process_csv_button,
    csv_output
])

# PART B: UPDATE H5
image_file_chooser_h5 = FileChooser(os.getcwd())
image_file_chooser_h5.title = "Select H5 File to Update (original or copy)"
image_file_chooser_h5.filter_pattern = "*.h5"

update_h5_button = widgets.Button(description="Update H5 with Shifted Centers", button_style="primary")
h5_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_update_h5_clicked(b):
    with h5_output:
        clear_output()
        if shifted_csv_path is None:
            print("No shifted CSV available. Please run the 'Lowess & Save FULL CSV' step first.")
            return
        image_file = image_file_chooser_h5.selected
        if not image_file:
            print("Please select an H5 file to update.")
            return

        new_h5_path = os.path.join(
            os.path.dirname(image_file),
            os.path.splitext(os.path.basename(shifted_csv_path))[0] + '.h5'
        )

        try:
            create_updated_h5(image_file, new_h5_path, shifted_csv_path)
            print(f"Updated H5 file created at:\n{new_h5_path}")
        except Exception as e:
            print("Error updating H5 file:", e)

update_h5_button.on_click(on_update_h5_clicked)

h5_ui = widgets.VBox([
    widgets.HTML("<h2>Update H5 with Full Shifted Centers</h2>"),
    image_file_chooser_h5,
    update_h5_button,
    h5_output
])

# Combine the Lowess & H5 Update UIs
csv_h5_ui = widgets.VBox([lowess_ui, h5_ui])

##########################################################################
# FINAL: COMBINED TABBED INTERFACE
##########################################################################
tab = widgets.Tab(children=[process_images_ui, csv_h5_ui])
tab.set_title(0, "Process Images")
tab.set_title(1, "Lowess & H5 Update")
display(tab)


Tab(children=(VBox(children=(HTML(value='<h2>Process Images from H5 File</h2>'), FileChooser(path='/Users/xiao…