In [None]:
import os
import numpy as np
import pandas as pd
import h5py
from multiprocessing import Pool
from tqdm.notebook import tqdm

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, clear_output

# -- Your helper function that computes center for a single frame --
from helpers import process_image_full


# =============================================================================
# IRLS FIT HELPERS
# =============================================================================
def poly_design_matrix(x, degree=3):
    """
    Given an array x of shape (N,), produce the design matrix for a polynomial
    of 'degree'.
    E.g., for degree=3, returns [1, x, x^2, x^3].
    Shape => (N, degree+1).
    """
    # We'll build [1, x, x^2, ..., x^degree]
    X = np.ones((len(x), degree+1), dtype=np.float64)
    for d in range(1, degree+1):
        X[:, d] = x**d
    return X

def irls_fit_poly(x, y, degree=3, max_iter=50, tol=1e-4):
    """
    Fits a polynomial of given degree to data (x, y) using IRLS for robust regression.

    Returns:
       coeffs: (degree+1,) polynomial coefficients [c0, c1, ..., cD]
    """
    # Build the design matrix
    X = poly_design_matrix(x, degree=degree)
    # Start with a normal (least-squares) solution as initial guess
    # Using linear solve:  (X^T W X) b = X^T W y
    # For initial guess, W = I
    beta = np.linalg.lstsq(X, y, rcond=None)[0]

    # IRLS loop
    for _ in range(max_iter):
        # Predict
        y_pred = X @ beta
        residuals = y - y_pred
        
        # We'll use a simple robust weighting function: e.g. Huber or Cauchy
        # For demonstration, let's do something fairly standard like a Huber-like approach.
        # We'll define an 'epsilon' scale factor.
        # A more formal approach would estimate scale from median absolute deviation.
        epsilon = 1.5 * np.median(np.abs(residuals))  # scale estimate
        if epsilon < 1e-9:
            # data might be degenerate or no variation
            break

        # Huber-like weights
        # w_i = 1 if |res| < epsilon, else epsilon / |res|
        abs_res = np.abs(residuals)
        w = np.where(abs_res <= epsilon, 1.0, epsilon / abs_res)
        # Avoid dividing by zero
        w[abs_res < 1e-12] = 1.0

        # Build diagonal weight matrix
        W = np.diag(w)
        # Weighted least squares update
        XTWX = X.T @ W @ X
        XTWy = X.T @ W @ y
        new_beta = np.linalg.lstsq(XTWX, XTWy, rcond=None)[0]

        # Check convergence
        if np.linalg.norm(new_beta - beta) < tol:
            beta = new_beta
            break

        beta = new_beta

    return beta

def poly_predict(x, coeffs):
    """
    Evaluate polynomial with 'coeffs' on domain x.
    """
    degree = len(coeffs) - 1
    X = poly_design_matrix(x, degree=degree)
    return X @ coeffs


