In [None]:
#!/usr/bin/env python3
import os
import time
import h5py
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
from multiprocessing import Pool
from statsmodels.nonparametric.smoothers_lowess import lowess
from tqdm import tqdm
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore', message='invalid value encountered in subtract', category=RuntimeWarning)

# ------------------------------------------------------------------------
# Custom modules (adjust or remove if not needed):
from image_processing import process_single_image
from update_h5 import create_updated_h5

# ------------------------------------------------------------------------
# HELPER FUNCTION (TOP LEVEL, so it's picklable)
# ------------------------------------------------------------------------
def compute_center_for_frame(args):
    """
    Helper function for multiprocessing.
    Loads one frame from the H5, calls process_single_image, returns (frame_num, center_x, center_y).
    If the center is invalid or out-of-bounds, it returns NaN.
    """
    (frame_num, image_file, mask, n_wedges, n_rad_bins,
     xatol, fatol, verbose, xmin, xmax, ymin, ymax) = args

    # Load the image for this frame
    with h5py.File(image_file, 'r') as f:
        img = f['/entry/data/images'][frame_num].astype(np.float32)

    # Run the center-finding function
    cx, cy = process_single_image(img, mask, n_wedges, n_rad_bins, xatol, fatol, verbose)

    # If out-of-bounds or invalid => NaN
    if not (np.isfinite(cx) and np.isfinite(cy) and
            xmin <= cx < xmax and ymin <= cy < ymax):
        cx, cy = np.nan, np.nan

    return frame_num, cx, cy

# ------------------------------------------------------------------------
# SECTION 1: PROCESS IMAGES (ONE ROW PER FRAME)
# ------------------------------------------------------------------------
def process_images_no_chunk(
    image_file,
    mask,
    frame_interval=10,
    xatol=0.01,
    fatol=10,
    n_wedges=4,
    n_rad_bins=100,
    xmin=0,
    xmax=9999999,
    ymin=0,
    ymax=9999999,
    verbose=False
):
    """
    Creates a CSV with exactly one row per frame_number (0..n_images-1),
    while only physically loading frames you want to process:
      - first (0),
      - last (n_images-1),
      - multiples of frame_interval.

    If '/entry/data/index' exists, store it in 'data_index';
    otherwise, 'data_index = frame_number'.

    The resulting CSV has columns [frame_number, data_index, center_x, center_y],
    length == n_images. Unprocessed frames remain NaN for center_x, center_y.
    """

    # 1) Determine how many frames, plus read /entry/data/index if present
    with h5py.File(image_file, 'r') as f:
        n_images = f['/entry/data/images'].shape[0]
        index_dset = f.get('/entry/data/index')
        if index_dset is not None:
            data_index_all = index_dset[:]
        else:
            data_index_all = np.arange(n_images)

    # 2) Create a DataFrame with n_images rows, initialized to NaN centers
    df = pd.DataFrame({
        "frame_number": np.arange(n_images),
        "data_index": data_index_all,
        "center_x": np.full(n_images, np.nan, dtype=float),
        "center_y": np.full(n_images, np.nan, dtype=float),
    })

    # 3) Identify frames we actually process
    frames_to_process = set([0, n_images - 1]) | {
        i for i in range(n_images) if i % frame_interval == 0
    }
    frames_to_process = sorted(frames_to_process)

    # 4) Build argument tuples for multiprocessing
    tasks = []
    for fn in frames_to_process:
        tasks.append((
            fn,         # frame_num
            image_file, # pass path, not data
            mask,
            n_wedges, n_rad_bins,
            xatol, fatol,
            verbose,
            xmin, xmax, ymin, ymax
        ))

    # 5) Parallel center-finding
    start_time = time.time()
    with Pool() as pool:
        results = list(
            tqdm(
                pool.imap(compute_center_for_frame, tasks),
                total=len(tasks),
                desc="Processing frames"
            )
        )

    # 6) Place the results back into df
    for (fn, cx, cy) in results:
        df.at[fn, "center_x"] = cx
        df.at[fn, "center_y"] = cy

    # 7) Write CSV
    csv_file = os.path.join(
        os.path.dirname(image_file),
        f"centers_xatol_{xatol}_frameinterval_{frame_interval}.csv"
    )
    df.to_csv(csv_file, index=False)

    elapsed = time.time() - start_time
    print(f"Created CSV with {len(df)} rows in {elapsed:.1f}s:\n{csv_file}")

# ------------------------------------------------------------------------
# SECTION 1: UI
# ------------------------------------------------------------------------
image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select Mask H5 File"
mask_file_chooser.filter_pattern = "*.h5"

use_mask_checkbox = widgets.Checkbox(value=True, description="Use Mask")

