In [None]:
# FASTER AND CHUNKED WITH CHUNK-WISE CSV SAVING

import os
import h5py
import numpy as np
import pandas as pd
from multiprocessing import Pool
from find_diffraction_center import find_diffraction_center

# Define your image and mask file paths
image_file = "/home/bubl3932/files/UOX1/UOX1_original_IQM/UOX1.h5"
mask_file = "/home/bubl3932/mask/pxmask.h5"

def process_image(args):
    frame_number, image = args
    center = find_diffraction_center(
        image,
        masks,
        threshold=threshold,
        max_iters=max_iters,
        step_size=step_size,
        n_steps=n_steps,
        n_wedges=n_wedges,
        n_rad_bins=n_rad_bins,
        plot_profiles=plot_profiles,
        frame_number=frame_number
    )
    return frame_number, center

# Define parameters
threshold = 0.1
max_iters = 10
step_size = 1
n_steps = 5
n_wedges = 4
n_rad_bins = 100
plot_profiles = False

chunk_size = 100  # Adjust based on available memory
csv_file = os.path.join(os.path.dirname(image_file),"centers.csv")

# Remove existing CSV file if present
if os.path.exists(csv_file):
    os.remove(csv_file)
header_written = False

with h5py.File(image_file, 'r') as f_img, h5py.File(mask_file, 'r') as f_mask:
    # Adjust dataset names to those in your HDF5 files
    dataset_images = f_img['/entry/data/images']
    masks = f_mask['/mask'][:].astype(bool)  # load the mask entirely if it's small enough

    n_images = dataset_images.shape[0]
    print(f"Total images: {n_images}")

    for i in range(0, n_images, chunk_size):
        # Read a chunk of images and convert them to float32
        images_chunk = dataset_images[i:i+chunk_size].astype(np.float32)
        # Create list of tuples: (frame_number, image)
        args = [(i + idx, img) for idx, img in enumerate(images_chunk)]
        
        # Process the current chunk using multiprocessing
        with Pool() as pool:
            centers = pool.map(process_image, args)
        
        # Create a DataFrame for this chunk
        df_chunk = pd.DataFrame(centers, columns=["frame_number", "center"])
        df_chunk[['center_x', 'center_y']] = pd.DataFrame(df_chunk['center'].tolist(), index=df_chunk.index)
        df_chunk = df_chunk.drop(columns=["center"])
        
        # Save the current chunk to CSV in append mode (write header only for the first chunk)
        if not header_written:
            df_chunk.to_csv(csv_file, index=False, mode='w')
            header_written = True
        else:
            df_chunk.to_csv(csv_file, index=False, mode='a', header=False)
        
        print(f"Processed images {i} to {min(i+chunk_size, n_images)}")

print("Finished processing and saved centers to", csv_file)


Total images: 27025
Frame 42: Convergence reached with center found (639.1504535502033, 5947.113934939005).
Frame 38: Convergence reached with center found (550.8118896925859, -4407.283996383363).
Frame 40: Convergence reached with center found (591.076351580272, 1284.7985576690382).
Frame 44: Convergence reached with center found (1220.8202890944476, 2109.5328238098364).
Frame 2: Convergence reached with center found (701.3522746201035, 1228.5320074432593).
Frame 28: Convergence reached with center found (677.3694888863673, 1134.8312436173835).
Frame 24: Convergence reached with center found (334.43094050743656, -164.57357392825898).
Frame 32: Convergence reached with center found (541.6629018383725, 668.7007739885559).
Frame 10: Convergence reached with center found (545.4818295103632, 768.9169079599466).
Frame 0: Convergence reached with center found (478.0207882423268, 738.437138498959).
Frame 34: Convergence reached with center found (581.5671387001387, 673.6801794996065).
Frame 2

In [None]:
# Plotting
%matplotlib qt
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

# Determine if there are multiple images.
if images.ndim == 3:
    n_images = images.shape[0]
else:
    n_images = 1

# Create the figure and axis.
fig, ax = plt.subplots(figsize=(8, 8))
plt.subplots_adjust(bottom=0.25)  # leave space at the bottom for the slider

# Initial image index.
init_index = 0

# Display the initial image.
if n_images > 1:
    im_handle = ax.imshow(images[init_index], cmap='gray', origin='lower')
    refined_center = centers[init_index]
else:
    im_handle = ax.imshow(images, cmap='gray', origin='lower')
    refined_center = center  # single image case

# Plot the refined center marker.
center_handle, = ax.plot(refined_center[1], refined_center[0], 'ro', markersize=10, label='Refined Center')
ax.set_title('Diffraction Image with Refined Center')
ax.set_xlabel('Column (pixels)')
ax.set_ylabel('Row (pixels)')
ax.legend()

# Create a slider axis below the image.
ax_slider = plt.axes([0.2, 0.1, 0.65, 0.03])
slider = Slider(ax_slider, 'Image Index', 0, n_images - 1, valinit=init_index, valstep=1)

# Update function to be called whenever the slider value changes.
def update(val):
    idx = int(slider.val)
    # Update image.
    if n_images > 1:
        im_handle.set_data(images[idx])
        refined_center = centers[idx]
    else:
        refined_center = center
    # Update center marker.
    center_handle.set_data(refined_center[1], refined_center[0])
    fig.canvas.draw_idle()

slider.on_changed(update)
plt.show()


In [None]:
import matplotlib.pyplot as plt

# For a single image case (or for example, the first image in the stack):
plt.figure(figsize=(8, 8))
plt.imshow(images[0] if images.ndim == 3 else images, cmap='gray', origin='lower')
# Note: best_center is either 'center' or centers[0] if multiple images were processed.
refined_center = centers[0] if images.ndim == 3 else center
plt.plot(refined_center[1], refined_center[0], 'ro', markersize=10, label='Refined Center')
plt.title('Diffraction Image with Refined Center')
plt.xlabel('Column (pixels)')
plt.ylabel('Row (pixels)')
plt.legend()
plt.show()
