In [None]:
from load_image_and_mask import load_image_and_mask
from find_diffraction_center import find_diffraction_center

# Example usage:
image_file = "/home/bubl3932/files/UOX1/UOX1_original/UOX1_subset.h5"
mask_file = "/home/bubl3932/mask/pxmask.h5"

threshold=0.1    # Tune as needed
max_iters=10
step_size=1
n_steps=5
n_wedges=4
n_rad_bins=100
plot_profiles=False  # disable wedge profile plots

# Load the full stack of images and corresponding masks.
images, masks = load_image_and_mask(image_file, mask_file, select_first=False)

# If multiple images are present (3D array), process each one.
if images.ndim == 3:
    centers = []
    n_images = images.shape[0]
    for i in range(n_images):
        print(f"Processing image {i+1}/{n_images}")
        center = find_diffraction_center(
            images[i],
            masks,
            threshold=threshold,    # Tune as needed
            max_iters=max_iters,
            step_size=step_size,
            n_steps=n_steps,
            n_wedges=n_wedges,
            n_rad_bins=n_rad_bins,
            plot_profiles=plot_profiles  # disable wedge profile plots
        )
        centers.append(center)
    print("Refined centers for all images:")
    for idx, c in enumerate(centers):
        print(f"Image {idx+1}: Center = {c}")
else:
    center = find_diffraction_center(
        images,
        masks,
        threshold=threshold,
        max_iters=max_iters,
        step_size=step_size,
        n_steps=n_steps,
        n_wedges=n_wedges,
        n_rad_bins=n_rad_bins,
        plot_profiles=plot_profiles
    )
    print(f"Refined diffraction center: {center}")


In [None]:
#FASTER!
from multiprocessing import Pool
from load_image_and_mask import load_image_and_mask
from find_diffraction_center import find_diffraction_center

# Example usage:
image_file = "/home/bubl3932/files/UOX1/UOX1_original/UOX1_subset.h5"
mask_file = "/home/bubl3932/mask/pxmask.h5"

threshold = 0.1    # Tune as needed
max_iters = 10
step_size = 1
n_steps = 5
n_wedges = 4
n_rad_bins = 100
plot_profiles = False  # disable wedge profile plots

# Load the full stack of images and corresponding masks.
images, masks = load_image_and_mask(image_file, mask_file, select_first=False)

def process_image(image):
    # Call the function for each image.
    return find_diffraction_center(
        image,
        masks,
        threshold=threshold,
        max_iters=max_iters,
        step_size=step_size,
        n_steps=n_steps,
        n_wedges=n_wedges,
        n_rad_bins=n_rad_bins,
        plot_profiles=plot_profiles
    )

if images.ndim == 3:
    n_images = images.shape[0]
    print(f"Processing {n_images} images with multiprocessing...")
    
    with Pool() as pool:
        centers = pool.map(process_image, [images[i] for i in range(n_images)])
    
    print("Refined centers for all images:")
    for idx, c in enumerate(centers):
        print(f"Image {idx+1}: Center = {c}")
else:
    center = find_diffraction_center(
        images,
        masks,
        threshold=threshold,
        max_iters=max_iters,
        step_size=step_size,
        n_steps=n_steps,
        n_wedges=n_wedges,
        n_rad_bins=n_rad_bins,
        plot_profiles=plot_profiles
    )
    print(f"Refined diffraction center: {center}")


Processing 100 images with multiprocessing...
Initial guess center: (np.float64(513.4013368936525), np.float64(515.6643540693383))
Initial guess center: (np.float64(512.6576707376385), np.float64(512.6064878908342))
Initial guess center: (np.float64(518.6861788045974), np.float64(512.9514922156633))
Initial guess center: (np.float64(517.7431852961182), np.float64(513.5386161803308))
Initial guess center: (np.float64(515.6430817595906), np.float64(508.55458404374986))
Initial guess center: (np.float64(517.96149213769), np.float64(510.33380252653086))
Initial guess center: (np.float64(516.9138198478753), np.float64(511.90488632800873))
Initial guess center: (np.float64(515.9981993476318), np.float64(516.5679639519387))
Initial guess center: (np.float64(518.8802865741899), np.float64(510.0880514757972))
Initial guess center: (np.float64(518.3587522170816), np.float64(511.6645263775953))
Initial guess center: (np.float64(518.4206882846483), np.float64(517.6266863872868))
Initial guess cent

