In [101]:
import os
import matplotlib.pyplot as plt
import numpy as np
%matplotlib qt
from matplotlib.widgets import Slider
import h5py
from pystackreg import StackReg
import skimage.io as io

In [2]:
def plot_image(image, return_plot=False, **kwargs):

    fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

    im = ax.imshow(image, **kwargs)
    fig.colorbar(im, ax=ax)

    if return_plot:
            return fig, ax
    else:
        fig.show()   

In [3]:
base = 'D:\\Musterman_postdoc\\20240223_Musterman\\'
xrf_dir = f'{base}xrf_maps\\'

In [4]:
base_scan = 153000
FXI_first = [102, 104, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 130, 132, 134, 136, 138, 140, 143, 145]
FXI_second = [157, 159, 161, 163, 165, 167, 169, 171, 173, 175, 177, 179, 181, 183, 185, 187, 189, 191, 193, 195, 197, 199, 201, 203, 205, 207, 209, 211, 213, 215]
LiNbO = [253, 255, 257, 259, 261, 263, 265, 267, 269, 271, 273, 275, 277, 279, 281, 283, 285, 287, 289, 291, 293, 295, 297]
BNL_inner = [485, 487, 489, 491, 493, 495, 502, 504, 506, 508, 510]

In [28]:
energy_list = []
data_list = []

for scan in base_scan + np.array(FXI_first):
    with h5py.File(f'{xrf_dir}scan2D_{scan}_xs_sum8ch.h5', 'r') as f:

        xrf_fit_names = [d.decode('utf-8') for d in f['xrfmap/detsum/xrf_fit_name'][:]]
        xrf_fit = f['xrfmap/detsum/xrf_fit'][:]

        i0 = f['xrfmap/scalers/val'][..., 0]
        xrf_fit = np.concatenate((xrf_fit, np.expand_dims(i0, axis=0)), axis=0)
        xrf_fit = np.transpose(xrf_fit, axes=(0, 2, 1))
        xrf_fit_names.append('i0')

        xrf_dict = dict(zip(xrf_fit_names, xrf_fit))
        energy = f['xrfmap/scan_metadata'].attrs['instrument_mono_incident_energy']

    energy_list.append(energy)
    data_list.append(xrf_dict)

#sort_data = sorted(zip(energy_list, data_list), key=lambda pair: pair[0])
#energy_list = [y for (y, x) in sort_data]
#data_list = [x for (y, x) in sort_data]

In [6]:
'''save_dir = f'{base}plots//'

element_key = 'Er_L'
for ind in range(len(data_list)):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

    scaler = data_list[ind]['i0'].copy()
    scaler[scaler < 1000] = np.mean(scaler)
    scaler[scaler > 1e6] = np.median(scaler)

    ax.imshow(data_list[ind][element_key] / scaler, vmin=0.2, vmax=0.35)

    plt.savefig(f'{save_dir}LiNbO_{ind}.png', transparent=True)
    plt.close()''';

In [29]:
element_key = 'Pt_L'

In [30]:
fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

element_img = ax.imshow(data_list[0][element_key] / data_list[0]['i0'])

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    val_ind = int(np.round(val, 0))
    element_img.remove()
    element_img = ax.imshow(data_list[val_ind][element_key] / data_list[val_ind]['i0'])
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [9]:
fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