# =============================================================================
# ADAPTIVE SAMPLING LOGIC
# =============================================================================
def adaptive_sampling(
    dataset_images, mask,
    threshold,
    max_iters,
    step_size,
    n_steps,
    n_wedges,
    n_rad_bins,
    plot_profiles,
    desired_accuracy=0.5,
    max_refine_iterations=5,
    init_stride=100,
    poly_degree=3,
    chunk_size=100
):
    """
    Overall logic:
    1) Start with frames = range(0, n_frames, init_stride).
    2) Compute centers for those frames.
    3) Fit polynomial (IRLS) for x(t) and y(t).
    4) Check residual error. If > desired_accuracy, subdivide intervals that are too large.
    5) Repeat until error < desired_accuracy or max_refine_iterations reached.
    6) Return final polynomial coefficients for x and y, plus the final list of sampled frames & centers.
    """

    n_frames = dataset_images.shape[0]
    frames_chosen = list(range(0, n_frames, init_stride))  # initial guess
    if frames_chosen[-1] != n_frames - 1:
        frames_chosen.append(n_frames - 1)  # ensure last frame is included

    frames_chosen = sorted(set(frames_chosen))  # unique & sorted

    # We'll store centers in a dictionary: frame_idx -> (x, y)
    known_centers = {}

    def get_centers_for_frames(frames_needed):
        """Compute centers for frames in frames_needed (which might be new).
           We'll do chunked reading, but not with a Pool for the entire set (for clarity).
           You can adapt for full parallelism if desired.
        """
        frames_needed = sorted(set(frames_needed))
        results = {}
        
        # We can do chunk-wise reading to avoid reading everything at once
        with tqdm(total=len(frames_needed), desc="Computing new centers") as pbar:
            i_next = 0  # index over frames_needed
            while i_next < len(frames_needed):
                fstart = frames_needed[i_next]
                # read a chunk from fstart to fstart+chunk_size
                cstart = fstart
                cend = min(fstart + chunk_size, n_frames)
                
                # gather frames within [cstart, cend)
                frames_in_chunk = []
                while i_next < len(frames_needed) and frames_needed[i_next] < cend:
                    frames_in_chunk.append(frames_needed[i_next])
                    i_next += 1
                
                if len(frames_in_chunk) == 0:
                    continue
                
                # read images
                images_chunk = dataset_images[cstart:cend].astype(np.float32)
                
                # prepare arguments for the pool
                arg_list = []
                for fidx in frames_in_chunk:
                    local_idx = fidx - cstart
                    arg_list.append((
                        fidx,
                        images_chunk[local_idx],
                        mask,
                        threshold,
                        max_iters,
                        step_size,
                        n_steps,
                        n_wedges,
                        n_rad_bins,
                        plot_profiles
                    ))
                
                # process in parallel
                chunk_results = []
                with Pool() as pool:
                    for res in pool.imap_unordered(process_image_full, arg_list):
                        chunk_results.append(res)
                        pbar.update(1)
                
                # store
                for (frame_num, (cy, cx)) in chunk_results:
                    results[frame_num] = (cx, cy)
        
        return results

    refine_iteration = 0
    while refine_iteration < max_refine_iterations:
        # Find which frames we haven't computed yet
        frames_to_compute = [f for f in frames_chosen if f not in known_centers]

        # Compute new centers
        if len(frames_to_compute) > 0:
            new_results = get_centers_for_frames(frames_to_compute)
            known_centers.update(new_results)

        # Build arrays for IRLS
        # frames_chosen is sorted
        chosen_frames_array = np.array(frames_chosen, dtype=np.float32)
        chosen_centers_array = np.array([known_centers[f] for f in frames_chosen], dtype=np.float32)
        # chosen_centers_array[:, 0] = x(t), chosen_centers_array[:, 1] = y(t)

        # Fit polynomial for X(t)
        x_poly = irls_fit_poly(chosen_frames_array, chosen_centers_array[:, 0],
                               degree=poly_degree)
        # Fit polynomial for Y(t)
        y_poly = irls_fit_poly(chosen_frames_array, chosen_centers_array[:, 1],
                               degree=poly_degree)

        # Evaluate error on the chosen frames
        x_pred = poly_predict(chosen_frames_array, x_poly)
        y_pred = poly_predict(chosen_frames_array, y_poly)
        dx = chosen_centers_array[:,0] - x_pred
        dy = chosen_centers_array[:,1] - y_pred
        errors = np.sqrt(dx*dx + dy*dy)

        max_error = errors.max()
        print(f"[Refine #{refine_iteration}] Max error among sampled frames: {max_error:.3f}")

        if max_error < desired_accuracy:
            print("Desired accuracy reached. Stopping refinement.")
            break

        # Otherwise, refine: find intervals with large error and subdivide
        new_frames = []
        for i in range(len(chosen_frames_array)-1):
            f1 = frames_chosen[i]
            f2 = frames_chosen[i+1]
            e1 = errors[i]
            e2 = errors[i+1]
            # If either endpoint has an error > desired_accuracy, we'll add a midpoint sample
            if e1 > desired_accuracy or e2 > desired_accuracy:
                mid = (f1 + f2)//2
                if mid not in frames_chosen and mid not in new_frames and mid != f1 and mid != f2:
                    new_frames.append(mid)

        if len(new_frames) == 0:
            print("No additional frames to add, but error still above threshold. Stopping.")
            break

        frames_chosen = sorted(set(frames_chosen + new_frames))
        refine_iteration += 1

    # Return final polynomial coefficients + final sample
    return {
        "frames_sampled": frames_chosen,
        "centers_sampled": np.array([known_centers[f] for f in frames_chosen]),
        "poly_x": x_poly,
        "poly_y": y_poly
    }


# =============================================================================
# WIDGETS FOR UI
# =============================================================================

# Create file chooser widgets for the image and mask files.
image_file_chooser = FileChooser("/Users/xiaodong/Desktop/UOX-data/UOX1_sub/", filename="UOX1_sub.h5")
# image_file_chooser = FileChooser(os.getcwd())
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser("/Users/xiaodong/mask/", filename="pxmask.h5")
# mask_file_chooser = FileChooser(os.getcwd())
mask_file_chooser.title = "Select H5 Mask File"
mask_file_chooser.filter_pattern = "*.h5"