xatol_widget = widgets.FloatText(value=0.01, description="xatol:", layout=widgets.Layout(width="140px"))
frame_interval_widget = widgets.IntText(value=10, description="Interval:", layout=widgets.Layout(width="140px"))
verbose_checkbox = widgets.Checkbox(value=False, description="Verbose")

xmin_widget = widgets.IntText(value=0, description="xmin:", layout=widgets.Layout(width="140px"))
xmax_widget = widgets.IntText(value=2048, description="xmax:", layout=widgets.Layout(width="140px"))
ymin_widget = widgets.IntText(value=0, description="ymin:", layout=widgets.Layout(width="140px"))
ymax_widget = widgets.IntText(value=2048, description="ymax:", layout=widgets.Layout(width="140px"))

process_images_button = widgets.Button(description="Process Images", button_style="primary")
output_area = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_process_images_clicked(b):
    with output_area:
        clear_output()
        image_file = image_file_chooser.selected
        if not image_file:
            print("Please select an H5 image file.")
            return

        mask_file = mask_file_chooser.selected
        if not mask_file:
            print("Please select a mask H5 file.")
            return

        try:
            with h5py.File(mask_file, 'r') as f_mask:
                if use_mask_checkbox.value:
                    mask = f_mask['/mask'][:].astype(bool)
                else:
                    # If "not using" mask, let's just set all True
                    sample = f_mask['/mask'][0]
                    mask = np.ones_like(sample, dtype=bool)
        except Exception as e:
            print("Error loading mask file:", e)
            return

        xatol_val = xatol_widget.value
        frame_interval_val = frame_interval_widget.value
        verbose_val = verbose_checkbox.value
        
        xmin_val = xmin_widget.value
        xmax_val = xmax_widget.value
        ymin_val = ymin_widget.value
        ymax_val = ymax_widget.value

        print("Processing images to create a CSV with one row per frame...")
        process_images_no_chunk(
            image_file=image_file,
            mask=mask,
            frame_interval=frame_interval_val,
            xatol=xatol_val,
            fatol=10,
            n_wedges=4,
            n_rad_bins=100,
            xmin=xmin_val,
            xmax=xmax_val,
            ymin=ymin_val,
            ymax=ymax_val,
            verbose=verbose_val
        )
        print("Done.")

process_images_button.on_click(on_process_images_clicked)

process_images_ui = widgets.VBox([
    widgets.HTML("<h2>Section 1: Process Images (One Row Per Frame)</h2>"),
    image_file_chooser,
    mask_file_chooser,
    use_mask_checkbox,
    widgets.HBox([xatol_widget, frame_interval_widget, verbose_checkbox]),
    widgets.HBox([xmin_widget, xmax_widget, ymin_widget, ymax_widget]),
    process_images_button,
    output_area
])

# ------------------------------------------------------------------------
# SECTION 2: LOWESS-FIT & H5 UPDATE
# ------------------------------------------------------------------------
csv_file_chooser = FileChooser(os.getcwd())
csv_file_chooser.title = "Select CSV From Section 1"
csv_file_chooser.filter_pattern = "*.csv"

shift_x_widget = widgets.FloatText(value=0, description="Shift X:", layout=widgets.Layout(width="150px"))
shift_y_widget = widgets.FloatText(value=0, description="Shift Y:", layout=widgets.Layout(width="150px"))

lowess_frac_widget = widgets.FloatSlider(
    value=0.1, min=0.01, max=1.0, step=0.01,
    description="Lowess frac:",
    continuous_update=False,
    layout=widgets.Layout(width="300px")
)

process_csv_button = widgets.Button(description="Lowess & Save CSV", button_style="primary")
csv_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
shifted_csv_path = None