element_img = ax.imshow(data_list[0][element_key] / data_list[0]['i0'])

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
energy_slider = Slider(
    ax=slider_ax,
    label='Energy [keV]',
    valmin=np.min(energy_list),
    valmax=np.max(energy_list),
    valinit=np.min(energy_list),
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    diff_arr = np.abs(np.array(energy_list) - val)
    val_ind = np.argmin(diff_arr)
    element_img.remove()
    element_img = ax.imshow(data_list[val_ind][element_key] / data_list[val_ind]['i0'])
    fig.canvas.draw_idle()
    
energy_slider.on_changed(update)

fig.show()

In [31]:
sr = StackReg(StackReg.RIGID_BODY)

img_reg = np.array([index[element_key] / index['i0'] for index in data_list])

out = sr.register_transform_stack(img_reg, reference='first')

In [11]:
fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

element_img = ax.imshow(out[0])

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    val_ind = int(np.round(val, 0))
    element_img.remove()
    element_img = ax.imshow(out[val_ind])
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [12]:
def rotation_scale_translation(ref_img, mov_img, rotation_upsample=1000, shift_upsample=1000, bandpass=(0, 1), fix_rotation=False):
    from skimage.filters import window, difference_of_gaussians
    from scipy.fft import fft2, fftshift
    from skimage.transform import warp_polar, rotate
    from skimage.registration import phase_cross_correlation
    
    if not fix_rotation:
        # Image bandpass
        ref_img = difference_of_gaussians(ref_img, *bandpass)
        mov_img = difference_of_gaussians(mov_img, *bandpass)

        # Window images
        ref_wimage = ref_img * window('hann', ref_img.shape)
        mov_wimage = mov_img * window('hann', ref_img.shape)

        # work with shifted FFT magnitudes
        ref_fs = np.abs(fftshift(fft2(ref_wimage)))
        mov_fs = np.abs(fftshift(fft2(mov_wimage)))

        # Create log-polar transform
        shape = ref_fs.shape
        radius = shape[0] // 8  # only take lower frequencies
        warped_ref_fs = warp_polar(ref_fs, radius=radius, output_shape=shape,
                                scaling='log', order=0)
        warped_mov_fs = warp_polar(mov_fs, radius=radius, output_shape=shape,
                                scaling='log', order=0)
        
        # Register shift in polar space with half of FFT
        warped_ref_fs = warped_ref_fs[:shape[0] // 2, :]  # only use half of FFT
        warped_mov_fs = warped_mov_fs[:shape[0] // 2, :]
        polar_shift, polar_error, polar_phasediff = phase_cross_correlation(warped_ref_fs,
                                                                            warped_mov_fs,
                                                                            upsample_factor=rotation_upsample,
                                                                            normalization=None)
        
        # Convert to useful parameters
        shiftr, shiftc = polar_shift[:2]
        angle = (360 / shape[0]) * shiftr
        klog = shape[1] / np.log(radius)
        scale = np.exp(shiftc / klog)

        # Correct image and register translation
        # adj_image = rescale(moving_image, 2 - scale) # this seem wrong
        adj_img = rotate(mov_img, -angle)
        # adj_image = rescale
    else:
        angle = 0
        scale = 0
        adj_img = mov_img

    shift, trans_error, trans_phasediff = phase_cross_correlation(ref_img,
                                                                        adj_img,
                                                                        upsample_factor=shift_upsample,
                                                                        normalization=None)
    
    return angle, scale, shift

In [32]:
ref_img = img_reg[0]

angle_list, scale_list, shift_list = [], [], []

for mov_img in img_reg:

    angle, scale, shift = rotation_scale_translation(ref_img, mov_img, bandpass=(0, 1), fix_rotation=True)
    print(f'{angle:.3f}', f'{scale:.2f}', shift)
    angle_list.append(angle)
    scale_list.append(scale)
    shift_list.append(shift)

0.000 0.00 [0. 0.]
0.000 0.00 [ 0.138 -0.294]
0.000 0.00 [ 0.474 -0.455]
0.000 0.00 [ 0.332 -1.111]
0.000 0.00 [ 0.483 -1.604]
0.000 0.00 [ 0.228 -2.015]
0.000 0.00 [ 0.372 -2.675]
0.000 0.00 [ 0.514 -3.327]
0.000 0.00 [ 0.281 -3.671]
0.000 0.00 [ 0.235 -4.453]
0.000 0.00 [ 0.157 -4.942]
0.000 0.00 [ 0.095 -5.724]
0.000 0.00 [ 0.117 -6.371]
0.000 0.00 [ 0.029 -7.08 ]
0.000 0.00 [ 0.243 -7.762]
0.000 0.00 [ 0.085 -8.211]
0.000 0.00 [ 0.152 -5.545]
0.000 0.00 [ 0.024 -4.999]
0.000 0.00 [-0.265 -5.1  ]
0.000 0.00 [-0.154 -4.97 ]
0.000 0.00 [-0.028 -4.954]


In [33]:
from skimage.transform import AffineTransform, warp

def shift(image, vector):
    transform = AffineTransform(translation=vector)
    shifted = warp(image, transform, mode='constant', preserve_range=True, cval=0)
    shifted = shifted.astype(image.dtype)
    return shifted

shifted = shift(img_reg[1], shift_list[1])

In [41]:
fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

extent=[0 + shift_list[0][1],
        ref_img.shape[1] + shift_list[0][1],
        0 - shift_list[0][0],
        ref_img.shape[0] - shift_list[0][0]]

ymin, xmin = np.min(np.asarray(shift_list), axis=0)
ymax, xmax = np.max(np.asarray(shift_list), axis=0)
ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)

element_img = ax.imshow(img_reg[0], extent=extent)

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    val_ind = int(np.round(val, 0))
    extent=[0 + shift_list[val_ind][1],
            ref_img.shape[1] + shift_list[val_ind][1],
            0 - shift_list[val_ind][0],
            ref_img.shape[0] - shift_list[val_ind][0]]
    
    ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
    ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)

    element_img.remove()
    #element_img = ax.imshow(shift(img_reg[val_ind], shift_list[val_ind][::-1]), extent=extent)
    element_img = ax.imshow(img_reg[val_ind], extent=extent)
    
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [43]:
fig, ax = plt.subplots(2, 1, figsize=(5, 5), dpi=200)

extent=[0 + shift_list[0][1],
        ref_img.shape[1] + shift_list[0][1],
        0 - shift_list[0][0],
        ref_img.shape[0] - shift_list[0][0]]

rst_img = ax[0].imshow(img_reg[0], extent=extent)
reg_img = ax[1].imshow(out[0], extent=extent)

ymin, xmin = np.min(np.asarray(shift_list), axis=0)
ymax, xmax = np.max(np.asarray(shift_list), axis=0)
ax[0].set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax[0].set_ylim(0 + ymax, ref_img.shape[0] + ymin)

slider_ax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
index_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global rst_img, reg_img
    val_ind = int(np.round(val, 0))
    extent=[0 + shift_list[val_ind][1],
            ref_img.shape[1] + shift_list[val_ind][1],
            0 - shift_list[val_ind][0],
            ref_img.shape[0] - shift_list[val_ind][0]]
    
    ax[0].set_xlim(0 + xmax, ref_img.shape[1] + xmin)
    ax[0].set_ylim(0 + ymax, ref_img.shape[0] + ymin)

    rst_img.remove()
    rst_img = ax[0].imshow(img_reg[val_ind], extent=extent)

    reg_img.remove()
    reg_img = ax[1].imshow(out[val_ind], extent=extent)
    fig.canvas.draw_idle()
    
index_slider.on_changed(update)

fig.show()

In [44]:
x = np.arange(0, ref_img.shape[1])
y = np.arange(0, ref_img.shape[0])
xx, yy = np.meshgrid(x, y)

In [45]:
import matplotlib
from matplotlib import cm

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

norm = matplotlib.colors.Normalize(vmin=0, vmax=(len(shift_list)))
mapper = cm.scalerMappable(norm=norm, cmap='jet')
grid_colors = [(r, g, b) for r, g, b, a in mapper.to_rgba(range(len(shift_list)))]

for i in range(len(shift_list)):
    xx_plot = xx + shift_list[i][1]
    yy_plot = yy + shift_list[i][0]

    ax.scatter(xx_plot, yy_plot, s=1, label=i, color=grid_colors[i])

ax.set_aspect('equal')
fig.show()

In [19]:
import matplotlib
from matplotlib import cm

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

norm = matplotlib.colors.Normalize(vmin=0, vmax=(len(shift_list)))
mapper = cm.scalerMappable(norm=norm, cmap='jet')
grid_colors = [(r, g, b) for r, g, b, a in mapper.to_rgba(range(len(shift_list)))]

for i in range(len(shift_list)):
    xx_plot = xx + shift_list[i][1]
    yy_plot = yy + shift_list[i][0]

    ax.scatter(xx_plot, yy_plot, s=1, label=i, color='k')

mean_shift = np.mean(np.asarray(shift_list), axis=0)
xx_plot = xx + mean_shift[1]
yy_plot = yy + mean_shift[0]
ax.scatter(xx_plot, yy_plot, s=2, label=i, color='r')

ax.set_aspect('equal')
ax.set_xlim(0 + xmax, ref_img.shape[1] + xmin)
ax.set_ylim(0 + ymax, ref_img.shape[0] + ymin)
fig.show()

In [82]:
i = 20

xx_virt = (xx + mean_shift[1]).flatten()
yy_virt = (yy + mean_shift[0]).flatten()

mask = np.all([xx_virt > 0 + xmax,
               xx_virt < ref_img.shape[1] + xmin,
               yy_virt > 0 + ymax,
               yy_virt < ref_img.shape[0] + ymin], axis=0)

xx_virt = xx_virt[mask]
yy_virt = yy_virt[mask]

virt_shape = (len(np.unique(yy_virt)), len(np.unique(xx_virt)))

xx_i = (xx + shift_list[i][1]).flatten()
yy_i = (yy + shift_list[i][0]).flatten()

xx_diff = xx_virt[:, np.newaxis] - xx_i[np.newaxis, :]
yy_diff = yy_virt[:, np.newaxis] - yy_i[np.newaxis, :]

dist = np.sqrt(xx_diff**2 + yy_diff**2)

In [84]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

virt_ind = -1

# Nearest neighbors - guarantees interpolation
dist_mask = np.where(dist[virt_ind] < 1.25)[0]

# Nearest pixel
dist_mask = np.argmin(dist[virt_ind])


ax.scatter(xx_i.flatten(), yy_i.flatten(), s=1, c='k', alpha=0.5)
ax.scatter(xx_i.flatten()[dist_mask], yy_i.flatten()[dist_mask], s=2, c='r')
ax.scatter(xx_virt[virt_ind], yy_virt[virt_ind], s=3, c='k')

ax.set_aspect('equal')
fig.show()

In [86]:
xx_virt = (xx + mean_shift[1]).flatten()
yy_virt = (yy + mean_shift[0]).flatten()

mask = np.all([xx_virt > 0 + xmax,
               xx_virt < ref_img.shape[1] + xmin,
               yy_virt > 0 + ymax,
               yy_virt < ref_img.shape[0] + ymin], axis=0)

xx_virt = xx_virt[mask]
yy_virt = yy_virt[mask]

virt_shape = (len(np.unique(yy_virt)), len(np.unique(xx_virt)))

xx_i = (xx + shift_list[i][1]).flatten()
yy_i = (yy + shift_list[i][0]).flatten()

xx_i = xx.flatten()[:, np.newaxis] + np.array(shift_list)[np.newaxis, :, 1]
yy_i = yy.flatten()[:, np.newaxis] + np.array(shift_list)[np.newaxis, :, 0]

xx_diff = xx_virt[:, np.newaxis, np.newaxis] - xx_i[np.newaxis, :]
yy_diff = yy_virt[:, np.newaxis, np.newaxis] - yy_i[np.newaxis, :]

dist = np.sqrt(xx_diff**2 + yy_diff**2)

In [214]:
file_path = 'D:\\Musterman_postdoc\\20240223_Musterman\\xrd_maps\\'


virt_indices = (8, 84)
virt_indices = (15, 137)
ravel_index = np.ravel_multi_index(virt_indices, virt_shape)
pixels_index = np.argmin(dist[ravel_index], axis=0)
pixel_indices = np.asarray(np.unravel_index(pixels_index, ref_img.shape)).T


img_stack = []
en_stack = []

for i, scanid in enumerate(base_scan + np.array(FXI_first)):
    with h5py.File(f'{file_path}scan{scanid}_xrd.h5', 'r') as f:
        base_grp = f['xrdmap/image_data']
        img = base_grp['raw_images'][tuple(pixel_indices[i])]

        img_stack.append(img.astype(np.float32))
        en_stack.append(f['xrdmap'].attrs['energy'])

In [166]:
dark_id = 153086
dark_dir = 'D:\\Musterman_postdoc\\20240223_Musterman\\dark_fields\\'
dir_mask = [str(dark_id) in d for d in os.listdir(dark_dir)]

dark_field = io.imread(f'{dark_dir}{np.array(os.listdir(dark_dir))[dir_mask][0]}').astype(np.float32)

In [168]:
test.plot_image(img_stack[-1])

In [104]:
from xrdmaptools.XRDMap import XRDMap

scanid = 153126
filedir = 'D:\\Musterman_postdoc\\20240223_Musterman\\xrd_maps\\'
h5_filename = f'scan{scanid}_xrd.h5'
test = XRDMap.from_hdf(h5_filename, wd=filedir, save_hdf=False, dask_enabled=False)

Connecting to databrokers...failed.
Loading data from hdf file...
Loading most recent images (raw_images)...done!
Loading reciprocal positions...done!
Loading scalers...done!
Loading positions...done!
Loading reflection spots...done!
Instantiating ImageMap...done!
Setting detector calibration...
XRD Map loaded!


In [139]:
test.map.images = np.array([])
test.map.dtype = np.float32

In [120]:
filedir = 'D:\\Musterman_postdoc\\20240223_Musterman\\calibrations\\new\\'

test.set_calibration('scan153043_dexela_calibration.poni', filedir=filedir)
#test.set_calibration('scan153220_dexela_calibration_ext.poni', filedir=filedir)

Setting detector calibration...
Calibration performed under different settings. Adjusting calibration.


In [215]:
for i in range(len(img_stack)):
    img_stack[i] -= dark_field

In [216]:
test.map.apply_polarization_correction(apply=False)
test.map.apply_solidangle_correction(apply=False)

for i in range(len(img_stack)):
    img_stack[i] /= test.map.polarization_correction
    img_stack[i] /= test.map.solidangle_correction
    img_stack[i] /= 1 / np.sin(np.radians(test.tth_arr / 2))

In [217]:
# Trim edges
for i in range(len(img_stack)):
    img_stack[i][0] = 0
    img_stack[i][-1] = 0
    img_stack[i][:, 0] = 0
    img_stack[i][:, -1] = 0

In [218]:
# This will keep adding values to img_stack. Maybe fix that...

q_coord_list = []

start_en = test.energy

# Add blank images at extrema
en_step = np.median(np.abs(np.diff(sorted(en_stack))))
en_min = np.min(en_stack) - en_step
en_max = np.max(en_stack) + en_step

en_stack.append(en_min)
en_stack.append(en_max)
img_stack.append(np.zeros_like(img_stack[0]))
img_stack.append(np.zeros_like(img_stack[0]))

for i in range(len(img_stack)):
    test._del_arr()
    test.energy = en_stack[i]
    #q_coord_list.append(test.q_arr)
    q_coord_list.append(q_arr(test))

test._del_arr()
test.energy = start_en

In [219]:
all_qx = []
all_qy = []
all_qz = []
all_int = []

for i in range(len(img_stack)):
    all_qx.extend(q_coord_list[i][0].flatten())
    all_qy.extend(q_coord_list[i][1].flatten())
    all_qz.extend(q_coord_list[i][2].flatten())
    all_int.extend(img_stack[i].flatten())


In [220]:
edges = []

q_coord_arr = np.asarray(q_coord_list)

# Low energy image edges
edges.append(q_coord_arr[np.argmin(en_stack), :, 0])
edges.append(q_coord_arr[np.argmin(en_stack), :, -1])
edges.append(q_coord_arr[np.argmin(en_stack), :, :, 0])
edges.append(q_coord_arr[np.argmin(en_stack), :, :, -1])

# High energy image edges
edges.append(q_coord_arr[np.argmax(en_stack), :, 0])
edges.append(q_coord_arr[np.argmax(en_stack), :, -1])
edges.append(q_coord_arr[np.argmax(en_stack), :, :, 0])
edges.append(q_coord_arr[np.argmax(en_stack), :, :, -1])

# Corners
edges.append(q_coord_arr[:, :, 0, 0].T)
edges.append(q_coord_arr[:, :, -1, 0].T)
edges.append(q_coord_arr[:, :, 0, -1].T)
edges.append(q_coord_arr[:, :, -1, -1].T)

In [222]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

#q_plot = dset
#q_plot = np.array([i.flatten() for i in out])
q_plot = np.asarray([all_qx, all_qy, all_qz, all_int])
q_plot = q_plot[:, ::50]

qx_min, qx_max = np.min(q_plot[0]), np.max(q_plot[0])
qy_min, qy_max = np.min(q_plot[1]), np.max(q_plot[1])
qz_min, qz_max = np.min(q_plot[2]), np.max(q_plot[2])

int_mask = q_plot[-1] > 15

ax.scatter(q_plot[0][int_mask],
           q_plot[1][int_mask],
           q_plot[2][int_mask],
           c=q_plot[3][int_mask], s=1, alpha=0.1)

for edge in edges:
    ax.plot(*edge, c='k', alpha=0.5)

ax.set_xlim(qx_min, qx_max)
ax.set_ylim(qy_min, qy_max)
ax.set_zlim(qz_min, qz_max)

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')
fig.show()

In [202]:
test.plot_image(img_stack[0])

In [233]:
from xrdmaptools.utilities.utilities import label_nearest_spots
from tqdm import tqdm

q_plot = np.asarray([all_qx, all_qy, all_qz, all_int])
q_plot = q_plot[:, ::50]

int_mask = q_plot[-1] > 25

out = label_nearest_spots(q_plot[:-1][:, int_mask].T, max_dist=0.1).T

In [237]:
import plotly.graph_objects as go

data = []

expansion = 0.1

# TODO: trim data prior to for loop for more accurate time estimate
for label in tqdm(np.unique(out[-1])):

    label_mask = out[-1] == label

    # Ignore clusters of less than 5 pixels
    if np.sum(label_mask) < 5:
        continue
    
    # Create mask of bounding box
    qx_min_i = np.min(out[0][label_mask]) - expansion
    qx_max_i = np.max(out[0][label_mask]) + expansion
    qx_mask = np.all([all_qx > qx_min_i,
                      all_qx < qx_max_i], axis=0)

    qy_min_i = np.min(out[1][label_mask]) - expansion
    qy_max_i = np.max(out[1][label_mask]) + expansion
    qy_mask = np.all([all_qy > qy_min_i,
                      all_qy < qy_max_i], axis=0)

    qz_min_i = np.min(out[2][label_mask]) - expansion
    qz_max_i = np.max(out[2][label_mask]) + expansion
    qz_mask = np.all([all_qy > qy_min_i,
                      all_qy < qy_max_i], axis=0)

    # Combined mask without already plotted values
    full_mask = np.all([qx_mask, qy_mask, qz_mask], axis=0)

    # Interpolate data to regular grid within bounding box
    grid_data = map_2_grid(np.asarray([np.array(all_qx)[full_mask],
                                       np.array(all_qy)[full_mask],
                                       np.array(all_qz)[full_mask],
                                       np.array(all_int)[full_mask]]),
                                       gridstep=0.01)
    grid_data = np.stack([x.flatten() for x in grid_data])

    # Logscaling
    plot_log = np.array(all_int.copy())
    plot_log[plot_log < 1e-3] = 1e-3
    plot_log = np.log(plot_log)

    # Generate graph trace
    data.append(go.Volume(
            x=grid_data[0],
            y=grid_data[1],
            z=grid_data[2],
            #value=grid_data[3],
            #isomin=25,
            #isomax=100,
            value = plot_log,
            isomin = 3,
            isomax = 8,

            opacity=0.1, # needs to be small to see through all surfaces
            surface_count=5, # needs to be a large number for good volume rendering
            ))


100%|██████████| 35/35 [07:33<00:00, 12.97s/it]


In [238]:
for edge in edges:

    data.append(go.Scatter3d(
        x = edge[0],
        y = edge[1],
        z = edge[2],
        mode='lines',
        opacity=0.5,
        line=dict(
            color='black',
            width=5
        )
    ))

In [239]:
fig = go.Figure(data=data)
fig.show(renderer='browser')

MemoryError: 

In [184]:
from matplotlib.widgets import Slider

A = []
for i in sorted(en_stack):
    im_i = img_stack[np.where(np.array(en_stack) == i)[0][0]]
    A.append(im_i)

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

#A = [x for _, x in sorted(zip(img_stack, en_stack))]

im = ax.imshow(A[0], vmin=0, vmax=10)

axidx = plt.axes([0.9, 0.2, 0.03, 0.6])
#slidx = Slider(axidx, 'Index', 0, len(A) - 1, valinit = 0, valfmt='%d', orientation='vertical')
slidx = Slider(axidx, 'Energy', np.min(en_stack), np.max(en_stack), valinit=0, valfmt='%d', orientation='vertical')

def update(val):
    energy = slidx.val
    idx = np.argmin(np.abs(sorted(en_stack) - energy))
    im.set_data(A[idx])
    fig.canvas.draw_idle()
slidx.on_changed(update)

fig.show()

In [175]:
plot_image(img_reg[0])

In [188]:
from scipy.interpolate import griddata

def map_2_grid(q_dset, gridstep=0.005):
    all_qx = q_dset[0]
    all_qy = q_dset[1]
    all_qz = q_dset[2]
    all_int = q_dset[3]

    # Find bounds
    x_min = np.min(all_qx)
    x_max = np.max(all_qx)
    y_min = np.min(all_qy)
    y_max = np.max(all_qy)
    z_min = np.min(all_qz)
    z_max = np.max(all_qz)

    # Generate q-space grid
    xx = np.linspace(x_min, x_max, int((x_max - x_min) / gridstep))
    yy = np.linspace(y_min, y_max, int((y_max - y_min) / gridstep))
    zz = np.linspace(z_min, z_max, int((z_max - z_min) / gridstep))

    grid = np.array(np.meshgrid(xx, yy, zz, indexing='ij'))
    grid = grid.reshape(3, -1).T

    points = np.array([all_qx, all_qy, all_qz]).T

    int_grid = griddata(points, all_int, grid, method='nearest')
    #int_grid = int_grid.reshape(yy.shape[0], xx.shape[0], zz.shape[0]).T
    int_grid = int_grid.reshape(xx.shape[0], yy.shape[0], zz.shape[0])

    return np.array([*np.meshgrid(xx, yy, zz, indexing='ij'), int_grid])

In [187]:
ind = 13
element = 'Sb_L'
ref_image = data_list[0][element] / data_list[0]['i0']
moving_image = data_list[ind][element] / data_list[ind]['i0']

In [188]:
from skimage.registration import phase_cross_correlation

shift, error, phasediff = phase_cross_correlation(ref_image, moving_image, upsample_factor=1000, normalization=None)
print(shift)

[-6.618 -0.104]


In [189]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

ax.imshow(ref_image, extent=[0, ref_image.shape[1], 0, ref_image.shape[0]])

ax.set_xlim(0, ref_image.shape[1])
ax.set_ylim(0, ref_image.shape[0])

fig.show()

In [190]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

ax.imshow(moving_image, extent=[0 + shift[1],
                                ref_image.shape[1] + shift[1],
                                0 - shift[0],
                                ref_image.shape[0] - shift[0]])

ax.set_xlim(0, ref_image.shape[1])
ax.set_ylim(0, ref_image.shape[0])

fig.show()

In [108]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

index = 3

scaler = data_list[index]['i0'].copy()
scaler[scaler < 1000] = np.mean(scaler)

moving_image = data_list[index][element] / scaler

ax.imshow(moving_image)

fig.show()

In [183]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

ax.plot(range(len(angle_list[3:])), [x[0] for x in shift_list[3:]])
ax.plot(range(len(angle_list[3:])), [x[1] for x in shift_list[3:]])

fig.show()

In [155]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)