threshold_widget = widgets.FloatText(
    value=0.1,
    description="Threshold:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
max_iters_widget = widgets.IntText(
    value=10,
    description="Max Iters:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
step_size_widget = widgets.IntText(
    value=1,
    description="Step Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_steps_widget = widgets.IntText(
    value=5,
    description="n_steps:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_wedges_widget = widgets.IntText(
    value=4,
    description="n_wedges:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
n_rad_bins_widget = widgets.IntText(
    value=100,
    description="n_rad_bins:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
plot_profiles_widget = widgets.Checkbox(
    value=False,
    description="Plot Profiles"
)
chunk_size_widget = widgets.IntText(
    value=100,
    description="Chunk Size:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

desired_accuracy_widget = widgets.FloatText(
    value=0.5,
    description="Desired Accuracy (pixels):",
    layout=widgets.Layout(width='250px'),
    style={'description_width': 'initial'}
)
max_refine_iters_widget = widgets.IntText(
    value=5,
    description="Max Refinements:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
initial_stride_widget = widgets.IntText(
    value=100,
    description="Initial Stride:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)
poly_degree_widget = widgets.IntText(
    value=3,
    description="Poly Degree:",
    layout=widgets.Layout(width='200px'),
    style={'description_width': 'initial'}
)

process_button = widgets.Button(
    description="Process Images (Adaptive + IRLS)",
    button_style="primary"
)

processing_output = widgets.Output(layout={
    'border': '1px solid black',
    'padding': '5px',
    'height': '400px',
    'overflow_y': 'auto'
})


# =============================================================================
# MAIN CALLBACK
# =============================================================================
def on_process_button_clicked(b):
    with processing_output:
        clear_output()
        image_file = image_file_chooser.selected
        mask_file = mask_file_chooser.selected
        
        if not image_file:
            print("Please select an image file.")
            return
        if not mask_file:
            print("Please select a mask file.")
            return
        
        threshold = threshold_widget.value
        max_iters = max_iters_widget.value
        step_size = step_size_widget.value
        n_steps = n_steps_widget.value
        n_wedges = n_wedges_widget.value
        n_rad_bins = n_rad_bins_widget.value
        plot_profiles = plot_profiles_widget.value
        chunk_size = chunk_size_widget.value

        desired_accuracy = desired_accuracy_widget.value
        max_refine_iters = max_refine_iters_widget.value
        init_stride = initial_stride_widget.value
        poly_degree = poly_degree_widget.value

        csv_file = os.path.join(os.path.dirname(image_file), "centers_adaptive_irls.csv")
        if os.path.exists(csv_file):
            os.remove(csv_file)

        print(f"Opening H5: {image_file}")
        with h5py.File(image_file, 'r') as f_img, h5py.File(mask_file, 'r') as f_mask:
            dataset_images = f_img['/entry/data/images']
            mask = f_mask['/mask'][:].astype(bool)
            n_frames = dataset_images.shape[0]
            print(f"Total frames: {n_frames}")

            # Adaptive sampling + IRLS
            results = adaptive_sampling(
                dataset_images, mask,
                threshold=threshold,
                max_iters=max_iters,
                step_size=step_size,
                n_steps=n_steps,
                n_wedges=n_wedges,
                n_rad_bins=n_rad_bins,
                plot_profiles=plot_profiles,
                desired_accuracy=desired_accuracy,
                max_refine_iterations=max_refine_iters,
                init_stride=init_stride,
                poly_degree=poly_degree,
                chunk_size=chunk_size
            )

            frames_sampled = results["frames_sampled"]
            centers_sampled = results["centers_sampled"]  # Nx2 => (x, y)
            px = results["poly_x"]
            py = results["poly_y"]

            # If you want a center estimate for *every* frame, 
            # simply evaluate the polynomial from 0..(n_frames-1).
            all_frames = np.arange(n_frames, dtype=np.float32)
            pred_x = poly_predict(all_frames, px)
            pred_y = poly_predict(all_frames, py)

            # Write to CSV
            df = pd.DataFrame({
                "frame_number": all_frames.astype(int),
                "center_x": pred_x,
                "center_y": pred_y
            })
            df.to_csv(csv_file, index=False)
            print(f"Final centers written to {csv_file}")


# Setup UI
process_button.on_click(on_process_button_clicked)

param_box = widgets.VBox([
    threshold_widget,
    max_iters_widget,
    step_size_widget,
    n_steps_widget,
    n_wedges_widget,
    n_rad_bins_widget,
    plot_profiles_widget,
    chunk_size_widget,
    desired_accuracy_widget,
    max_refine_iters_widget,
    initial_stride_widget,
    poly_degree_widget,
])

file_chooser_box = widgets.HBox([image_file_chooser, mask_file_chooser])
ui = widgets.VBox([
    widgets.HTML("<h2>Adaptive Sampling + IRLS</h2>"),
    file_chooser_box,
    widgets.HTML("<h3>Processing Parameters</h3>"),
    param_box,
    process_button,
    widgets.HTML("<h3>Logs & Feedback</h3>"),
    processing_output
])

display(ui)
