In [None]:
import os
import h5py
import numpy as np
import pandas as pd
from multiprocessing import Pool
from tqdm.notebook import tqdm  # For progress bars in Jupyter

# Suppose you have a function from your 'helpers' module:
from helpers import process_image_full

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output

# ----------------------------
#  HELPER FUNCTIONS
# ----------------------------

def compute_center_for_frame(frame_number, dataset_images, mask, 
                             threshold, max_iters, step_size, n_steps, 
                             n_wedges, n_rad_bins, plot_profiles):
    """
    Reads the specified frame from the dataset, computes the center.
    Returns (center_x, center_y).
    """
    img = dataset_images[frame_number].astype(np.float32)
    result = process_image_full(
        (frame_number, img, mask, threshold, max_iters, step_size,
         n_steps, n_wedges, n_rad_bins, plot_profiles)
    )
    # result is expected to be (frame_number, center), 
    # where center is (row, col). Swap to (x, y).
    return (result[1][1], result[1][0])

def fit_polynomial(sample_frames, sample_centers, degree=2):
    """
    Fits a polynomial of given degree for x and y vs. frame_number.
    Returns two sets of polynomial coefficients (px, py).
    """
    frames = np.array(sample_frames, dtype=float)
    centers_x = np.array([c[0] for c in sample_centers])
    centers_y = np.array([c[1] for c in sample_centers])
    
    # Fit polynomial in X vs frame, Y vs frame
    px = np.polyfit(frames, centers_x, degree)
    py = np.polyfit(frames, centers_y, degree)
    return px, py

def evaluate_polynomial(pcoeffs, frames):
    """
    Evaluate polynomial coefficients (from np.polyfit) for a set of frames.
    pcoeffs should be the output of np.polyfit.
    """
    return np.polyval(pcoeffs, frames)

def compute_residuals(sample_frames, sample_centers, px, py):
    """
    Given a polynomial fit (px, py) and the sample data, 
    returns a list of residuals (distance from fit).
    """
    frames = np.array(sample_frames, dtype=float)
    pred_x = evaluate_polynomial(px, frames)
    pred_y = evaluate_polynomial(py, frames)
    actual_x = np.array([c[0] for c in sample_centers])
    actual_y = np.array([c[1] for c in sample_centers])
    diffs = np.sqrt((pred_x - actual_x)**2 + (pred_y - actual_y)**2)
    return diffs

def remove_outliers(sample_frames, sample_centers, threshold=5.0):
    """
    Basic outlier removal: 
      1) Compute median center 
      2) Remove points that deviate from median by > threshold (in Euclidian distance)
    You could replace this with something more advanced if desired.
    Returns updated sample_frames, sample_centers with outliers removed.
    """
    arr = np.array(sample_centers)
    median_center = np.median(arr, axis=0)
    dist = np.sqrt(np.sum((arr - median_center)**2, axis=1))
    keep_idx = np.where(dist < threshold)[0]
    kept_frames = [sample_frames[i] for i in keep_idx]
    kept_centers = [sample_centers[i] for i in keep_idx]
    return kept_frames, kept_centers