im = ax.imshow(moving_image)
fig.colorbar(im, ax=ax)

fig.show()

In [161]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

index=20

scaler = data_list[index]['i0'].copy()
scaler[scaler < 1000] = np.mean(scaler)
scaler[scaler > 1e6] = np.median(scaler)

moving_image = data_list[index][element] / scaler

im = ax.imshow(moving_image, vmin=0.88*np.max(moving_image), vmax=0.35)
fig.colorbar(im, ax=ax)

fig.show()

In [175]:
save_dir = f'{base}plots//'

element_key = 'Sb_L'

def _plot_image(index):
    scaler = data_list[index]['i0'].copy()
    scaler[scaler < 1000] = np.mean(scaler)
    scaler[scaler > 1e6] = np.median(scaler)

    plot_img = data_list[index][element_key] / scaler
    plot_img = rotate(plot_img, -angle_list[index])
    shift = shift_list[index]
    #print(shift)

    element_img = ax.imshow(plot_img, extent=[0 + shift[1],
                                              ref_image.shape[1] + shift[1],
                                              0 - shift[0],
                                              ref_image.shape[0] - shift[0]])
    
    return element_img

for ind in range(len(data_list)):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

    ax.set_xlim(0, ref_image.shape[1])
    ax.set_ylim(0, ref_image.shape[0])

    _plot_image(ind)

    plt.savefig(f'{save_dir}BNL_inner_aligned_{ind}.png', transparent=True)
    plt.close()

