In [101]:
import os
import matplotlib.pyplot as plt
import numpy as np
%matplotlib qt
from matplotlib.widgets import Slider
import h5py
from pystackreg import StackReg
import skimage.io as io

In [41]:
fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

extent=[0 + shift_list[0][1],
        ref_img.shape[1] + shift_list[0][1],
        0 - shift_list[0][0],
        ref_img.shape[0] - shift_list[0][0]]

ymin, xmin = np.min(np.asarray(shift_list), axis=0)
ymax, xmax = np.max(np.asarray(shift_list), axis=0)
ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)

element_img = ax.imshow(img_reg[0], extent=extent)

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    val_ind = int(np.round(val, 0))
    extent=[0 + shift_list[val_ind][1],
            ref_img.shape[1] + shift_list[val_ind][1],
            0 - shift_list[val_ind][0],
            ref_img.shape[0] - shift_list[val_ind][0]]
    
    ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
    ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)

    element_img.remove()
    #element_img = ax.imshow(shift(img_reg[val_ind], shift_list[val_ind][::-1]), extent=extent)
    element_img = ax.imshow(img_reg[val_ind], extent=extent)
    
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [43]:
fig, ax = plt.subplots(2, 1, figsize=(5, 5), dpi=200)

extent=[0 + shift_list[0][1],
        ref_img.shape[1] + shift_list[0][1],
        0 - shift_list[0][0],
        ref_img.shape[0] - shift_list[0][0]]

rst_img = ax[0].imshow(img_reg[0], extent=extent)
reg_img = ax[1].imshow(out[0], extent=extent)

ymin, xmin = np.min(np.asarray(shift_list), axis=0)
ymax, xmax = np.max(np.asarray(shift_list), axis=0)
ax[0].set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax[0].set_ylim(0 + ymax, ref_img.shape[0] + ymin)

slider_ax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global rst_img, reg_img
    val_ind = int(np.round(val, 0))
    extent=[0 + shift_list[val_ind][1],
            ref_img.shape[1] + shift_list[val_ind][1],
            0 - shift_list[val_ind][0],
            ref_img.shape[0] - shift_list[val_ind][0]]
    
    ax[0].set_xlim(0 + xmax, ref_img.shape[1] + xmin)
    ax[0].set_ylim(0 + ymax, ref_img.shape[0] + ymin)

    rst_img.remove()
    rst_img = ax[0].imshow(img_reg[val_ind], extent=extent)

    reg_img.remove()
    reg_img = ax[1].imshow(out[val_ind], extent=extent)
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [44]:
x = np.arange(0, ref_img.shape[1])
y = np.arange(0, ref_img.shape[0])
xx, yy = np.meshgrid(x, y)

In [45]:
import matplotlib
from matplotlib import cm

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

norm = matplotlib.colors.Normalize(vmin=0, vmax=(len(shift_list)))
mapper = cm.scalerMappable(norm=norm, cmap='jet')
grid_colors = [(r, g, b) for r, g, b, a in mapper.to_rgba(range(len(shift_list)))]

for i in range(len(shift_list)):
    xx_plot = xx + shift_list[i][1]
    yy_plot = yy + shift_list[i][0]

    ax.scatter(xx_plot, yy_plot, s=1, label=i, color=grid_colors[i])

ax.set_aspect('equal')
fig.show()

In [19]:
import matplotlib
from matplotlib import cm

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

norm = matplotlib.colors.Normalize(vmin=0, vmax=(len(shift_list)))
mapper = cm.scalerMappable(norm=norm, cmap='jet')
grid_colors = [(r, g, b) for r, g, b, a in mapper.to_rgba(range(len(shift_list)))]

for i in range(len(shift_list)):
    xx_plot = xx + shift_list[i][1]
    yy_plot = yy + shift_list[i][0]

    ax.scatter(xx_plot, yy_plot, s=1, label=i, color='k')