In [4]:
%matplotlib qt
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

# Determine if there are multiple images.
if images.ndim == 3:
    n_images = images.shape[0]
else:
    n_images = 1

# Create the figure and axis.
fig, ax = plt.subplots(figsize=(8, 8))
plt.subplots_adjust(bottom=0.25)  # leave space at the bottom for the slider

# Initial image index.
init_index = 0

# Display the initial image.
if n_images > 1:
    im_handle = ax.imshow(images[init_index], cmap='gray', origin='lower')
    refined_center = centers[init_index]
else:
    im_handle = ax.imshow(images, cmap='gray', origin='lower')
    refined_center = center  # single image case

# Plot the refined center marker.
center_handle, = ax.plot(refined_center[1], refined_center[0], 'ro', markersize=10, label='Refined Center')
ax.set_title('Diffraction Image with Refined Center')
ax.set_xlabel('Column (pixels)')
ax.set_ylabel('Row (pixels)')
ax.legend()

# Create a slider axis below the image.
ax_slider = plt.axes([0.2, 0.1, 0.65, 0.03])
slider = Slider(ax_slider, 'Image Index', 0, n_images - 1, valinit=init_index, valstep=1)

# Update function to be called whenever the slider value changes.
def update(val):
    idx = int(slider.val)
    # Update image.
    if n_images > 1:
        im_handle.set_data(images[idx])
        refined_center = centers[idx]
    else:
        refined_center = center
    # Update center marker.
    center_handle.set_data(refined_center[1], refined_center[0])
    fig.canvas.draw_idle()

slider.on_changed(update)
plt.show()


MESA: error: ZINK: failed to choose pdev
glx: failed to create drisw screen


Traceback (most recent call last):
  File "/home/bubl3932/anaconda3/envs/pyxem-env/lib/python3.10/site-packages/matplotlib/cbook.py", line 361, in process
    func(*args, **kwargs)
  File "/home/bubl3932/anaconda3/envs/pyxem-env/lib/python3.10/site-packages/matplotlib/widgets.py", line 592, in <lambda>
    return self._observers.connect('changed', lambda val: func(val))
  File "/tmp/ipykernel_1251008/3282708397.py", line 47, in update
    center_handle.set_data(refined_center[1], refined_center[0])
  File "/home/bubl3932/anaconda3/envs/pyxem-env/lib/python3.10/site-packages/matplotlib/lines.py", line 666, in set_data
    self.set_xdata(x)
  File "/home/bubl3932/anaconda3/envs/pyxem-env/lib/python3.10/site-packages/matplotlib/lines.py", line 1290, in set_xdata
    raise RuntimeError('x must be a sequence')
RuntimeError: x must be a sequence
Traceback (most recent call last):
  File "/home/bubl3932/anaconda3/envs/pyxem-env/lib/python3.10/site-packages/matplotlib/cbook.py", line 361, in p

In [5]:
import matplotlib.pyplot as plt

# For a single image case (or for example, the first image in the stack):
plt.figure(figsize=(8, 8))
plt.imshow(images[0] if images.ndim == 3 else images, cmap='gray', origin='lower')
# Note: best_center is either 'center' or centers[0] if multiple images were processed.
refined_center = centers[0] if images.ndim == 3 else center
plt.plot(refined_center[1], refined_center[0], 'ro', markersize=10, label='Refined Center')
plt.title('Diffraction Image with Refined Center')
plt.xlabel('Column (pixels)')
plt.ylabel('Row (pixels)')
plt.legend()
plt.show()
