In [5]:
# %matplotlib qt
import os
import numpy as np
import pandas as pd
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output

# Lowess from statsmodels
from statsmodels.nonparametric.smoothers_lowess import lowess

# Your custom function for updating .h5 (unchanged)
from update_h5 import create_updated_h5

# -------------------------
# Section 1: Lowess-Fit Centers in CSV (fill missing frames) + Shift
# -------------------------

csv_file_chooser = FileChooser(os.getcwd())
csv_file_chooser.title = "Select Input CSV File"
csv_file_chooser.filter_pattern = "*.csv"

shift_x_widget = widgets.FloatText(
    value=0,
    description="Shift X:",
    layout=widgets.Layout(width="200px")
)

shift_y_widget = widgets.FloatText(
    value=0,
    description="Shift Y:",
    layout=widgets.Layout(width="200px")
)

# Widget to set the Lowess smoothing fraction.
lowess_frac_widget = widgets.FloatSlider(
    value=0.1,
    min=0.01,
    max=1.0,
    step=0.01,
    description="Lowess frac:",
    continuous_update=False,
    layout=widgets.Layout(width="300px")
)

process_csv_button = widgets.Button(
    description="Lowess & Save CSV",
    button_style="primary"
)

csv_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
shifted_csv_path = None  # Global to store final path of CSV

def on_process_csv_clicked(b):
    global shifted_csv_path
    with csv_output:
        clear_output()

        input_csv = csv_file_chooser.selected
        if not input_csv:
            print("Please select an input CSV file.")
            return

        try:
            df = pd.read_csv(input_csv)
        except Exception as e:
            print(f"Error reading CSV: {e}")
            return

        # Ensure CSV has these columns
        for col in ['frame_number','center_x','center_y']:
            if col not in df.columns:
                print(f"CSV must contain '{col}' column.")
                return

        # Sort by frame_number just in case
        df = df.sort_values('frame_number').reset_index(drop=True)

        # Extract data
        frames = df['frame_number'].values
        original_x = df['center_x'].values
        original_y = df['center_y'].values

        # Compute Lowess fraction from widget
        frac_val = lowess_frac_widget.value

        # 1) Fit Lowess to existing data (no extra interpolation!)
        #    returns sorted arrays of shape (N, 2) => (x, smoothed_y)
        lowess_x = lowess(endog=original_x, exog=frames, frac=frac_val, return_sorted=True)
        lowess_y = lowess(endog=original_y, exog=frames, frac=frac_val, return_sorted=True)

        # 2) Generate the new set of frames: from min to max (inclusive)
        min_frame = frames.min()
        max_frame = frames.max()
        all_frames = np.arange(min_frame, max_frame + 1)

        # 3) Evaluate the Lowess fit at *every* integer frame
        #    np.interp uses linear interpolation between the Lowess points.
        smoothed_x = np.interp(all_frames, lowess_x[:,0], lowess_x[:,1])
        smoothed_y = np.interp(all_frames, lowess_y[:,0], lowess_y[:,1])

        # 4) Apply shifts to the Lowess results
        shift_x = shift_x_widget.value
        shift_y = shift_y_widget.value
        smoothed_x += shift_x
        smoothed_y += shift_y

        # 5) Create a new DataFrame with every frame and the (shifted) smoothed centers
        output_df = pd.DataFrame({
            'frame_number': all_frames,
            'center_x': smoothed_x,
            'center_y': smoothed_y
        })

        # 6) Plot original data and new smoothed results
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))

        # Left: X
        axs[0].plot(frames, original_x, 'o--', label='Original X', markersize=4)
        axs[0].plot(all_frames, smoothed_x, 'o-', label='Lowess + shift X', markersize=4)
        axs[0].set_title('Center X')
        axs[0].legend()

        # Right: Y
        axs[1].plot(frames, original_y, 'o--', label='Original Y', markersize=4)
        axs[1].plot(all_frames, smoothed_y, 'o-', label='Lowess + shift Y', markersize=4)
        axs[1].set_title('Center Y')
        axs[1].legend()

        plt.show()

        # 7) Save to CSV
        shifted_csv_path = os.path.join(
            os.path.dirname(input_csv),
            f"centers_lowess_{frac_val:.3f}_shifted_{shift_x}_{shift_y}.csv"
        )
        output_df.to_csv(shifted_csv_path, index=False)
        print(f"Created CSV with smoothed centers for every frame:\n{shifted_csv_path}")

process_csv_button.on_click(on_process_csv_clicked)

csv_ui = widgets.VBox([
    widgets.HTML("<h2>Lowess-Fit (Fill Missing Frames) + Shift</h2>"),
    csv_file_chooser,
    widgets.HBox([shift_x_widget, shift_y_widget]),
    lowess_frac_widget,
    process_csv_button,
    csv_output
])

# -------------------------
# Section 2: Update H5 with New Centers
# -------------------------

image_file_chooser_h5 = FileChooser(os.getcwd())
image_file_chooser_h5.title = "Select H5 Image File for Updating"
image_file_chooser_h5.filter_pattern = "*.h5"

update_h5_button = widgets.Button(
    description="Update H5 with Shifted Centers",
    button_style="primary"
)

h5_output = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})

def on_update_h5_clicked(b):
    with h5_output:
        clear_output()
        if shifted_csv_path is None:
            print("No shifted CSV available. Please run the CSV shifting process first.")
            return
        
        image_file = image_file_chooser_h5.selected
        if not image_file:
            print("Please select an H5 image file.")
            return
        
        new_h5_path = os.path.join(
            os.path.dirname(image_file),
            os.path.splitext(os.path.basename(image_file))[0] + os.path.basename(shifted_csv_path)
        )
        
        try:
            create_updated_h5(image_file, new_h5_path, shifted_csv_path)
            print(f"Updated H5 file created at:\n{new_h5_path}")
        except Exception as e:
            print("Error updating H5 file:", e)

update_h5_button.on_click(on_update_h5_clicked)

h5_ui = widgets.VBox([
    widgets.HTML("<h2>Update H5 with Shifted Centers</h2>"),
    image_file_chooser_h5,
    update_h5_button,
    h5_output
])

# -------------------------
# Final UI
# -------------------------
ui = widgets.VBox([csv_ui, h5_ui])
display(ui)


VBox(children=(VBox(children=(HTML(value='<h2>Lowess-Fit (Fill Missing Frames) + Shift</h2>'), FileChooser(pat…