def refine_sampling(dataset_images, mask, sample_frames, sample_centers,
                    threshold, max_iters, step_size, n_steps, n_wedges, 
                    n_rad_bins, plot_profiles, px, py, max_error_per_interval):
    """
    Given the current polynomial fit (px, py), look at each consecutive pair in sample_frames.
    We linearly step from f1->f2, evaluating the polynomial. If the fit is "too curved", 
    you might sample more than a single midpoint. 
    For simplicity, we just pick the midpoint here. 
    If the error is above max_error_per_interval, we compute the center for that midpoint 
    and add it to the sample.

    Return updated sample_frames, sample_centers.
    """
    new_frames = []
    new_centers = []
    
    # We'll track how many new frames we add so we can show a progress bar.
    intervals = []
    for i in range(len(sample_frames) - 1):
        f1 = sample_frames[i]
        f2 = sample_frames[i+1]
        intervals.append((f1, f2))

    with tqdm(total=len(intervals), desc="Refining intervals", leave=False) as pbar:
        for f1, f2 in intervals:
            pbar.update(1)
            # Evaluate the polynomial for the midpoint
            if f2 - f1 < 2:
                # Already adjacent frames, skip
                continue
            mid = (f1 + f2) // 2
            # Evaluate polynomial at f1, mid, f2
            arr_frames = np.array([f1, mid, f2], dtype=float)
            fit_x = evaluate_polynomial(px, arr_frames)
            fit_y = evaluate_polynomial(py, arr_frames)
            
            # Rough measure of "interval error" could be the difference 
            # in predicted centers between f1->mid->f2 
            # vs. predicted center from f1->f2 directly. 
            # But for simplicity let's just see if the midpoint might 
            # be "non-linear" enough to check. We'll compute the actual center 
            # only if needed. 
            
            # Alternatively, we can say we *always* compute the midpoint if (f2 - f1) > some threshold.
            # But that defeats the adaptive approach. We'll do it only if (f2 - f1) is large.
            if (f2 - f1) > 3:
                # Evaluate the polynomial at mid and see if adding a sample 
                # might reduce total error. We won't know until we compute it, though...
                # Let's do it the "classic" adaptive approach way: 
                #  - compute center at mid
                #  - check the error w.r.t. the polynomial 
                actual_center = compute_center_for_frame(mid, dataset_images, mask, threshold, 
                                                         max_iters, step_size, n_steps, 
                                                         n_wedges, n_rad_bins, plot_profiles)
                pred_center = (fit_x[1], fit_y[1])  # The midpoint's predicted center in fit_x,y
                dist = np.sqrt((actual_center[0] - pred_center[0])**2 + 
                               (actual_center[1] - pred_center[1])**2)
                if dist > max_error_per_interval:
                    # Add this midpoint to our new frames
                    new_frames.append(mid)
                    new_centers.append(actual_center)

    if not new_frames:
        # No updates
        return sample_frames, sample_centers
    
    # Merge old + new
    for f, c in zip(new_frames, new_centers):
        sample_frames.append(f)
        sample_centers.append(c)
    
    # Sort them in ascending order
    zipped = sorted(zip(sample_frames, sample_centers), key=lambda x: x[0])
    sample_frames, sample_centers = zip(*zipped)
    return list(sample_frames), list(sample_centers)

# -----------------------------
#   WIDGET-BASED UI
# -----------------------------
image_file_chooser = FileChooser("/Users/xiaodong/Desktop/UOX-data/UOX1_sub/", filename="UOX1_sub.h5")
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser("/Users/xiaodong/mask/", filename="pxmask.h5")
mask_file_chooser.title = "Select H5 Mask File"
mask_file_chooser.filter_pattern = "*.h5"


