In [1]:
import os
import h5py
import numpy as np
import pandas as pd
import threading
import time
from multiprocessing import Pool

import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display

# Import the updated functions.
from ICFTOTAL import center_of_mass_initial_guess, find_diffraction_center

# Global flag for stopping processing.
stop_processing = False

# -------------------------------
# File chooser widgets.
# -------------------------------
image_file_chooser = FileChooser("/Users/xiaodong/Desktop/UOX-data/UOX1_sub/", filename="UOX1_sub.h5")
image_file_chooser.title = "Select H5 Image File"
image_file_chooser.filter_pattern = "*.h5"

mask_file_chooser = FileChooser("/Users/xiaodong/mask/", filename="pxmask.h5")
mask_file_chooser.title = "Select H5 Mask File"
mask_file_chooser.filter_pattern = "*.h5"

# -------------------------------
# Processing parameters.
# -------------------------------
n_wedges_widget = widgets.IntText(
    value=4, description="n_wedges:", layout=widgets.Layout(width='200px')
)
n_rad_bins_widget = widgets.IntText(
    value=100, description="n_rad_bins:", layout=widgets.Layout(width='200px')
)

xatol_widget = widgets.FloatText(
    value=1, description='xatol:', layout=widgets.Layout(width='200px')
)
fatol_widget = widgets.FloatText(
    value=10, description='fatol:', layout=widgets.Layout(width='200px')
)

# Add a chunk size widget.
chunk_size_widget = widgets.IntText(
    value=10, description="Chunk Size:", layout=widgets.Layout(width='200px')
)

# -------------------------------
# Advanced parameters: center specification.
# -------------------------------
# These parameters are now ignored; each image uses its own center-of-mass guess.
use_default_center_widget = widgets.Checkbox(
    value=True, description="(Ignored) Use default center-of-mass guess"
)
initial_center_x_widget = widgets.FloatText(
    value=0.0, description="(Ignored) Initial Center X:", layout=widgets.Layout(width='200px')
)
initial_center_y_widget = widgets.FloatText(
    value=0.0, description="(Ignored) Initial Center Y:", layout=widgets.Layout(width='200px')
)
def get_initial_center():
    # Always compute the center-of-mass; ignore any preset center.
    return None

verbose_widget = widgets.Checkbox(
    value=True, description="Verbose:"
)

# -------------------------------
# Control buttons.
# -------------------------------
process_button = widgets.Button(
    description="Process Images", button_style="primary"
)
stop_button = widgets.Button(
    description="Stop Processing", button_style="danger"
)

# -------------------------------
# Progress bar widget.
# -------------------------------
progress_bar = widgets.FloatProgress(
    min=0, max=100, value=0,
    description="Progress:",
    layout=widgets.Layout(width='600px')
)
progress_bar.style = {'description_width': '250px'}

def update_ui_state(processing=False):
    """Enable or disable UI elements based on processing state."""
    process_button.disabled = processing
    image_file_chooser.disabled = processing
    mask_file_chooser.disabled = processing

def load_chunk(image_file, start, end):
    """
    Load a chunk of images from the H5 file.
    Slicing the dataset loads only the required subset.
    """
    with h5py.File(image_file, 'r') as f:
        images = f['/entry/data/images']
        chunk = images[start:end].astype(np.float32)
    return chunk

def process_single_image(img, mask, n_wedges, n_rad_bins, xatol, fatol, verbose):
    """
    Process a single image:
      - Compute the center-of-mass initial guess.
      - Refine the diffraction center.
    """
    init_center = center_of_mass_initial_guess(img, mask)
    refined_center = find_diffraction_center(
        img, mask,
        initial_center=init_center,
        n_wedges=n_wedges,
        n_rad_bins=n_rad_bins,
        xatol=xatol,
        fatol=fatol,
        verbose=verbose
    )
    return refined_center

def process_images():
    global stop_processing
    stop_processing = False  # Reset stop flag.
    update_ui_state(processing=True)
    
    try:
        # Retrieve selected file paths.
        image_file = image_file_chooser.selected
        mask_file = mask_file_chooser.selected
        
        if not image_file or not mask_file:
            print("Please select both an image file and a mask file.")
            return
        
        # Load the mask and convert to boolean.
        with h5py.File(mask_file, 'r') as f_mask:
            mask = f_mask['/mask'][:].astype(bool)
        
        # Retrieve parameter values.
        n_wedges = n_wedges_widget.value
        n_rad_bins = n_rad_bins_widget.value
        xatol = xatol_widget.value
        fatol = fatol_widget.value
        verbose = verbose_widget.value
        chunk_size = chunk_size_widget.value
        
        # Define output CSV file path.
        csv_file = os.path.join(os.path.dirname(image_file), f"centers_xatol_{xatol}_fatol_{fatol}.csv")
        if os.path.exists(csv_file):
            os.remove(csv_file)
        header_written = False
        
        print("Opening image file...")
        # Open the file to get the total number of images.
        with h5py.File(image_file, 'r') as f_img:
            images_dataset = f_img['/entry/data/images']
            n_images = images_dataset.shape[0]
        
        progress_bar.max = n_images
        progress_bar.value = 0
        start_time = time.time()
        
        frame_counter = 0
        
        # Create a multiprocessing Pool.
        with Pool() as pool:
            # Process images chunk-by-chunk.
            for start_idx in range(0, n_images, chunk_size):
                if stop_processing:
                    print("\nProcessing stopped by user.")
                    break
                    
                end_idx = min(start_idx + chunk_size, n_images)
                current_chunk = load_chunk(image_file, start_idx, end_idx)
                
                # Prepare arguments for each image in the chunk.
                args = [(img, mask, n_wedges, n_rad_bins, xatol, fatol, verbose) for img in current_chunk]
                
                # Use Pool.starmap to process images in parallel.
                results = pool.starmap(process_single_image, args)
                
                # Write the results incrementally.
                df_chunk = pd.DataFrame(
                    [[frame_counter + idx, res[0], res[1]] for idx, res in enumerate(results)],
                    columns=["frame_number", "center_x", "center_y"]
                )
                mode = "w" if not header_written else "a"
                df_chunk.to_csv(csv_file, index=False, mode=mode, header=not header_written)
                header_written = True
                frame_counter += len(results)
                
                progress_bar.value = frame_counter
                percent = frame_counter / n_images * 100
                elapsed = time.time() - start_time
                remaining = (elapsed / frame_counter) * (n_images - frame_counter) if frame_counter > 0 else 0
                progress_bar.description = f"{percent:.1f}% - Elapsed: {elapsed:.1f}s, Rem: {remaining:.1f}s"
                print(f"[{time.strftime('%H:%M:%S')}] Processed frames {start_idx} to {end_idx} ({percent:.1f}%)")
                    
        print("\nProcessing complete.")
        print("CSV file written to:", csv_file)
    except Exception as e:
        print("An error occurred during processing:")
        print(str(e))
    finally:
        update_ui_state(processing=False)