In [174]:
element_key = 'Sb_L'
ref_image = data_list[0][element] / data_list[0]['i0']


fig = plt.figure(figsize=(5, 5), dpi=200)
ax = fig.add_axes([0.1, 0.1, 0.5, 0.8])

def _plot_image(index):
    scaler = data_list[index]['i0'].copy()
    scaler[scaler < 1000] = np.mean(scaler)
    scaler[scaler > 1e6] = np.median(scaler)

    plot_img = data_list[index][element_key] / scaler
    plot_img = rotate(plot_img, -angle_list[index])
    shift = shift_list[index]
    #print(shift)

    element_img = ax.imshow(plot_img, extent=[0 + shift[1],
                                              ref_image.shape[1] + shift[1],
                                              0 - shift[0],
                                              ref_image.shape[0] - shift[0]],
                                              vmin=0.2, vmax=0.35)
    
    return element_img

element_img = _plot_image(0)
#element_img = ax.imshow(data_list[0][element_key] / data_list[0]['i0'])

ax.set_xlim(0, ref_image.shape[1])
ax.set_ylim(0, ref_image.shape[0])

slider_ax = fig.add_axes([0.7, 0.1, 0.03, 0.8])
energy_slider = Slider(
    ax=slider_ax,
    label='Index',
    valmin=0,
    valmax=len(energy_list) - 1,
    valinit=0,
    orientation='vertical'
)

# The function to be called anytime a slider's value changes
def update(val):
    global element_img
    val_ind = int(np.round(val, 0))
    element_img.remove()
    element_img = _plot_image(val_ind)
    #element_img = ax.imshow(data_list[val_ind][element_key] / data_list[val_ind]['i0'])
    fig.canvas.draw_idle()
    
energy_slider.on_changed(update)

fig.show()

In [246]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

ax.imshow(ref_image, extent=[0, ref_image.shape[1], 0, ref_image.shape[0]])

ax.set_xlim(0, ref_image.shape[1])
ax.set_ylim(0, ref_image.shape[0])

fig.show()

In [247]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200)

ax.imshow(rotate(moving_image, -recovered_angle), extent=[0 + shift[1],
                                ref_image.shape[1] + shift[1],
                                0 - shift[0],
                                ref_image.shape[0] - shift[0]])

ax.set_xlim(0, ref_image.shape[1])
ax.set_ylim(0, ref_image.shape[0])

fig.show()