threshold_widget = widgets.FloatText(
    value=0.1,
    description="Threshold:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
max_iters_widget = widgets.IntText(
    value=10,
    description="Max Iters:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
step_size_widget = widgets.IntText(
    value=1,
    description="Step Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_steps_widget = widgets.IntText(
    value=5,
    description="n_steps:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_wedges_widget = widgets.IntText(
    value=4,
    description="n_wedges:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_rad_bins_widget = widgets.IntText(
    value=100,
    description="n_rad_bins:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
plot_profiles_widget = widgets.Checkbox(
    value=False,
    description="Plot Profiles",
    style={'description_width': 'initial'}
)
chunk_size_widget = widgets.IntText(
    value=100,
    description="Chunk Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# Initial sampling: every N frames
compute_step_widget = widgets.IntText(
    value=50,
    description="Initial Step:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# Polynomial degree
poly_degree_widget = widgets.IntText(
    value=2,
    description="Poly Degree:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# Outlier threshold
outlier_thresh_widget = widgets.FloatText(
    value=10.0,
    description="Outlier Dist Thresh:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# max_error_per_interval – how big an error (in pixels) we allow before subdividing
max_error_per_interval_widget = widgets.FloatText(
    value=5.0,
    description="Interval Error Thresh:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

# Maximum number of refinement iterations
max_refine_iter_widget = widgets.IntText(
    value=5,
    description="Max Refine Iters:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

process_button = widgets.Button(description="Process Images", button_style="primary")
processing_output = widgets.Output(layout={
    'border': '1px solid black',
    'padding': '5px',
    'height': '400px',
    'overflow_y': 'auto'
})

def on_process_button_clicked(b):
    with processing_output:
        clear_output()
        
        # 1) Check selected files
        image_file = image_file_chooser.selected
        mask_file = mask_file_chooser.selected
        if not image_file:
            print("Please select an image file.")
            return
        if not mask_file:
            print("Please select a mask file.")
            return
        
        # 2) Gather parameters
        threshold = threshold_widget.value
        max_iters = max_iters_widget.value
        step_size = step_size_widget.value
        n_steps = n_steps_widget.value
        n_wedges = n_wedges_widget.value
        n_rad_bins = n_rad_bins_widget.value
        plot_profiles = plot_profiles_widget.value
        chunk_size = chunk_size_widget.value
        
        compute_step = compute_step_widget.value
        poly_degree = poly_degree_widget.value
        outlier_dist_thresh = outlier_thresh_widget.value
        max_error_per_interval = max_error_per_interval_widget.value
        max_refine_iters = max_refine_iter_widget.value
        
        csv_file = os.path.join(os.path.dirname(image_file), "centers.csv")
        if os.path.exists(csv_file):
            os.remove(csv_file)
        
        print(f"Opening image file: {image_file}")
        print(f"Opening mask file:  {mask_file}")
        
        with h5py.File(image_file, 'r') as f_img, h5py.File(mask_file, 'r') as f_mask:
            dataset_images = f_img['/entry/data/images']
            mask = f_mask['/mask'][:].astype(bool)
            n_images = dataset_images.shape[0]
            print(f"Total images: {n_images}")
            
            # 3) Initial sampling frames
            sample_frames = list(range(0, n_images, compute_step))
            if sample_frames[-1] != (n_images - 1):
                sample_frames.append(n_images - 1)  # ensure last frame is included
            
            # 4) Compute centers for initial sample
            print("Computing initial sample centers...")
            sample_centers = []
            with tqdm(total=len(sample_frames), desc="Coarse Sampling") as pbar:
                for frm in sample_frames:
                    c = compute_center_for_frame(frm, dataset_images, mask, threshold, 
                                                 max_iters, step_size, n_steps, 
                                                 n_wedges, n_rad_bins, plot_profiles)
                    sample_centers.append(c)
                    pbar.update(1)
            
            # 5) Iterative refinement
            print("Refining & Fitting Iterations...")
            
            for iteration in range(max_refine_iters):
                print(f"\nIteration #{iteration+1}")
                
                # a) Remove outliers
                sample_frames, sample_centers = remove_outliers(sample_frames, sample_centers, 
                                                                threshold=outlier_dist_thresh)
                
                # If we have fewer than 3 points, break
                if len(sample_frames) < 3:
                    print("Too few sample points remain after outlier removal.")
                    break
                
                # b) Fit polynomial
                px, py = fit_polynomial(sample_frames, sample_centers, degree=poly_degree)
                
                # c) Compute residuals & show basic stats
                resid = compute_residuals(sample_frames, sample_centers, px, py)
                print(f"  Sample Points: {len(sample_frames)}")
                print(f"  Residual - min: {resid.min():.2f}, max: {resid.max():.2f}, mean: {resid.mean():.2f}")
                
                # d) Refine sampling if needed
                old_count = len(sample_frames)
                sample_frames, sample_centers = refine_sampling(
                    dataset_images, mask, sample_frames, sample_centers,
                    threshold, max_iters, step_size, n_steps, n_wedges, 
                    n_rad_bins, plot_profiles, px, py, max_error_per_interval
                )
                new_count = len(sample_frames)
                
                if new_count == old_count:
                    # No new frames added -> we've converged
                    print("No new intervals refined. Convergence reached.")
                    break
            
            # 6) Final fit & produce *all* frames
            print("\nFinal fit & interpolation across all frames...")
            px, py = fit_polynomial(sample_frames, sample_centers, degree=poly_degree)
            
            # Evaluate for every frame
            all_frames = np.arange(n_images, dtype=float)
            all_x = evaluate_polynomial(px, all_frames)
            all_y = evaluate_polynomial(py, all_frames)
            
            # 7) Save to CSV
            df = pd.DataFrame({
                'frame_number': all_frames.astype(int),
                'center_x': all_x,
                'center_y': all_y
            })
            df.to_csv(csv_file, index=False)
            print("Done. Wrote CSV to:", csv_file)

process_button.on_click(on_process_button_clicked)

# Layout the UI
param_box = widgets.VBox([
    threshold_widget,
    max_iters_widget,
    step_size_widget,
    n_steps_widget,
    n_wedges_widget,
    n_rad_bins_widget,
    plot_profiles_widget,
    chunk_size_widget,
    compute_step_widget,
    poly_degree_widget,
    outlier_thresh_widget,
    max_error_per_interval_widget,
    max_refine_iter_widget
])

file_chooser_box = widgets.HBox([image_file_chooser, mask_file_chooser])
ui = widgets.VBox([
    widgets.HTML("<h2>Adaptive, Iterative, & Robust Center-Fitting Tool</h2>"),
    file_chooser_box,
    widgets.HTML("<h3>Processing Parameters</h3>"),
    param_box,
    process_button,
    widgets.HTML("<h3>Logs & Feedback</h3>"),
    processing_output
])

display(ui)


VBox(children=(HTML(value='<h2>Adaptive, Iterative, & Robust Center-Fitting Tool</h2>'), HBox(children=(FileCh…