def on_process_button_clicked(b):
    threading.Thread(target=process_images, daemon=True).start()

def on_stop_button_clicked(b):
    global stop_processing
    stop_processing = True

# Set button callbacks.
process_button.on_click(on_process_button_clicked)
stop_button.on_click(on_stop_button_clicked)

# -------------------------------
# Assemble the UI.
# -------------------------------
basic_params_box = widgets.VBox([n_wedges_widget, n_rad_bins_widget, chunk_size_widget])
center_box = widgets.HBox([use_default_center_widget, initial_center_x_widget, initial_center_y_widget])
advanced_params_box = widgets.VBox([center_box, verbose_widget, xatol_widget, fatol_widget])
file_chooser_box = widgets.HBox([image_file_chooser, mask_file_chooser])
button_box = widgets.HBox([process_button, stop_button])

ui = widgets.VBox([
    widgets.HTML("<h2>Interactive Image Processing Tool</h2>"),
    file_chooser_box,
    widgets.HTML("<h3>Processing Parameters</h3>"),
    basic_params_box,
    widgets.HTML("<h3>Advanced Parameters</h3>"),
    advanced_params_box,
    button_box,
    progress_bar
])

display(ui)


VBox(children=(HTML(value='<h2>Interactive Image Processing Tool</h2>'), HBox(children=(FileChooser(path='/Use…

Opening image file...


In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt

def compare_centers(file1, file2, output_file="centers_comparison.csv"):
    """
    Compare centers between two CSV files.
    
    Both CSV files should have the following columns:
      - frame_number
      - center_x
      - center_y

    The function merges the files on frame_number, computes the differences
    in the x and y coordinates, calculates the Euclidean distance between the centers,
    prints summary statistics, saves the result to output_file, and creates plots.
    """
    # Load CSV files into DataFrames.
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)
    
    # Merge DataFrames on 'frame_number'
    merged = pd.merge(df1, df2, on="frame_number", suffixes=('_file1', '_file2'))
    
    # Calculate the differences between the centers.
    merged["delta_x"] = merged["center_x_file1"] - merged["center_x_file2"]
    merged["delta_y"] = merged["center_y_file1"] - merged["center_y_file2"]
    merged["euclidean_distance"] = np.sqrt(merged["delta_x"]**2 + merged["delta_y"]**2)
    
    # Print summary statistics.
    print("Summary of differences:")
    print("Mean difference in x: {:.3f}".format(merged["delta_x"].mean()))
    print("Mean difference in y: {:.3f}".format(merged["delta_y"].mean()))
    print("Mean Euclidean distance: {:.3f}".format(merged["euclidean_distance"].mean()))
    
    # Save the merged DataFrame to a CSV file.
    merged.to_csv(output_file, index=False)
    print("Comparison results saved to:", output_file)
    
    # Plot Euclidean distance versus frame number.
    plt.figure(figsize=(10, 6))
    plt.plot(merged["frame_number"], merged["euclidean_distance"], marker='o', linestyle='-')
    plt.xlabel("Frame Number")
    plt.ylabel("Euclidean Distance")
    plt.title("Euclidean Distance between Centers by Frame")
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    # Plot delta_x and delta_y versus frame number.
    plt.figure(figsize=(10, 6))
    plt.plot(merged["frame_number"], merged["delta_x"], marker='o', linestyle='-', label="Delta X")
    plt.plot(merged["frame_number"], merged["delta_y"], marker='o', linestyle='-', label="Delta Y")
    plt.xlabel("Frame Number")
    plt.ylabel("Difference (pixels)")
    plt.title("Differences in Center Coordinates by Frame")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    file1 = "/Users/xiaodong/Desktop/UOX-data/UOX1_sub/centers_xatol_0.01_fatol_10.0.csv"
    file2 = "/Users/xiaodong/Desktop/UOX-data/UOX1_sub/centers_xatol_0.1_fatol_10.csv"
    output_file = os.path.join(os.path.dirname(file1), "centers_comparison.csv")
    compare_centers(file1, file2, output_file)