def on_process_csv_clicked(b):
    global shifted_csv_path
    with csv_output:
        clear_output()

        input_csv = csv_file_chooser.selected
        if not input_csv:
            print("Please select a CSV from Section 1.")
            return

        # Read CSV
        try:
            df = pd.read_csv(input_csv)
        except Exception as e:
            print(f"Error reading CSV: {e}")
            return

        # Must contain these columns
        for col in ["frame_number", "data_index", "center_x", "center_y"]:
            if col not in df.columns:
                print(f"CSV must contain '{col}' column.")
                return

        # Sort by data_index for the smoothing
        df = df.sort_values("data_index").reset_index(drop=True)
        idx_all  = df["data_index"].values
        cx_all   = df["center_x"].values
        cy_all   = df["center_y"].values

        valid_mask = ~np.isnan(cx_all) & ~np.isnan(cy_all)
        idx_valid  = idx_all[valid_mask]
        cx_valid   = cx_all[valid_mask]
        cy_valid   = cy_all[valid_mask]

        frac_val = lowess_frac_widget.value

        if len(idx_valid) < 2:
            print("Too few valid points for a LOWESS fit. We'll leave all centers as-is.")
            df_smoothed = df.copy()
        else:
            min_idx, max_idx = idx_valid.min(), idx_valid.max()
            lowess_x = lowess(cx_valid, idx_valid, frac=frac_val, return_sorted=True)
            lowess_y = lowess(cy_valid, idx_valid, frac=frac_val, return_sorted=True)

            # Interpolate at all integer data_index in [min_idx..max_idx]
            all_idx = np.arange(min_idx, max_idx + 1)
            smoothed_x = np.interp(all_idx, lowess_x[:,0], lowess_x[:,1])
            smoothed_y = np.interp(all_idx, lowess_y[:,0], lowess_y[:,1])

            # Apply user shift
            shift_x = shift_x_widget.value
            shift_y = shift_y_widget.value
            smoothed_x += shift_x
            smoothed_y += shift_y

            # Build a lookup
            idx2sx = dict(zip(all_idx, smoothed_x))
            idx2sy = dict(zip(all_idx, smoothed_y))

            # Construct a new DataFrame with updated centers
            df_smoothed = df.copy()
            for i in range(len(df_smoothed)):
                di = df_smoothed.at[i, "data_index"]
                if min_idx <= di <= max_idx:
                    df_smoothed.at[i, "center_x"] = idx2sx[di]
                    df_smoothed.at[i, "center_y"] = idx2sy[di]

        # Plot the original vs. smoothed (for valid points)
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        axs[0].plot(idx_valid, cx_valid, 'o--', label='Original X (valid)', markersize=4)
        axs[1].plot(idx_valid, cy_valid, 'o--', label='Original Y (valid)', markersize=4)

        s_cx_valid = df_smoothed.loc[valid_mask, "center_x"].values
        s_cy_valid = df_smoothed.loc[valid_mask, "center_y"].values

        axs[0].plot(idx_valid, s_cx_valid, 'o-', label='Smoothed X', markersize=4)
        axs[1].plot(idx_valid, s_cy_valid, 'o-', label='Smoothed Y', markersize=4)
        axs[0].set_title("Center X vs data_index")
        axs[1].set_title("Center Y vs data_index")
        axs[0].legend()
        axs[1].legend()
        plt.show()

        # Write final CSV
        shifted_csv_path = os.path.join(
            os.path.dirname(input_csv),
            f"centers_lowess_{frac_val:.2f}_shifted.csv"
        )
        df_smoothed.to_csv(shifted_csv_path, index=False)
        print(f"Smoothed CSV saved:\n{shifted_csv_path}")

process_csv_button.on_click(on_process_csv_clicked)

lowess_ui = widgets.VBox([
    widgets.HTML("<h2>Section 2A: Lowess-Fit & Shift</h2>"),
    csv_file_chooser,
    widgets.HBox([shift_x_widget, shift_y_widget]),
    lowess_frac_widget,
    process_csv_button,
    csv_output
])

image_file_chooser_h5 = FileChooser(os.getcwd())
image_file_chooser_h5.title = "Select H5 File to Update"
image_file_chooser_h5.filter_pattern = "*.h5"

update_h5_button = widgets.Button(description="Update H5 with Smoothed Centers", button_style="primary")
h5_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_update_h5_clicked(b):
    with h5_output:
        clear_output()
        if shifted_csv_path is None:
            print("No shifted CSV available. Please run the 'Lowess & Save CSV' step first.")
            return
        image_file = image_file_chooser_h5.selected
        if not image_file:
            print("Please select an H5 file to update.")
            return

        new_h5_path = os.path.join(
            os.path.dirname(image_file),
            os.path.splitext(os.path.basename(shifted_csv_path))[0] + '.h5'
        )

        try:
            create_updated_h5(image_file, new_h5_path, shifted_csv_path)
            print(f"Updated H5 file created at:\n{new_h5_path}")
        except Exception as e:
            print("Error updating H5 file:", e)

update_h5_button.on_click(on_update_h5_clicked)

h5_ui = widgets.VBox([
    widgets.HTML("<h2>Section 2B: Update H5</h2>"),
    image_file_chooser_h5,
    update_h5_button,
    h5_output
])

csv_h5_ui = widgets.VBox([lowess_ui, h5_ui])

# ------------------------------------------------------------------------
# FINAL COMBINED UI
# ------------------------------------------------------------------------
tab = widgets.Tab(children=[process_images_ui, csv_h5_ui])
tab.set_title(0, "Process Images")
tab.set_title(1, "Lowess & H5 Update")
display(tab)