mean_shift = np.mean(np.asarray(shift_list), axis=0)
xx_plot = xx + mean_shift[1]
yy_plot = yy + mean_shift[0]
ax.scatter(xx_plot, yy_plot, s=2, label=i, color='r')

ax.set_aspect('equal')
ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)
fig.show()

In [82]:
i = 20

xx_virt = (xx + mean_shift[1]).flatten()
yy_virt = (yy + mean_shift[0]).flatten()

mask = np.all([xx_virt > 0 + xmax,
               xx_virt < ref_img.shape[1] + xmin,
               yy_virt > 0 + ymax,
               yy_virt < ref_img.shape[0] + ymin], axis=0)

xx_virt = xx_virt[mask]
yy_virt = yy_virt[mask]

virt_shape = (len(np.unique(yy_virt)), len(np.unique(xx_virt)))

xx_i = (xx + shift_list[i][1]).flatten()
yy_i = (yy + shift_list[i][0]).flatten()

xx_diff = xx_virt[:, np.newaxis] - xx_i[np.newaxis, :]
yy_diff = yy_virt[:, np.newaxis] - yy_i[np.newaxis, :]

dist = np.sqrt(xx_diff**2 + yy_diff**2)

In [84]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

virt_ind = -1

# Nearest neighbors - guarantees interpolation
dist_mask = np.where(dist[virt_ind] < 1.25)[0]

# Nearest pixel
dist_mask = np.argmin(dist[virt_ind])


ax.scatter(xx_i.flatten(), yy_i.flatten(), s=1, c='k', alpha=0.5)
ax.scatter(xx_i.flatten()[dist_mask], yy_i.flatten()[dist_mask], s=2, c='r')
ax.scatter(xx_virt[virt_ind], yy_virt[virt_ind], s=3, c='k')

ax.set_aspect('equal')
fig.show()

In [86]:
xx_virt = (xx + mean_shift[1]).flatten()
yy_virt = (yy + mean_shift[0]).flatten()

mask = np.all([xx_virt > 0 + xmax,
               xx_virt < ref_img.shape[1] + xmin,
               yy_virt > 0 + ymax,
               yy_virt < ref_img.shape[0] + ymin], axis=0)

xx_virt = xx_virt[mask]
yy_virt = yy_virt[mask]

virt_shape = (len(np.unique(yy_virt)), len(np.unique(xx_virt)))

xx_i = (xx + shift_list[i][1]).flatten()
yy_i = (yy + shift_list[i][0]).flatten()

xx_i = xx.flatten()[:, np.newaxis] + np.array(shift_list)[np.newaxis, :, 1]
yy_i = yy.flatten()[:, np.newaxis] + np.array(shift_list)[np.newaxis, :, 0]

xx_diff = xx_virt[:, np.newaxis, np.newaxis] - xx_i[np.newaxis, :]
yy_diff = yy_virt[:, np.newaxis, np.newaxis] - yy_i[np.newaxis, :]

dist = np.sqrt(xx_diff**2 + yy_diff**2)

In [214]:
file_path = 'D:\\Musterman_postdoc\\20240223_Musterman\\xrd_maps\\'


virt_indices = (8, 84)
virt_indices = (15, 137)
ravel_index = np.ravel_multi_index(virt_indices, virt_shape)
pixels_index = np.argmin(dist[ravel_index], axis=0)
pixel_indices = np.asarray(np.unravel_index(pixels_index, ref_img.shape)).T


img_stack = []
en_stack = []

for i, scanid in enumerate(base_scan + np.array(FXI_first)):
    with h5py.File(f'{file_path}scan{scanid}_xrd.h5', 'r') as f:
        base_grp = f['xrdmap/image_data']
        img = base_grp['raw_images'][tuple(pixel_indices[i])]

        img_stack.append(img.astype(np.float32))
        en_stack.append(f['xrdmap'].attrs['energy'])