In [1]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
import skimage.io as io
from tqdm import tqdm
from scipy.spatial import KDTree, distance_matrix
from tqdm import tqdm
from itertools import product, combinations
%matplotlib qt

In [2]:
import xrdmaptools
from xrdmaptools.XRDMap import XRDMap
from xrdmaptools.RockingCurveStack import RockingCurveStack
from xrdmaptools.geometry.geometry import get_q_vect

  _create_built_program_from_source_cached(


Connecting to databrokers...failed.


In [3]:
def read_metadata(filename, filedir):
    import json

    with open(f'{filedir}{filename}', 'r') as f:
        json_str = f.read()
        out = json.loads(json_str)
    return out

In [4]:
%run -i "C:\Users\emusterma\OneDrive - Brookhaven National Laboratory\Documents\Postdoc\Repositories\SRX_sXRD_analysis\xrdmaptools\RockingCurveStack.py"
plt.close()

In [5]:
scan_range = '156229-156251'
scan_range = '156253-156275'
scan_range = '156277-156299'
scan_range = '156301-156323'
scan_range = '156325-156347'
scan_range = '156349-156371'
scan_range = '156373-156395'

filedir = 'E:\\Musterman_data\\20240610\\energy_rc\\'
filename = f'scan{scan_range}_dexela_energy_rc.tif'

energy, i0, im, it = np.genfromtxt(filedir + 'scan156229-156251_energy_rc_parameters.txt')
sclr_dict = dict(zip(['i0', 'im', 'it'], [i0, im, it]))
md = read_metadata(f'scan{scan_range}_energy_rc_metadata.txt', filedir)
base_md = {key:value for key, value in md.items() if key in ['scan_id', 'theta', 'dwell']}
extra_md = {key:value for key, value in md.items() if key not in base_md.keys()}
base_md['scanid'] = base_md['scan_id']
del base_md['scan_id']

rsm = RockingCurveStack.from_image_stack(filename,
                                         wd=filedir,
                                         energy=energy,
                                         #theta=None,
                                         sclr_dict=sclr_dict,
                                         **base_md,
                                         extra_metadata=extra_md,
                                         save_hdf=False)

Loading images...done!


In [6]:
calib_dir = 'E:\\Musterman_data\\20240610\\calibrations\\'
rsm.set_calibration('scan156160_dexela_calibration.poni', filedir=calib_dir)

Setting detector calibration...
Calibration performed under different settings. Adjusting calibration.


In [7]:
dark_id = 156203
dark_dir = 'E:\\Musterman_data\\20240610\\dark_fields\\'
dir_mask = [str(dark_id) in d for d in os.listdir(dark_dir)]

dark_field = io.imread(f'{dark_dir}{np.array(os.listdir(dark_dir))[dir_mask][0]}').astype(np.float32)
rsm.map.correct_dark_field(dark_field=dark_field)

Correcting dark-field...done!


In [8]:
rsm.map.normalize_scaler()

Normalizing images by i0 scaler...done!


In [9]:
rsm.map.apply_polarization_correction()
rsm.map.apply_solidangle_correction()

Applying X-ray polarization correction...done!
Applying solid angle correction...done!


In [10]:
rsm.map.estimate_background(method='bruckner', binning=8, min_prominence=0.1)
rsm.map.remove_background()

Estimating background with Bruckner algorithm.


100%|██████████| 241/241 [01:12<00:00,  3.33it/s]

Removing background...




done!


In [11]:
rsm.map.rescale_images(arr_max=rsm.map.estimate_saturated_pixel())

In [12]:
rsm.find_blobs(threshold_method='minimum',
               multiplier=5,
               size=3,
               expansion=5,
               override_rescale=True)

Searching images for blobs...


100%|██████████| 241/241 [00:53<00:00,  4.51it/s]


In [13]:
edges = ([[] for _ in range(12)])
full_q_arr = np.empty((rsm.map.num_images, 3, *rsm.map.image_shape),
                      dtype=rsm.map.dtype)
for i, wavelength in tqdm(enumerate(rsm.wavelength),
                          total=rsm.map.num_images):
    q_arr = get_q_vect(rsm.tth_arr,
                       rsm.chi_arr,
                       wavelength=wavelength,
                       degrees=rsm.polar_units == 'deg')
    full_q_arr[i] = q_arr

edges = [None,] * 12
# First image edges
edges[0] = full_q_arr[0, :, 0, :]
edges[1] = full_q_arr[0, :, -1, :]
edges[2] = full_q_arr[0, :, :, 0]
edges[3] = full_q_arr[0, :, :, -1]
# Last image edges
edges[4] = full_q_arr[-1, :, 0, :]
edges[5] = full_q_arr[-1, :, -1, :]
edges[6] = full_q_arr[-1, :, :, 0]
edges[7] = full_q_arr[-1, :, :, -1]
# Image corners
edges[8] = full_q_arr[:, :, 0, 0].T
edges[9] = full_q_arr[:, :, 0, -1].T
edges[10] = full_q_arr[:, :, -1, 0].T
edges[11] = full_q_arr[:, :, -1, -1].T

q_mins = np.min(full_q_arr, axis=(0, 2, 3))
q_maxs = np.max(full_q_arr, axis=(0, 2, 3))

100%|██████████| 241/241 [01:26<00:00,  2.79it/s]


In [14]:
qx = full_q_arr[:, 0][rsm.map.blob_masks.squeeze()]
qy = full_q_arr[:, 1][rsm.map.blob_masks.squeeze()]
qz = full_q_arr[:, 2][rsm.map.blob_masks.squeeze()]
qs = np.asarray([qx, qy, qz]).T # May not need transpose
intensity = rsm.map.images[rsm.map.blob_masks]

In [15]:
from xrdmaptools.reflections.spot_blob_search_3d import rsm_blob_search, rsm_spot_search

In [16]:
labels = rsm_blob_search(qs, max_dist=0.05, max_neighbors=5, subsample=20)

Scheduling blob search...


100%|██████████| 490656/490656 [00:11<00:00, 42408.16it/s]


Finding blobs...


100%|██████████| 490656/490656 [02:33<00:00, 3195.98it/s]


Upsampling data...


100%|██████████| 9813119/9813119 [06:14<00:00, 26231.54it/s]


In [None]:
int_mask = intensity > np.min(intensity) + 0.01 * (np.max(intensity) - np.min(intensity))
int_labels = rsm_blob_search(qs[int_mask], max_dist=0.05, max_neighbors=5, subsample=10)

Scheduling blob search...


100%|██████████| 639911/639911 [00:11<00:00, 55980.13it/s]


Finding blobs...


100%|██████████| 639911/639911 [03:53<00:00, 2745.54it/s]


Upsampling data...


100%|██████████| 6399103/6399103 [03:55<00:00, 27165.97it/s]


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection': '3d'})

skip = 250

ax.scatter(*qs[::skip].T, c=intensity[::skip], s=1, alpha=0.01)

#ax.scatter(*edges, s=1, c='gray')

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection': '3d'})

skip = 1
blob_label = 1
another_mask = (labels == blob_label) & int_mask

ax.scatter(*qs[another_mask][::skip].T, c=intensity[another_mask][::skip], s=1, alpha=0.1)

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [16]:
from scipy.interpolate import griddata

def map_2_grid(qs, intensity, gridstep=0.005):

    # Find bounds
    x_min = np.min(qs[:, 0])
    x_max = np.max(qs[:, 0])
    y_min = np.min(qs[:, 1])
    y_max = np.max(qs[:, 1])
    z_min = np.min(qs[:, 2])
    z_max = np.max(qs[:, 2])

    # Generate q-space grid
    xx = np.linspace(x_min, x_max, int((x_max - x_min) / gridstep))
    yy = np.linspace(y_min, y_max, int((y_max - y_min) / gridstep))
    zz = np.linspace(z_min, z_max, int((z_max - z_min) / gridstep))

    grid = np.array(np.meshgrid(xx, yy, zz, indexing='ij'))
    grid = grid.reshape(3, -1).T

    int_grid = griddata(qs, intensity, grid, method='nearest')
    #int_grid = int_grid.reshape(yy.shape[0], xx.shape[0], zz.shape[0]).T
    int_grid = int_grid.reshape(xx.shape[0], yy.shape[0], zz.shape[0])

    return np.array([*np.meshgrid(xx, yy, zz, indexing='ij'), int_grid])

In [None]:
X, Y, Z, int_grid = map_2_grid(spot_qs[int_mask], spot_ints[int_mask], gridstep=0.00125)

In [None]:
import plotly.graph_objects as go

plot_grid = int_grid
#plot_grid[plot_grid < 1e-3] = 1e-3
#plot_grid = np.log(plot_grid).flatten()
data = []

data.append(go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=plot_grid.flatten(),
    isomin=np.min(int_grid) + 2,
    isomax=np.max(int_grid) / 2,
    opacity=0.1, # needs to be small to see through all surfaces
    surface_count=25, # needs to be a large number for good volume rendering
    colorscale='viridis'
    ))

spot_mask = label_ints >= np.min(label_ints) + 0.00 * (np.max(label_ints) - np.min(label_ints))

data.append(go.Scatter3d(
    x = np.asarray(spots)[spot_mask, 0],
    y = np.asarray(spots)[spot_mask, 1],
    z = np.asarray(spots)[spot_mask, 2],
    mode='markers',
    opacity=1,
    marker=dict(
        size=3,
        color='red'
    )
))

fig = go.Figure(data=data)

x_range = np.max(X) - np.min(X)
y_range = np.max(Y) - np.min(Y)
z_range = np.max(Z) - np.min(Z)

fig.update_layout(scene_aspectmode='manual',
                  scene_aspectratio=dict(x=x_range, y=y_range, z=z_range))

fig.show(renderer='browser')

In [37]:
int_mask = intensity > np.min(intensity) + 0.05 * (np.max(intensity) - np.min(intensity))

spot_labels, spots, label_ints = rsm_spot_search(qs[int_mask], intensity[int_mask], nn_dist=0.005, significance=0.1, subsample=5)

Finding spots...


100%|██████████| 59429/59429 [00:12<00:00, 4653.88it/s]


Upsampling data...


100%|██████████| 297149/297149 [00:08<00:00, 34971.00it/s]


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection': '3d'})

skip = 5
blob_label = 5
int_mask = intensity > np.min(intensity) + 0.05 * (np.max(intensity) - np.min(intensity))
another_mask = (labels == blob_label) & int_mask

ax.scatter(*qs[another_mask][::skip].T, c=spot_labels[(labels == blob_label)[int_mask]][::skip], s=1, cmap='tab20')

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [None]:
import plotly.graph_objects as go

blob_label = np.nonzero(blob_ints == sorted(blob_ints, reverse=True)[4])[0][0]
int_mask = intensity > np.min(intensity) + 0.05 * (np.max(intensity) - np.min(intensity))
another_mask = (labels == blob_label) & int_mask
X, Y, Z, int_grid = map_2_grid(qs[another_mask], intensity[another_mask], gridstep=0.0025)

data = []

data.append(go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=int_grid.flatten(),
    isomin=np.min(int_grid) + 2,
    isomax=np.max(int_grid) / 2,
    opacity=0.1, # needs to be small to see through all surfaces
    surface_count=25, # needs to be a large number for good volume rendering
    colorscale='viridis'
    ))

x_range = np.max(X) - np.min(X)
y_range = np.max(Y) - np.min(Y)
z_range = np.max(Z) - np.min(Z)

spot_mask = label_ints >= np.min(label_ints) + 0.05 * (np.max(label_ints) - np.min(label_ints))
spot_x = np.asarray(spots)[spot_mask, 0]
spot_y = np.asarray(spots)[spot_mask, 1]
spot_z = np.asarray(spots)[spot_mask, 2]

extent_mask = np.all([
    np.min(X) < spot_x, spot_x < np.max(X),
    np.min(Y) < spot_y, spot_y < np.max(Y),
    np.min(Z) < spot_z, spot_z < np.max(Z),
], axis=0)

data.append(go.Scatter3d(
    x = spot_x[extent_mask],
    y = spot_y[extent_mask],
    z = spot_z[extent_mask],
    mode='markers',
    opacity=1,
    marker=dict(
        size=3,
        color='red'
    )
))

fig = go.Figure(data=data)

fig.update_layout(scene_aspectmode='manual',
                  scene_aspectratio=dict(x=x_range, y=y_range, z=z_range))

fig.show(renderer='browser')

In [179]:
spot_mask = label_ints >= np.min(label_ints) + 0.001 * (np.max(label_ints) - np.min(label_ints))
# sorted_ints = sorted(np.asarray(label_ints)[spot_mask])
# sorted_spots = [x for _, x in sorted(zip(np.asarray(label_ints)[spot_mask],
#                                          np.asarray(spots)[spot_mask]),
#                                          key=lambda pair: pair[0])]

In [180]:
np.sum(spot_mask)

510

In [65]:
from xrdmaptools.utilities.utilities import arbitrary_center_of_mass

labels = rsm_blob_search(np.asarray(spots)[spot_mask], max_dist=0.25)
new_spots = [arbitrary_center_of_mass(np.asarray(label_ints)[spot_mask][labels == label],
                                      *np.asarray(spots)[spot_mask][labels == label].T
                                      ) for label in np.unique(labels)]

Scheduling blob search...


100%|██████████| 894/894 [00:00<00:00, 63079.67it/s]


Finding blobs...


100%|██████████| 894/894 [00:00<00:00, 6591.50it/s]


In [66]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

ax.scatter(*np.asarray(spots)[spot_mask].T, s = 1, c=labels, cmap='tab20')

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.scatter(*np.asarray(new_spots).T, s=10, c='r')

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [20]:
from xrdmaptools.crystal.Phase import Phase
cif_dir = '''C:\\Users\\emusterma\\OneDrive - Brookhaven National Laboratory\\Documents\\Postdoc\\Literature\\CIF\\'''
stibnite = Phase.fromCIF(cif_dir + 'AMCSD\\Stibnite_0008636.cif')
stibnite.energy = 15
stibnite.get_hkl_reflections()
corundum = Phase.fromCIF(cif_dir + 'AMCSD\\Corundum_0009327.cif')
corundum.energy = 15
corundum.get_hkl_reflections()

In [21]:
from xrdmaptools.crystal.Phase import generate_reciprocal_lattice
ref_hkls, ref_qs, ref_fs = generate_reciprocal_lattice(stibnite, 8)

In [181]:
from xrdmaptools.crystal.Phase import generate_reciprocal_lattice
from itertools import combinations
from scipy.spatial import distance_matrix

spot_qs = np.asarray(spots)[spot_mask]
#spot_ints = np.asarray(label_ints)[spot_mask]

# spot_qs = np.asarray(new_spots)
#spot_qs = np.asarray(corr_spot_qs)

phase = stibnite
near_q = 0.1
near_angle = 2

near_q = 0.01
near_angle = 1
    
spot_q_mags = np.linalg.norm(spot_qs, axis=1)
max_q = np.max(spot_q_mags)

# Combine these at some point...
stibnite.get_hkl_reflections()
ref_hkls, ref_qs, ref_fs = generate_reciprocal_lattice(stibnite, 1.15 * max_q) # 15% window
# ref_qs = ref_qs @ (np.eye(3) - eij_full)
ref_q_mags = np.linalg.norm(ref_qs, axis=1)

# Minimum step size in q-space.
min_q = np.min(np.linalg.norm(phase.Q([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), axis=0))
if near_q > min_q * 0.85:
    raise ValueError('Near_q threshold is greater 85% of minimum lattice spacing. This seems unwise...')

# Find difference between measured and calculated q magnitudes
mag_diff_arr = np.abs(spot_q_mags[:, np.newaxis]
                      - ref_q_mags[np.newaxis, :])

# Eliminate any reflections outside phase-allowed spots
phase_mask = np.any(mag_diff_arr < near_q, axis=1)
mag_diff_arr = mag_diff_arr[phase_mask]
spot_qs = spot_qs[phase_mask]
spot_q_mags = spot_q_mags[phase_mask]
#spot_ints = spot_ints[phase_mask]

# Generate all pairs of spots which are crystallographically feasible
spot_pair_indices = list(combinations(range(len(spot_qs)), 2))
#spot_diff_arr = np.abs(spot_q_mags[:, np.newaxis]
#                       - spot_q_mags[np.newaxis, :])
spot_pair_dist = distance_matrix(spot_qs, spot_qs)
allowed_pairs = [spot_pair_dist[tuple(indices)] > min_q * 0.85 for indices in spot_pair_indices]
#allowed_pairs = [spot_pair_dist[indices] > min_q * 0.85 for indices in spot_pair_indices]
spot_pair_indices = np.asarray(spot_pair_indices)[allowed_pairs]

# Determine all angles
spot_angles = multi_vector_angles(spot_qs, spot_qs, degrees=True)
ref_angles = multi_vector_angles(ref_qs, ref_qs, degrees=True)

valid_pairs, valid_combos = [], []
for pair in tqdm(spot_pair_indices):
    ref_combos = list(product(*[np.nonzero(mag_diff_arr[i] < near_q)[0] for i in pair]))

    angle_mask = [np.abs(spot_angles[tuple(pair)] - ref_angles[tuple(combo)]) < near_angle for combo in ref_combos]
    doublet_mask = [combo[0] != combo[1] for combo in ref_combos]

    ref_combos = np.asarray(ref_combos)[np.array(angle_mask) & np.array(doublet_mask)]
    
    if len(ref_combos) > 0:
        valid_pairs.append(tuple(pair))
        valid_combos.append([tuple(combo) for combo in ref_combos])

100%|██████████| 112697/112697 [00:24<00:00, 4653.27it/s]


In [182]:
expanded_pair = [np.nan,] * len(spot_qs)
expanded_pair_list = []
for pair, combos in zip(valid_pairs, valid_combos):
    for combo in combos:
        expanded_pair_i = expanded_pair.copy()
        expanded_pair_i[pair[0]] = combo[0]
        expanded_pair_i[pair[1]] = combo[1]
        expanded_pair_list.append(expanded_pair_i)
        
expanded_pair_list = np.asarray(expanded_pair_list)
expanded_pair_list.shape

(159660, 508)

In [183]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')

skip = np.round(len(expanded_pair_list) / 1000, 0).astype(int)

for idx in tqdm(range(len(expanded_pair_list[::skip]))):
    ax.plot(*spot_qs[~np.isnan(expanded_pair_list[::skip][idx])].T, c='k', alpha=0.01, lw=1)

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

100%|██████████| 998/998 [00:00<00:00, 2017.86it/s]


In [26]:
def nan_combine(*connections):

    connections = np.asarray(connections)
    if np.any(np.sum(~np.isnan(connections), axis=0) > 1):
        raise RuntimeError('Overwritten connections!')

    base_connection = np.asarray([np.nan,] * len(connections[0]))

    for connection in connections:
        base_connection[~np.isnan(connection)] = connection[~np.isnan(connection)]

    return base_connection

In [27]:
def _find_next_connections(full_pairs,
                           partial_connections,
                           common_full,
                           common_partial=None,
                           ):
    
    full_pairs = np.asarray(full_pairs)
    partial_connections = np.asarray(partial_connections)
    
    # Must have at least two partials to advance connection rank
    if len(partial_connections) < 2:
        return [], []
    
    common_indices = np.nonzero(~np.isnan(common_full))[0]
    common_vals = common_full[common_indices]

    if common_partial is not None:
        common_index = np.nonzero(~np.isnan(common_partial))[0][0]
        common_val = common_partial[common_index]
        common_mask = full_pairs[:, common_index] == common_val
    else:
        common_mask = [True, ] * len(full_pairs)
    
    next_partials = []
    valid_connections = []
    for i in range(len(partial_connections)):
        partial = partial_connections[i]
        part_index = np.nonzero(~np.isnan(partial))[0][0]
        part_val = partial[part_index]

        # Ignore indexing repeated spots or references reflections
        if (part_index in common_indices
            or part_val in common_vals):
            continue
            
        next_conn_mask = full_pairs[:, part_index] == part_val
        
        if np.sum(common_mask & next_conn_mask) > 0:
            next_partials.append(partial)
            valid_connections.append(nan_combine(common_full, partial))

    return next_partials, valid_connections

In [184]:
from scipy.spatial.transform import Rotation
from xrdmaptools.geometry.geometry import q_2_polar

def get_rmse(spots0, spots1):
    rmse = np.mean([np.sqrt(np.sum([(p - q)**2 for p, q in zip(v1, v2)]))
                    for v1, v2 in zip(spots0, spots1)])
    return rmse

def fit_orientation_index(connection,
                          spot_qs,
                          ref_qs):
    fit_spot_indices = np.nonzero(~np.isnan(connection))
    fit_spot_qs = spot_qs[fit_spot_indices]

    fit_ref_indices = connection[fit_spot_indices].astype(int)
    fit_ref_qs = np.asarray(ref_qs)[fit_ref_indices]

    fit_orientation, fit_rssd = Rotation.align_vectors(fit_ref_qs, fit_spot_qs)

    rmse = get_rmse(fit_ref_qs,
                    fit_orientation.apply(fit_spot_qs, inverse=False))

    return fit_orientation, rmse

print('Determining all orientations from valid pairs...')
pair_orientations = []
pair_mis_mag = []
pair_rmse = []
for pair_i in tqdm(range(len(expanded_pair_list))):
    pair_ref_hkls = [ref_hkls[int(ind)] for ind
                    in expanded_pair_list[pair_i][~np.isnan(expanded_pair_list[pair_i])]]
    pair_ref_qs = [ref_qs[int(ind)] for ind
                   in expanded_pair_list[pair_i][~np.isnan(expanded_pair_list[pair_i])]]
    pair_spot_qs = spot_qs[np.nonzero(~np.isnan(expanded_pair_list[0]))]

    # Check for colinearity; 3D orientation cannot be determined
    pair_divs = np.array(pair_ref_hkls[0]) / np.array(pair_ref_hkls[1])
    if len(np.unique(pair_divs[~np.isnan(pair_divs)])) < 2:
        pair_orientations.append(np.nan) # assume validity
        pair_rmse.append(np.nan)
        pair_mis_mag.append(np.nan)
        continue
    
    orientation, rmse = fit_orientation_index(expanded_pair_list[pair_i],
                                              spot_qs,
                                              ref_qs)

    pair_orientations.append(orientation)
    pair_mis_mag.append(np.degrees(orientation.magnitude()))

    pair_rmse.append(rmse)

pair_orientations = np.asarray(pair_orientations)
pair_mis_mag = np.asarray(pair_mis_mag)
pair_rmse = np.asarray(pair_rmse)

print('Reducing symmetrically equivalent pairs...')
eval_pair_mask = np.array([True,] * len(expanded_pair_list))
keep_pair_mask = eval_pair_mask.copy()
min_wavelength, max_wavelength = np.min(rsm.wavelength), np.max(rsm.wavelength)
min_tth, max_tth = np.min(rsm.tth_arr), np.max(rsm.tth_arr)
min_chi, max_chi = np.min(rsm.chi_arr), np.max(rsm.chi_arr)
for pair_i in tqdm(range(len(expanded_pair_list))):

    if not eval_pair_mask[pair_i]:
        continue
    
    # Cannot symmetrically reduce orientations that cannot be determined
    if np.isnan(pair_rmse[pair_i]):
        eval_pair_mask[pair_i] = False
        continue
    
    # TESTING FEATURE
    # Discard poorly fitting pairs
    if pair_rmse[pair_i] > min_q:
        eval_pair_mask[pair_i] = False
        keep_pair_mask[pair_i] = False
        print('Horrible fitting')
        continue

    # Test to make sure all points are in probed volume
    IN_SCAN_RANGE = True
    pair_ref_qs = ref_qs[expanded_pair_list[pair_i][np.nonzero(~np.isnan(expanded_pair_list[pair_i]))[0]].astype(int)]
    rot_qs = pair_ref_qs @ pair_orientations[pair_i].as_matrix()
    tth, chi, wavelength = q_2_polar(rot_qs, degrees=True)
    tth_mask = np.any([tth < min_tth * 0.85,
                       tth > max_tth * 1.15],
                       axis=0)
    chi_mask = np.any([chi < min_chi * 0.85,
                       chi > max_chi * 1.15],
                       axis=0)
    wavelength_mask = np.any([wavelength < min_wavelength * 0.85,
                              wavelength > max_wavelength * 1.15],
                              axis=0)
    
    if np.any([tth_mask, chi_mask, wavelength_mask]):
        IN_SCAN_RANGE = False
        #IN_SCAN_RANGE = True

    # def trunc(values, decs=0):
    #     return np.trunc(values * 10**decs) / (10**decs)

    # similar_pair_mask = trunc(pair_rmse, 13) == trunc(pair_rmse[pair_i], 13)
    similar_pair_mask = np.round(pair_rmse, 10) == np.round(pair_rmse[pair_i], 10)
    #sym_pairs = expanded_pair_list[similar_pair_mask]
    eval_pair_mask[similar_pair_mask] = False
    keep_pair_mask[similar_pair_mask] = False
    
    if IN_SCAN_RANGE and np.sum(similar_pair_mask) > 1:
        # most_positive_index = np.argmax([np.sign(np.asarray(ref_hkls)[pair[np.nonzero(~np.isnan(pair))[0]].astype(int)]).sum()
        #                                 for pair in sym_pairs])
        # most_positive_index = np.sign(np.asarray(ref_hkls)[sym_pairs[np.nonzero(~np.isnan(sym_pairs))].astype(int)].reshape(-1, 2, 3)).sum(axis=(1, 2)).argmax()

        min_mis_mag = np.min(pair_mis_mag[similar_pair_mask])
        mis_ind = np.nonzero(pair_mis_mag[similar_pair_mask] < min_mis_mag + near_angle)[0] # some wiggle room

        keep_pair_mask[np.nonzero(similar_pair_mask)[0][mis_ind]] = True

# pair_orientations = pair_orientations[keep_pair_mask]
# pair_mis_mag = pair_mis_mag[keep_pair_mask]
# pair_rmse = pair_rmse[keep_pair_mask]
np.sum(keep_pair_mask)

Determining all orientations from valid pairs...


100%|██████████| 159660/159660 [00:19<00:00, 8163.48it/s]


Reducing symmetrically equivalent pairs...


100%|██████████| 159660/159660 [01:08<00:00, 2321.35it/s]


80188

In [185]:
# Quicker cast and search from set of valid pairs
from scipy.spatial import KDTree
from scipy.optimize import curve_fit

def generate_q_mask(qs,
                    tth_ext,
                    chi_ext,
                    wavelength_ext,
                    q_ext,
                    ext=0.15,
                    degrees=False):
    
    # Check extent parameters
    for param in [tth_ext, chi_ext, wavelength_ext]:
        if len(param) != 2:
            raise ValueError('Input extents must be of length 2.')
        if param[0] > param[1]:
            raise ValueError('Input extents must be (minimum, maximum)')
        
    # Check qs shape

    q_ext_mask = np.all([
        np.all([qs[:, 0] > q_ext[0][0] * (1 - ext),
                qs[:, 0] < q_ext[1][0] * (1 + ext)], axis=0),
        np.all([qs[:, 1] > q_ext[0][1] * (1 - ext),
                qs[:, 1] < q_ext[1][1] * (1 + ext)], axis=0),
        np.all([qs[:, 2] > q_ext[0][2] * (1 - ext),
                qs[:, 2] < q_ext[1][2] * (1 + ext)], axis=0),
    ], axis=0)
    
    tth, chi, wavelength = q_2_polar(qs, degrees=degrees)

    tth_mask = np.any([tth < tth_ext[0] * (1 - ext),
                       tth > tth_ext[1] * (1 + ext)],
                       axis=0)
    chi_mask = np.any([chi < chi_ext[0] * (1 - ext),
                       chi > chi_ext[1] * (1 + ext)],
                       axis=0)
    wavelength_mask = np.any([wavelength < wavelength_ext[0] * (1 - ext),
                              wavelength > wavelength_ext[1] * (1 + ext)],
                              axis=0)

    # q_mask = ~(tth_mask | chi_mask | wavelength_mask)
    q_mask = q_ext_mask & ~(tth_mask | chi_mask | wavelength_mask)

    return q_mask


# When we don't know, we through higher order polynomials
# at it until it works
def poly6(x, a, b, c, d, e, f, g):
    return a * x**6 + b * x**5 + c * x**4 + d * x**3 + e * x**2 + f * x + g

chi_upr_popt, _ = curve_fit(poly6, rsm.tth_arr[0], rsm.chi_arr[0])
chi_lwr_popt, _ = curve_fit(poly6, rsm.tth_arr[-1], rsm.chi_arr[-1])
tth_lwr_popt, _ = curve_fit(poly6, rsm.chi_arr[:, 0], rsm.tth_arr[:, 0])
tth_upr_popt, _ = curve_fit(poly6, rsm.chi_arr[:, -1], rsm.tth_arr[:, -1])

def generate_q_mask(qs,
                    tth_ext,
                    chi_ext,
                    wavelength_ext,
                    popts=[chi_upr_popt,
                           chi_lwr_popt,
                           tth_upr_popt,
                           tth_lwr_popt],
                    ext=0.05
                    ):
    
    # Check extent parameters
    for param in [tth_ext, chi_ext, wavelength_ext]:
        if len(param) != 2:
            raise ValueError('Input extents must be of length 2.')
        if param[0] > param[1]:
            raise ValueError('Input extents must be (minimum, maximum)')
        
    tth, chi, wavelength = q_2_polar(qs, degrees=True)

    chi_upr_mask = chi <= poly6(tth, *popts[0]) * (1 + ext)
    chi_lwr_mask = chi >= poly6(tth, *popts[1]) * (1 - ext)
    tth_upr_mask = tth <= poly6(chi, *popts[2]) * (1 + ext)
    tth_lwr_mask = tth >= poly6(chi, *popts[3]) * (1 - ext)

    tth_mask = np.all([tth >= tth_ext[0] * (1 - ext),
                       tth <= tth_ext[1] * (1 + ext)],
                       axis=0)
    chi_mask = np.all([chi >= chi_ext[0] * (1 - ext),
                       chi <= chi_ext[1] * (1 + ext)],
                       axis=0)

    wavelength_mask = np.all([
        wavelength >= wavelength_ext[0] * (1 - ext),
        wavelength <= wavelength_ext[1] * (1 + ext)
    ], axis=0)

    q_mask = np.all([chi_lwr_mask,
                     chi_upr_mask,
                     tth_lwr_mask,
                     tth_upr_mask,
                     tth_mask,
                     chi_mask,
                     wavelength_mask], axis=0)
    
    return q_mask


def fixed_pair_casting_indexing(pair_list,
                                spot_qs,
                                ref_qs,
                                iter_max=50):

    full_connections = []
    full_orientations = []
    full_connections_rmse = []
    for i in tqdm(range(len(pair_list))):

        pair = pair_list[i]
        # pair_orientation = pair_orientations[keep_pair_mask][i]
        pair_spot_inds = np.nonzero(~np.isnan(pair))[0]
        pair_ref_inds = pair[pair_spot_inds]
        prev_connection = pair.copy()

        iter_count = 0
        iter_max = 50
        ITERATE = True
        while ITERATE:
            connection = pair.copy()
            orientation, rmse = fit_orientation_index(prev_connection,
                                                      spot_qs,
                                                      ref_qs)

            rot_qs = ref_qs @ orientation.as_matrix()
            q_mask = generate_q_mask(rot_qs,
                                     (min_tth, max_tth),
                                     (min_chi, max_chi),
                                     (min_wavelength, max_wavelength),
                                     (q_mins, q_maxs),
                                     degrees=True)

            # kdtree built from spots so we can query the reference lattice and avoid non-crystallographic indexing
            kdtree = KDTree(spot_qs)
            pot_conn = kdtree.query_ball_point(rot_qs[q_mask], r=near_q)
            # Remove original pair reflections
            for ind in pair_ref_inds:
                if ind in np.nonzero(q_mask)[0]:
                    pot_conn[np.nonzero(np.nonzero(q_mask)[0] == ind)[0][0]] = []

            # Expand connection
            for conn_i, conn in enumerate(pot_conn):
                if len(conn) > 0:
                    # Remove reflections near original pair
                    for ind in pair_spot_inds:
                        if ind in conn:
                            conn.remove(ind)
                    if len(conn) == 0:
                        continue
                    elif len(conn) == 1:
                        # Add candidate reflection
                        connection[conn[0]] = np.nonzero(q_mask)[0][conn_i]
                    else:
                        # Add closest of multiple candidate reflections
                        ref_dist, ref_idx = kdtree.query(rot_qs[q_mask][conn_i])
                        connection[ref_idx] = np.nonzero(q_mask)[0][conn_i]
            
            # Compare connection with previous connection
            connection_spots = np.nonzero(~np.isnan(connection))[0]
            prev_connection_spots = np.nonzero(~np.isnan(prev_connection))[0]

            if len(connection_spots) == len(prev_connection_spots):
                connection_refs = connection[connection_spots]
                prev_connection_refs = prev_connection[prev_connection_spots]

                if (np.all(connection_spots == prev_connection_spots)
                    and np.all(connection_refs == prev_connection_refs)):
                    ITERATE = False

            prev_connection = connection.copy()
            iter_count += 1
            if iter_count >= iter_max:
                ITERATE = False
                # Re-update orientation
                orientation, rmse = fit_orientation_index(connection, 
                                                          spot_qs,
                                                          ref_qs)
                print(f'Max iterations reached for pair {i}.')

        full_connections.append(connection)
        full_orientations.append(orientation)
        full_connections_rmse.append(rmse)
        # if np.sum(~np.isnan(connection)) > 9:
        #     break

    return full_connections, full_orientations, full_connections_rmse


def initial_pair_casting_indexing(pair_list,
                                  spot_qs,
                                  ref_qs,
                                  iter_max=50):

    full_connections = []
    full_orientations = []
    full_connections_rmse = []
    for i in tqdm(range(len(pair_list))):

        pair = pair_list[i]
        prev_connection = pair.copy()

        iter_count = 0
        iter_max = 50
        ITERATE = True
        while ITERATE:
            # Blank baseline connection
            connection = pair.copy()
            connection[:] = np.nan

            orientation, rmse = fit_orientation_index(prev_connection,
                                                      spot_qs,
                                                      ref_qs)

            rot_qs = ref_qs @ orientation.as_matrix()
            q_mask = generate_q_mask(rot_qs,
                                     (min_tth, max_tth),
                                     (min_chi, max_chi),
                                     (min_wavelength, max_wavelength),
                                     ext=0.05)

            # kdtree built from spots so we can query the reference lattice and avoid non-crystallographic indexing
            kdtree = KDTree(spot_qs)
            pot_conn = kdtree.query_ball_point(rot_qs[q_mask], r=near_q)

            # Build new connection
            for conn_i, conn in enumerate(pot_conn):
                if len(conn) > 0:
                    if len(conn) == 0:
                        continue
                    elif len(conn) == 1:
                        # Add candidate reflection
                        connection[conn[0]] = np.nonzero(q_mask)[0][conn_i]
                    else:
                        # Add closest of multiple candidate reflections
                        ref_dist, ref_idx = kdtree.query(rot_qs[q_mask][conn_i])
                        connection[ref_idx] = np.nonzero(q_mask)[0][conn_i]
            
            # Eliminate less than pairs and replace with pair
            if np.sum(~np.isnan(connection)) <= 1:
                connection = pair.copy()
                ITERATE = False
                orientation, rmse = fit_orientation_index(connection, 
                                                          spot_qs,
                                                          ref_qs)
                break

            # Compare connection with previous connection
            connection_spots = np.nonzero(~np.isnan(connection))[0]
            prev_connection_spots = np.nonzero(~np.isnan(prev_connection))[0]

            if len(connection_spots) == len(prev_connection_spots):
                connection_refs = connection[connection_spots]
                prev_connection_refs = prev_connection[prev_connection_spots]

                if (np.all(connection_spots == prev_connection_spots)
                    and np.all(connection_refs == prev_connection_refs)):
                    ITERATE = False

            prev_connection = connection.copy()
            iter_count += 1
            if iter_count >= iter_max:
                ITERATE = False
                # Re-update orientation
                orientation, rmse = fit_orientation_index(connection, 
                                                          spot_qs,
                                                          ref_qs)
                print(f'Max iterations reached for pair {i}.')

        full_connections.append(connection)
        full_orientations.append(orientation)
        full_connections_rmse.append(rmse)

    return full_connections, full_orientations, full_connections_rmse



In [186]:
# Decompose orientation with pair casting
best_connections = []
excluded_spot_indices = []
included_spot_mask = np.asarray([True,] * len(expanded_pair_list[0]))
included_spot_mask[excluded_spot_indices] = False
original_pair_list = expanded_pair_list[keep_pair_mask]
pair_list = original_pair_list

iter_count = 0
ITERATE = True
while ITERATE:
    print(f'Finding orientation {iter_count + 1}.')

    # Find new connections
    connections, orientations, rmse = initial_pair_casting_indexing(pair_list,
                                                                    spot_qs[included_spot_mask],
                                                                    ref_qs)
    
    qof_norm_list = []
    for conn, orientation in zip(connections, orientations):
        rot_qs = orientation.apply(ref_qs, inverse=True)
        q_mask = generate_q_mask(rot_qs,
                                 (min_tth, max_tth),
                                 (min_chi, max_chi),
                                 (min_wavelength, max_wavelength),
                                 ext=0.05)
        
        fit_spot_qs = spot_qs[included_spot_mask][np.nonzero(~np.isnan(conn))[0]]
        fit_ref_qs = np.asarray(rot_qs)[conn[np.nonzero(~np.isnan(conn))[0]].astype(int)]
        all_ref_qs = rot_qs[q_mask]
        all_ref_fs = ref_fs[q_mask]

        qof, norm_qof = get_quality_of_fit(fit_spot_qs, fit_ref_qs, all_ref_qs, all_ref_fs, sigma=near_q * 1)
        qof_norm_list.append(norm_qof)

    best_connection = connections[np.argmax(qof_norm_list)]
    
    # Find best connection with highest connectivity
    # connection_length = [np.sum(~np.isnan(conn)) for conn in connections]
    # connection_mask = connection_length == np.max(connection_length)
    # best_connection = connections[np.nonzero(connection_mask)[0][np.argmin(np.asarray(rmse)[connection_mask])]]
    if np.sum(~np.isnan(best_connection)) <= 1: # I am not sure why this is even necessary
        ITERATE = False
        break
    expanded_best_connection = np.asarray([np.nan,] * len(included_spot_mask))
    expanded_best_connection[included_spot_mask] = best_connection
    best_connections.append(expanded_best_connection)

    # if iter_count == 0:
    #     raise

    # Update connections
    excluded_spot_indices.extend(*list(np.nonzero(~np.isnan(expanded_best_connection))))
    included_spot_mask[excluded_spot_indices] = False

    # Remove pairs where spots have already been indexed
    new_pairs = []
    for pair in pair_list:
        # All nan means the pair does not use any of the ecluded indices
        if np.all([np.isnan(pair[index]) for index in np.nonzero(~np.isnan(best_connection))[0]]):
            new_pairs.append(pair[np.isnan(best_connection)])
    pair_list = np.asarray(new_pairs)

    iter_count += 1
    if (len(spot_qs) - len(excluded_spot_indices) < 1 # Impossible to solve orientation
        or len(pair_list) < 1 # No more valid pairs to solve
        or iter_count >= 50): # Avoid infinite loops
        ITERATE = False

Finding orientation 1.


  fit_orientation, fit_rssd = Rotation.align_vectors(fit_ref_qs, fit_spot_qs)
  6%|▌         | 4529/80188 [00:23<08:29, 148.58it/s]

Max iterations reached for pair 4496.


 19%|█▊        | 14945/80188 [01:13<07:00, 155.13it/s]

Max iterations reached for pair 14916.


 26%|██▋       | 21192/80188 [01:49<07:25, 132.47it/s]

Max iterations reached for pair 21170.


 27%|██▋       | 21767/80188 [01:53<07:11, 135.35it/s]

Max iterations reached for pair 21743.


 27%|██▋       | 21801/80188 [01:53<08:16, 117.57it/s]

Max iterations reached for pair 21779.


 34%|███▎      | 26932/80188 [02:23<06:33, 135.50it/s]

Max iterations reached for pair 26907.


 64%|██████▎   | 50974/80188 [04:12<02:42, 179.50it/s]

Max iterations reached for pair 50946.


 88%|████████▊ | 70777/80188 [05:38<00:47, 199.97it/s]

Max iterations reached for pair 70752.


 93%|█████████▎| 74215/80188 [05:51<00:33, 180.57it/s]

Max iterations reached for pair 74179.


 96%|█████████▌| 76599/80188 [06:01<00:15, 224.71it/s]

Max iterations reached for pair 76560.


 98%|█████████▊| 78949/80188 [06:11<00:07, 168.93it/s]

Max iterations reached for pair 78921.


100%|██████████| 80188/80188 [06:16<00:00, 213.22it/s]


Finding orientation 2.


  6%|▌         | 4379/76225 [00:21<07:12, 166.28it/s]

Max iterations reached for pair 4356.


 19%|█▉        | 14459/76225 [01:12<05:36, 183.68it/s]

Max iterations reached for pair 14432.


 27%|██▋       | 20232/76225 [01:41<06:41, 139.31it/s]

Max iterations reached for pair 20196.


 27%|██▋       | 20778/76225 [01:43<05:39, 163.09it/s]

Max iterations reached for pair 20756.


 27%|██▋       | 20816/76225 [01:44<06:53, 133.99it/s]

Max iterations reached for pair 20792.


 34%|███▍      | 25783/76225 [02:07<04:58, 169.21it/s]

Max iterations reached for pair 25744.


 64%|██████▎   | 48589/76225 [05:39<07:07, 64.59it/s]  

Max iterations reached for pair 48577.


 88%|████████▊ | 67392/76225 [08:40<02:14, 65.45it/s] 

Max iterations reached for pair 67380.


 93%|█████████▎| 70606/76225 [09:09<01:20, 70.13it/s] 

Max iterations reached for pair 70590.


 96%|█████████▌| 72895/76225 [09:31<00:39, 83.70it/s] 

Max iterations reached for pair 72881.


 98%|█████████▊| 75048/76225 [09:51<00:17, 65.53it/s] 

Max iterations reached for pair 75037.


100%|██████████| 76225/76225 [10:02<00:00, 126.60it/s]


Finding orientation 3.


  5%|▌         | 3860/72614 [00:41<30:05, 38.07it/s] 

Max iterations reached for pair 3851.


 19%|█▉        | 13701/72614 [02:21<12:46, 76.82it/s] 

Max iterations reached for pair 13683.


 27%|██▋       | 19313/72614 [03:19<13:37, 65.18it/s] 

Max iterations reached for pair 19297.


 27%|██▋       | 19857/72614 [03:24<14:54, 58.96it/s] 

Max iterations reached for pair 19846.


 27%|██▋       | 19894/72614 [03:25<16:14, 54.11it/s]

Max iterations reached for pair 19882.


 34%|███▍      | 24708/72614 [04:11<14:40, 54.38it/s] 

Max iterations reached for pair 24699.


 72%|███████▏  | 52334/72614 [08:28<05:08, 65.80it/s] 

Max iterations reached for pair 52315.


 88%|████████▊ | 64157/72614 [10:14<02:37, 53.81it/s] 

Max iterations reached for pair 64141.


 93%|█████████▎| 67282/72614 [10:45<01:28, 59.93it/s] 

Max iterations reached for pair 67271.


 94%|█████████▍| 68278/72614 [10:54<00:56, 76.23it/s] 

Max iterations reached for pair 68265.


 96%|█████████▌| 69520/72614 [11:05<00:34, 88.64it/s] 

Max iterations reached for pair 69498.


 98%|█████████▊| 71517/72614 [11:24<00:16, 67.53it/s] 

Max iterations reached for pair 71496.


100%|██████████| 72614/72614 [11:34<00:00, 104.58it/s]


Finding orientation 4.


100%|██████████| 69722/69722 [10:56<00:00, 106.23it/s]


Finding orientation 5.


100%|██████████| 66745/66745 [10:20<00:00, 107.56it/s]


Finding orientation 6.


100%|██████████| 63951/63951 [09:43<00:00, 109.63it/s]


Finding orientation 7.


100%|██████████| 61520/61520 [10:04<00:00, 101.81it/s]


Finding orientation 8.


100%|██████████| 58226/58226 [11:12<00:00, 86.53it/s] 


Finding orientation 9.


100%|██████████| 55472/55472 [08:45<00:00, 105.60it/s]


Finding orientation 10.


100%|██████████| 53369/53369 [03:38<00:00, 243.78it/s]


Finding orientation 11.


 51%|█████▏    | 26662/51956 [01:50<02:13, 188.96it/s]

Max iterations reached for pair 26618.


 52%|█████▏    | 26767/51956 [01:50<02:08, 195.37it/s]

Max iterations reached for pair 26734.


 52%|█████▏    | 26810/51956 [01:50<02:36, 160.98it/s]

Max iterations reached for pair 26782.


100%|██████████| 51956/51956 [03:33<00:00, 243.67it/s]


Finding orientation 12.


 52%|█████▏    | 26027/50328 [01:46<02:04, 194.55it/s]

Max iterations reached for pair 25976.


 52%|█████▏    | 26131/50328 [01:47<02:06, 190.95it/s]

Max iterations reached for pair 26089.


 52%|█████▏    | 26176/50328 [01:47<02:21, 171.20it/s]

Max iterations reached for pair 26137.


100%|██████████| 50328/50328 [03:22<00:00, 248.77it/s]


Finding orientation 13.


 50%|█████     | 24500/48638 [01:37<02:09, 187.11it/s]

Max iterations reached for pair 24467.


 51%|█████     | 24611/48638 [01:38<02:04, 193.41it/s]

Max iterations reached for pair 24580.


 51%|█████     | 24657/48638 [01:38<02:20, 171.24it/s]

Max iterations reached for pair 24628.


100%|██████████| 48638/48638 [03:09<00:00, 256.70it/s]


Finding orientation 14.


 51%|█████     | 24168/47219 [01:38<01:59, 193.14it/s]

Max iterations reached for pair 24121.


 51%|█████▏    | 24271/47219 [01:38<01:59, 192.41it/s]

Max iterations reached for pair 24231.


 51%|█████▏    | 24317/47219 [01:39<02:12, 173.39it/s]

Max iterations reached for pair 24279.


100%|██████████| 47219/47219 [03:06<00:00, 253.52it/s]


Finding orientation 15.


 51%|█████     | 23154/45669 [01:30<01:56, 193.82it/s]

Max iterations reached for pair 23114.


 51%|█████     | 23269/45669 [01:31<01:54, 195.82it/s]

Max iterations reached for pair 23224.


 51%|█████     | 23318/45669 [01:31<02:04, 179.48it/s]

Max iterations reached for pair 23272.


100%|██████████| 45669/45669 [02:54<00:00, 261.68it/s]


Finding orientation 16.


 51%|█████     | 22563/44215 [01:27<01:52, 191.72it/s]

Max iterations reached for pair 22529.


 51%|█████▏    | 22677/44215 [01:28<01:46, 201.63it/s]

Max iterations reached for pair 22639.


 51%|█████▏    | 22727/44215 [01:28<01:58, 181.31it/s]

Max iterations reached for pair 22685.


100%|██████████| 44215/44215 [02:48<00:00, 263.09it/s]


Finding orientation 17.


 50%|█████     | 21540/42758 [01:21<01:43, 204.53it/s]

Max iterations reached for pair 21510.


 51%|█████     | 21653/42758 [01:21<01:40, 209.11it/s]

Max iterations reached for pair 21618.


 51%|█████     | 21703/42758 [01:22<01:51, 188.54it/s]

Max iterations reached for pair 21664.


100%|██████████| 42758/42758 [02:38<00:00, 269.21it/s]


Finding orientation 18.


 50%|█████     | 21237/42119 [01:19<01:30, 229.77it/s]

Max iterations reached for pair 21179.


 51%|█████     | 21323/42119 [01:19<01:39, 209.67it/s]

Max iterations reached for pair 21286.


 51%|█████     | 21374/42119 [01:20<01:51, 185.77it/s]

Max iterations reached for pair 21332.


100%|██████████| 42119/42119 [02:35<00:00, 270.79it/s]


Finding orientation 19.


 50%|█████     | 20499/40893 [01:16<01:33, 217.22it/s]

Max iterations reached for pair 20453.


 50%|█████     | 20584/40893 [01:16<01:50, 183.58it/s]

Max iterations reached for pair 20559.


 51%|█████     | 20657/40893 [01:17<01:46, 189.74it/s]

Max iterations reached for pair 20605.


100%|██████████| 40893/40893 [02:28<00:00, 274.80it/s]


Finding orientation 20.


 50%|████▉     | 19653/39686 [01:12<01:30, 222.51it/s]

Max iterations reached for pair 19597.


 50%|████▉     | 19735/39686 [01:13<01:41, 197.37it/s]

Max iterations reached for pair 19703.


 50%|████▉     | 19786/39686 [01:13<01:49, 181.55it/s]

Max iterations reached for pair 19749.


100%|██████████| 39686/39686 [02:25<00:00, 273.59it/s]


Finding orientation 21.


 49%|████▉     | 19294/39084 [01:09<01:30, 219.08it/s]

Max iterations reached for pair 19259.


 50%|████▉     | 19410/39084 [01:10<01:33, 210.70it/s]

Max iterations reached for pair 19365.


 50%|████▉     | 19462/39084 [01:10<01:42, 191.86it/s]

Max iterations reached for pair 19410.


100%|██████████| 39084/39084 [02:20<00:00, 278.28it/s]


Finding orientation 22.


 49%|████▉     | 18476/37684 [01:07<01:28, 216.96it/s]

Max iterations reached for pair 18426.


 49%|████▉     | 18559/37684 [01:07<01:34, 203.36it/s]

Max iterations reached for pair 18528.


 49%|████▉     | 18609/37684 [01:08<01:45, 181.60it/s]

Max iterations reached for pair 18573.


100%|██████████| 37684/37684 [02:15<00:00, 278.84it/s]


Finding orientation 23.


 48%|████▊     | 17562/36222 [01:02<01:22, 225.07it/s]

Max iterations reached for pair 17515.


 49%|████▊     | 17641/36222 [01:02<01:42, 180.98it/s]

Max iterations reached for pair 17616.


 49%|████▉     | 17692/36222 [01:02<01:44, 177.52it/s]

Max iterations reached for pair 17660.


100%|██████████| 36222/36222 [02:06<00:00, 286.67it/s]


Finding orientation 24.


 49%|████▉     | 17361/35244 [01:00<01:12, 245.58it/s]

Max iterations reached for pair 17296.


 49%|████▉     | 17416/35244 [01:01<01:33, 190.99it/s]

Max iterations reached for pair 17393.


 50%|████▉     | 17495/35244 [01:01<01:28, 200.50it/s]

Max iterations reached for pair 17435.


100%|██████████| 35244/35244 [02:02<00:00, 287.29it/s]


Finding orientation 25.


 49%|████▉     | 16686/34126 [00:58<01:15, 230.11it/s]

Max iterations reached for pair 16641.


 49%|████▉     | 16772/34126 [00:59<01:21, 212.49it/s]

Max iterations reached for pair 16738.


 49%|████▉     | 16824/34126 [00:59<01:32, 186.69it/s]

Max iterations reached for pair 16779.


100%|██████████| 34126/34126 [01:58<00:00, 288.47it/s]


Finding orientation 26.


 48%|████▊     | 15678/32832 [00:54<01:12, 236.02it/s]

Max iterations reached for pair 15631.


 48%|████▊     | 15761/32832 [00:55<01:27, 195.30it/s]

Max iterations reached for pair 15727.


 48%|████▊     | 15813/32832 [00:55<01:33, 182.30it/s]

Max iterations reached for pair 15768.


100%|██████████| 32832/32832 [01:54<00:00, 287.66it/s]


Finding orientation 27.


 48%|████▊     | 15320/32170 [00:53<01:13, 230.11it/s]

Max iterations reached for pair 15262.


 48%|████▊     | 15375/32170 [00:53<01:29, 188.34it/s]

Max iterations reached for pair 15357.


 48%|████▊     | 15429/32170 [00:53<01:31, 183.28it/s]

Max iterations reached for pair 15398.


100%|██████████| 32170/32170 [01:50<00:00, 290.92it/s]


Finding orientation 28.


 47%|████▋     | 14744/31227 [00:51<01:09, 237.41it/s]

Max iterations reached for pair 14706.


 47%|████▋     | 14832/31227 [00:51<01:17, 212.18it/s]

Max iterations reached for pair 14801.


 48%|████▊     | 14881/31227 [00:52<01:28, 184.83it/s]

Max iterations reached for pair 14841.


100%|██████████| 31227/31227 [01:46<00:00, 292.02it/s]


Finding orientation 29.


 47%|████▋     | 14107/30283 [00:43<01:00, 268.01it/s]

Max iterations reached for pair 14057.


 47%|████▋     | 14168/30283 [00:44<01:16, 210.75it/s]

Max iterations reached for pair 14152.


 47%|████▋     | 14223/30283 [00:44<01:22, 194.33it/s]

Max iterations reached for pair 14192.


100%|██████████| 30283/30283 [01:34<00:00, 321.13it/s]


Finding orientation 30.


 46%|████▌     | 13548/29378 [00:41<00:57, 276.29it/s]

Max iterations reached for pair 13508.


 46%|████▋     | 13614/29378 [00:41<01:15, 209.68it/s]

Max iterations reached for pair 13603.


 47%|████▋     | 13675/29378 [00:42<01:16, 204.61it/s]

Max iterations reached for pair 13643.


100%|██████████| 29378/29378 [01:30<00:00, 326.14it/s]


Finding orientation 31.


 47%|████▋     | 13433/28825 [00:41<00:58, 263.81it/s]

Max iterations reached for pair 13384.


 47%|████▋     | 13495/28825 [00:41<01:12, 211.05it/s]

Max iterations reached for pair 13476.


 47%|████▋     | 13582/28825 [00:42<01:09, 220.44it/s]

Max iterations reached for pair 13516.


100%|██████████| 28825/28825 [01:28<00:00, 325.91it/s]


Finding orientation 32.


 46%|████▌     | 12802/27893 [00:39<00:59, 253.80it/s]

Max iterations reached for pair 12749.


 46%|████▌     | 12859/27893 [00:39<01:14, 202.78it/s]

Max iterations reached for pair 12841.


 46%|████▋     | 12912/27893 [00:40<01:17, 194.28it/s]

Max iterations reached for pair 12881.


100%|██████████| 27893/27893 [01:25<00:00, 325.68it/s]


Finding orientation 33.


 46%|████▌     | 12576/27493 [00:41<00:58, 256.74it/s]

Max iterations reached for pair 12520.


 46%|████▌     | 12634/27493 [00:41<01:11, 207.10it/s]

Max iterations reached for pair 12612.


 46%|████▌     | 12687/27493 [00:41<01:17, 191.96it/s]

Max iterations reached for pair 12651.


100%|██████████| 27493/27493 [01:29<00:00, 308.75it/s]


Finding orientation 34.


 45%|████▌     | 12063/26541 [00:39<00:58, 247.91it/s]

Max iterations reached for pair 12022.


 46%|████▌     | 12121/26541 [00:40<01:13, 196.00it/s]

Max iterations reached for pair 12112.


 46%|████▌     | 12211/26541 [00:40<01:07, 213.71it/s]

Max iterations reached for pair 12150.


100%|██████████| 26541/26541 [01:26<00:00, 307.97it/s]


Finding orientation 35.


 45%|████▍     | 11242/25215 [00:36<01:01, 228.44it/s]

Max iterations reached for pair 11212.


 45%|████▍     | 11329/25215 [00:36<01:08, 203.88it/s]

Max iterations reached for pair 11299.


 45%|████▌     | 11381/25215 [00:37<01:12, 190.49it/s]

Max iterations reached for pair 11337.


100%|██████████| 25215/25215 [01:21<00:00, 308.00it/s]


Finding orientation 36.


100%|██████████| 24413/24413 [01:17<00:00, 315.77it/s]


Finding orientation 37.


100%|██████████| 23730/23730 [01:14<00:00, 320.09it/s]


Finding orientation 38.


100%|██████████| 22999/22999 [01:13<00:00, 314.67it/s]


Finding orientation 39.


100%|██████████| 22787/22787 [01:10<00:00, 321.30it/s]


Finding orientation 40.


100%|██████████| 22192/22192 [01:09<00:00, 319.20it/s]


Finding orientation 41.


100%|██████████| 21611/21611 [01:07<00:00, 320.76it/s]


Finding orientation 42.


100%|██████████| 20765/20765 [01:03<00:00, 325.94it/s]


Finding orientation 43.


100%|██████████| 20493/20493 [01:04<00:00, 318.51it/s]


Finding orientation 44.


100%|██████████| 20028/20028 [01:02<00:00, 320.03it/s]


Finding orientation 45.


100%|██████████| 19384/19384 [00:59<00:00, 327.64it/s]


Finding orientation 46.


100%|██████████| 18821/18821 [00:58<00:00, 324.13it/s]


Finding orientation 47.


100%|██████████| 18435/18435 [00:56<00:00, 325.82it/s]


Finding orientation 48.


100%|██████████| 18273/18273 [00:55<00:00, 327.98it/s]


Finding orientation 49.


100%|██████████| 17679/17679 [00:54<00:00, 325.55it/s]


Finding orientation 50.


100%|██████████| 17136/17136 [00:47<00:00, 358.24it/s]


In [187]:
np.asarray([np.sum(~np.isnan(conn)) for conn in best_connections])

array([9, 8, 7, 7, 8, 6, 8, 7, 6, 5, 6, 6, 5, 6, 5, 5, 4, 5, 5, 4, 4, 5,
       4, 5, 4, 4, 4, 4, 4, 3, 4, 3, 4, 5, 4, 3, 4, 3, 3, 3, 4, 3, 2, 3,
       3, 3, 2, 3, 4, 3])

In [747]:
stibnite.c

11.234

In [746]:
stibnite.a

11.314

In [191]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

colors = ['k',] * len(best_connections)
# colors = ['black', 'blue', 'green', 'purple'] + ['none'] * (len(best_connections) - 4)

for best_fit_ind in range(len(best_connections)):
    if best_fit_ind > 20:
        break

    orientation, rmse = fit_orientation_index(best_connections[best_fit_ind],
                                              spot_qs,
                                              ref_qs)
    # print(rmse)

    #rot_qs = ref_qs @ orientation.as_matrix()
    rot_qs = orientation.apply(ref_qs, inverse=True)
    q_mask = generate_q_mask(rot_qs,
                            (min_tth, max_tth),
                            (min_chi, max_chi),
                            (min_wavelength, max_wavelength),
                            ext=0.05)

    ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c=colors[best_fit_ind])

    hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
    fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
    for idx, hkl in enumerate(hkls):
        ax.text(*fit_spots[idx], str(hkl), fontsize=8, c=colors[best_fit_ind])

for edge in edges:
    ax.plot(*edge, c='gray', lw=1)

ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')
ax.scatter(0, 0, 0, s=10, c='blue')
ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [149]:
rmse_list = []
qof_list = []
norm_qof_list = []
for conn in tqdm(connections):
    orientation, rmse = fit_orientation_index(conn, spot_qs[included_spot_mask], ref_qs)

    #rot_qs = ref_qs @ orientation.as_matrix()
    rot_qs = orientation.apply(ref_qs, inverse=True)
    q_mask = generate_q_mask(rot_qs,
                            (min_tth, max_tth),
                            (min_chi, max_chi),
                            (min_wavelength, max_wavelength),
                            ext=0.05)


    fit_spot_qs = spot_qs[included_spot_mask][np.nonzero(~np.isnan(conn))[0]]
    fit_ref_qs = np.asarray(rot_qs)[conn[np.nonzero(~np.isnan(conn))[0]].astype(int)]
    all_ref_qs = rot_qs[q_mask]
    all_ref_fs = ref_fs[q_mask]

    qof, norm_qof = get_quality_of_fit(fit_spot_qs, fit_ref_qs, all_ref_qs, all_ref_fs, sigma=near_q * 1)

    rmse_list.append(rmse)
    qof_list.append(qof)
    norm_qof_list.append(norm_qof)

  fit_orientation, fit_rssd = Rotation.align_vectors(fit_ref_qs, fit_spot_qs)
100%|██████████| 14023/14023 [00:39<00:00, 359.13it/s]


In [31]:
def get_quality_of_fit(spot_qs,
                       ref_qs,
                       all_ref_qs,
                       all_ref_fs,
                       sigma=near_q):
    
    # Requirements:
    # 1. Penalize missing reflections weighted according to their expected intensity
    # 2. Do not penalize extra reflections which are not indexed (allows for overlapping orientations)
    # 3. Penalize reflections weighted by their distance from expected positions

    # # Determine which reflections are indexed
    # found_spot_mask = [tuple(ref) in [tuple(x) for x in ref_qs] for ref in all_ref_qs]

    # Normalize structure_factors to approximate total expected intensity
    # all_ref_fs /= np.sum(all_ref_fs)

    dist = [np.sqrt(np.sum([(p - q)**2 for p, q in zip(v1, v2)]))
                        for v1, v2 in zip(spot_qs, ref_qs)]

    # # 1D Gaussian centered at zero
    # def gauss_1d(x, amp, fwhm):
    #     sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
    #     return amp * np.exp(-(x)**2 / (2 * sigma**2))

    # qof = np.sum(gauss_1d(np.asarray(dist),
    #                       all_ref_fs[found_spot_mask],
    #                       fwhm))
    
    # Gaussian with structure factor amplitude and near_q standard deviation
    # centered at zero sampled at distance
    # qof = np.sum(
    #     np.log(all_ref_fs[found_spot_mask])
    #     * np.exp(-(np.asarray(dist))**2
    #              / (2 * sigma**2)))

    # Gaussian with specified standard deviation
    # centered at zero sampled at distance
    qof = np.sum(np.exp(-(np.asarray(dist))**2 / (2 * sigma**2)))
    # dist_qual /= len(spot_qs) # Normalized to one

    # # int_explained = np.sum(all_ref_fs[found_spot_mask]) / np.sum(all_ref_fs)
    # int_explained = len(spot_qs) / len(all_ref_fs)

    # qof = 0.5 * (dist_qual + int_explained)
    # norm_qof = qof
        
    # max_qof = np.log(np.sum(all_ref_fs))
    max_qof = len(all_ref_qs)
    norm_qof = qof / max_qof

    return qof, norm_qof

In [150]:
np.max(norm_qof_list)

0.6659233808314196

In [151]:
connections[np.argmax(norm_qof_list)]

array([1166.,   nan, 1014., 2877., 2409.,   nan,   nan, 3796., 2212.,
       3722.,   nan, 1699., 3790.,   nan,   nan,   nan, 4388.,   nan,
         nan, 1086.,   nan, 2711., 3452.,   nan,   nan,   nan,   nan,
         nan,   nan,   nan,   nan,   nan,  175., 2041.,   nan,   nan,
         nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
         nan,   nan,   nan,   nan])

In [152]:
connection_lengths = [np.sum(~np.isnan(conn)) for conn in connections]
connection_mask = connection_lengths == np.max(connection_lengths)
connections[np.nonzero(connection_mask)[0][np.argmin(np.asarray(rmse_list)[connection_mask])]]

array([1166.,   nan, 1014., 2877., 2409.,   nan,   nan, 3796., 2212.,
       3722.,   nan, 1699., 3790.,   nan,   nan,   nan, 4388.,   nan,
         nan, 1086.,   nan, 2711., 3452.,   nan,   nan,   nan,   nan,
         nan,   nan,   nan,   nan,   nan,  175., 2041.,   nan,   nan,
         nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,   nan,
         nan,   nan,   nan,   nan])

In [153]:
norm_qof_list[np.nonzero(connection_mask)[0][np.argmin(np.asarray(rmse_list)[connection_mask])]]

0.6659233808314196

In [156]:
np.unique(norm_qof_list)[::-1][np.unique(norm_qof_list)[::-1] > 0.5]

array([0.66592338, 0.63529521, 0.61285578, 0.61285578, 0.61216955,
       0.58175019])

In [157]:
np.asarray(norm_qof_list)[np.asarray(norm_qof_list) > 0.5]

array([0.66592338, 0.61216955, 0.63529521, 0.61285578, 0.63529521,
       0.61285578, 0.63529521, 0.63529521, 0.66592338, 0.63529521,
       0.61285578, 0.66592338, 0.66592338, 0.63529521, 0.63529521,
       0.63529521, 0.66592338, 0.66592338, 0.66592338, 0.66592338,
       0.66592338, 0.58175019, 0.66592338, 0.66592338, 0.63529521,
       0.66592338, 0.66592338, 0.66592338, 0.63529521, 0.61285578,
       0.66592338, 0.58175019, 0.63529521, 0.63529521, 0.66592338,
       0.66592338, 0.58175019, 0.63529521, 0.63529521, 0.63529521,
       0.63529521, 0.63529521, 0.63529521, 0.61285578, 0.66592338,
       0.66592338, 0.61285578, 0.63529521, 0.63529521, 0.61285578,
       0.63529521, 0.63529521, 0.61285578, 0.66592338, 0.61285578,
       0.63529521, 0.63529521, 0.63529521, 0.61285578, 0.63529521,
       0.63529521, 0.61285578, 0.63529521, 0.63529521, 0.66592338,
       0.63529521, 0.63529521, 0.63529521, 0.66592338, 0.58175019,
       0.66592338, 0.66592338, 0.58175019, 0.66592338, 0.66592

In [158]:
np.nonzero(np.asarray(norm_qof_list) > 0.5)[0]

array([    6,    11,    15,    16,    17,    18,    21,    55,    75,
          89,    90,    91,   102,   125,   137,   203,   208,   299,
         337,   357,   373,   374,   378,   402,   438,   472,   589,
         635,   711,   720,   727,   729,   740,   855,   896,   943,
         945,   951,   988,  1044,  1112,  1201,  1239,  1243,  1517,
        1628,  1986,  1988,  2015,  2075,  2083,  2084,  2092,  2122,
        2136,  2137,  2151,  2171,  2209,  2258,  2303,  2324,  2328,
        2508,  2538,  2713,  2808,  2815,  2820,  2824,  2825,  2827,
        2828,  2829,  2831,  2845,  2857,  2862,  2863,  2891,  2892,
        2894,  2962,  2964,  2968,  2971,  2981,  2992,  3002,  3075,
        3081,  3110,  3115,  3147,  3157,  3166,  3366,  3385,  3510,
        4033,  4053,  4071,  4088,  4192,  4241,  4244,  4266,  4296,
        4466,  4473,  4688,  4693,  4705,  4718,  4733,  4809,  4811,
        4834,  4838,  4865,  4878,  4888,  5075,  5093,  5245,  5247,
        5248,  5253,

In [163]:
norm_qof_list[11]

0.6121695541001125

In [167]:
connection_lengths = [np.sum(~np.isnan(conn)) for conn in connections]
connection_mask = connection_lengths == np.max(connection_lengths) - 1
np.asarray(rmse_list)[connection_mask]

array([0.04039877, 0.0496286 , 0.04039877, 0.0496286 , 0.04039877,
       0.04039877, 0.04039877, 0.0496286 , 0.04039877, 0.04039877,
       0.04039877, 0.05006308, 0.04039877, 0.04039877, 0.0496286 ,
       0.05006308, 0.04039877, 0.04039877, 0.05006308, 0.04039877,
       0.04039877, 0.04039877, 0.04039877, 0.04039877, 0.04039877,
       0.0496286 , 0.0496286 , 0.04039877, 0.04039877, 0.0496286 ,
       0.04039877, 0.04039877, 0.0496286 , 0.0496286 , 0.04039877,
       0.04039877, 0.04039877, 0.0496286 , 0.04039877, 0.04039877,
       0.0496286 , 0.04039877, 0.04039877, 0.04039877, 0.04039877,
       0.04039877, 0.05006308, 0.05006308, 0.04039877, 0.0496286 ,
       0.04039877, 0.05006308, 0.04039877, 0.0496286 , 0.04039877,
       0.0496286 , 0.04039877, 0.04039877, 0.04039877, 0.04039877,
       0.0496286 , 0.04039877, 0.0496286 , 0.04039877, 0.0496286 ,
       0.04039877, 0.04039877, 0.04039877, 0.04039877, 0.04039877,
       0.04039877, 0.04039877, 0.04039877, 0.0496286 , 0.04039

In [169]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

plot_connection = connections[np.argmax(norm_qof_list)]
# plot_connection = connections[np.nonzero(connection_mask)[0][np.argmin(np.asarray(rmse_list)[connection_mask])]]
plot_connection = connections[11]
plot_connection = np.asarray(connections)[connection_mask][1]

orientation, rmse = fit_orientation_index(plot_connection,
                                            spot_qs[included_spot_mask],
                                            ref_qs)

#rot_qs = ref_qs @ orientation.as_matrix()
rot_qs = orientation.apply(ref_qs, inverse=True)
q_mask = generate_q_mask(rot_qs,
                        (min_tth, max_tth),
                        (min_chi, max_chi),
                        (min_wavelength, max_wavelength),
                        ext=0.05)

ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='k')

hkls = [ref_hkls[int(ind)] for ind in plot_connection[~np.isnan(plot_connection)]]
fit_spots = spot_qs[included_spot_mask][np.nonzero(~np.isnan(plot_connection))[0]]
for idx, hkl in enumerate(hkls):
    ax.text(*fit_spots[idx], str(hkl), fontsize=8, c='k')

for edge in edges:
    ax.plot(*edge, c='gray', lw=1)

ax.scatter(*np.asarray(spot_qs[included_spot_mask]).T, s = 1, c='r')
ax.scatter(0, 0, 0, s=10, c='blue')
ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [731]:
from xrdmaptools.geometry.geometry import get_q_vect, q_2_polar, estimate_image_coords
from xrdmaptools.utilities.math import wavelength_2_energy

tth, chi, wavelength = q_2_polar(spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]], degrees=True)
# coords = estimate_image_coords(np.asarray([tth, chi]).T, rsm.tth_arr, rsm.chi_arr)[:, ::-1]

In [732]:
orientation, rmse = fit_orientation_index(best_connections[0],
                                          spot_qs,
                                          ref_qs)

#rot_qs = ref_qs @ orientation.as_matrix()
rot_qs = orientation.apply(ref_qs, inverse=True)

ref_tth, ref_chi, ref_wavelength = q_2_polar(np.asarray(rot_qs)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)], degrees=True)

In [737]:
def det_plane_from_ai(ai, skip=50):
    points = np.asarray([ai.position_array()[::skip, ::skip, i].ravel()
                         for i in [2, 1, 0]])
    
    d = np.mean(points, axis=1, keepdims=True)
    svd = np.linalg.svd(points - d)
    d = d.squeeze()
    # Plane normal n = (a, b, c) and point (d)
    return svd[0][:, -1], d

n, d = det_plane_from_ai(rsm.ai, skip=50)

In [734]:
class k_vector():
    def __init__(self, point, tth, chi, degrees=False):
        
        self.x0, self.y0, self.z0 = point

        if degrees:
            tth = np.radians(tth)
            chi = np.radians(chi)


        self.a = np.sin(tth) * np.cos(chi)
        self.b = np.sin(tth) * np.sin(chi)
        self.c = np.cos(tth)
            
    def __call__(self, t):
        return (self.x0 + self.a * t,
                self.y0 + self.b * t,
                self.z0 + self.c * t)
    
    def get_planar_intercept(self, a, b, c, d):

        t = ((a * (d[0] - self.x0) + b * (d[1] - self.y0) + c * (d[2] - self.z0))
             / (self.a * a + self.b * b + self.c * c))
        
        return self(t)
    
    def copy(self, point=None, tth=None, chi=None):
        if point is None:
            point = (self.x0, self.y0, self.z0)
        if tth is None:
            tth = self.tth
        if chi is None:
            chi = self.chi
        
        return self.__class__(point, tth, chi)

In [735]:
def lstsq_line_intersect(P0, P1):
    # From Traa, Johannes "Least-Squares Intersection of Lines" (2013).
    
    # Generate all line direction vectors 
    n = (P1 - P0) / np.linalg.norm(P1 - P0, axis=1)[:, np.newaxis] # normalized

    # Generate the array of all projectors 
    projs = np.eye(n.shape[1]) - n[:, :, np.newaxis] * n[:, np.newaxis]  # I - n*n.T

    # Generate R matrix and q vector
    R = projs.sum(axis=0)
    q = (projs @ P0[:, :, np.newaxis]).sum(axis=0)

    # Solve the least squares problem for the 
    # Intersection point p: Rp = q
    p = np.linalg.lstsq(R, q, rcond=None)[0]

    return p

In [739]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

skip = 50

pos_arr = rsm.ai.position_array()

x = pos_arr[:, :, 2][::skip, ::skip].ravel()
y = pos_arr[:, :, 1][::skip, ::skip].ravel()
z = pos_arr[:, :, 0][::skip, ::skip].ravel()

ax.scatter(x, y, z, s=1, c='k', alpha=0.1)
ax.scatter(0, 0, 0, s=10, facecolors='none', edgecolors='r')
# ax.scatter(xx.ravel(), yy.ravel(), zz.ravel(), s=1)

spot_k_vectors = []
ref_k_vectors = []
P0 = []
P1 = []
for ind in range(len(tth)):

        spot_k = k_vector((0, 0, 0), tth[ind], chi[ind], degrees=True)
        spot_k_vectors.append(spot_k)
        intercept = spot_k.get_planar_intercept(*n, d)


        ref_k = k_vector(intercept, ref_tth[ind], ref_chi[ind], degrees=True)
        ref_k_vectors.append(ref_k)

        P0.append([ref_k.x0, ref_k.y0, ref_k.z0])
        P1.append([ref_k.a, ref_k.b, ref_k.c])

        ax.scatter(*intercept, s=1, c='r')

        ax.plot(*spot_k(np.linspace(0, 0.5, 100)), c='r', lw=0.1)
        ax.plot(*ref_k(np.linspace(-0.5, 0, 100)), c='blue', lw=0.1)

        # Text
        ax.text(*intercept,
                str(tuple(np.asarray(ref_hkls)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)][ind])),
                fontsize=4, c='k')

P0 = np.asarray(P0)
P1 = np.asarray(P1)
zero_point = lstsq_line_intersect(P0, P1).squeeze()
print(np.linalg.norm(zero_point))
ax.scatter(*zero_point, c='g', s=1)

ax.set_xlabel('x [m]')
ax.set_ylabel('y [m]')
ax.set_zlabel('z [m]')
ax.set_aspect('equal')
fig.show()

0.0010710631321127918


In [742]:
from xrdmaptools.geometry.geometry import q_2_polar, get_q_vect, vector_angle

def apply_zero_point_correction(spot_qs,
                                ref_qs,
                                connection,
                                ai):
    # Combined find and correct zero point

    if len(spot_qs) != len(ref_qs):
        raise ValueError('Length of spots and assigned reference qs must be equal.')
    
    # Find detector plane
    n, d = det_plane_from_ai(ai, skip=50)

    # Convert to polar real-space coordinates
    spot_tth, spot_chi, spot_wavelength = q_2_polar(spot_qs)
    ref_tth, ref_chi, ref_wavelength = q_2_polar(ref_qs)

    # Build vector points
    intercepts = []
    points0, points1 = [], []
    for i in range(len(spot_qs)):
        # Create vectors from original zero to intercept detector plane
        spot_k = k_vector((0, 0, 0), spot_tth[i], spot_chi[i])
        intercept = spot_k.get_planar_intercept(*n, d)
        intercepts.append(intercept)
        
        # Draw reference vectors for detector intercepts back towards zero point
        ref_k = k_vector(intercept, ref_tth[i], ref_chi[i])

        # Convert reference vectors into two points
        # Should be a direct way to do this
        points0.append([ref_k.x0, ref_k.y0, ref_k.z0])
        points1.append([ref_k.a, ref_k.b, ref_k.c])

    zero_point = lstsq_line_intersect(np.asarray(points0),
                                        np.asarray(points1)).squeeze()

    dx, dy, dz = (np.asarray(intercepts) - zero_point).T
    upd_tth = np.arccos(dz / np.sqrt(dx**2 + dy**2 + dz**2))
    upd_chi = np.arctan(dy / dx)

    corr_spot_qs = get_q_vect(upd_tth, upd_chi, spot_wavelength).T

    return corr_spot_qs, tuple(zero_point)

# Unused
def get_zero_point(spot_qs,
                   ref_qs,
                   detector_plane_normal,
                   detector_plane_point):
    
    if len(spot_qs) != len(ref_qs):
        raise ValueError('Length of spots and assigned reference qs must be equal.')
    
    spot_tth, spot_chi, _ = q_2_polar(spot_qs)
    ref_tth, ref_chi, _ = q_2_polar(ref_qs)

    # Build vector points
    points0, points1 = [], []
    for i in range(len(spot_qs)):
        
        # Create vectors from original zero to intercept detector plane
        spot_k = k_vector((0, 0, 0), spot_tth[i], spot_chi[i])
        intercept = spot_k.get_planar_intercept(*detector_plane_normal,
                                                detector_plane_point)
        
        # Draw reference vectors for detector intercepts back towards zero point
        ref_k = k_vector(intercept, ref_tth[i], ref_chi[i])

        # Convert reference vectors into two points
        # Should be a direct way to do this
        points0.append([ref_k.x0, ref_k.y0, ref_k.z0])
        points1.append([ref_k.a, ref_k.b, ref_k.c])
    
    zero_point = lstsq_line_intersect(np.asarray(points0),
                                      np.asarray(points1)).squeeze()
    
    return zero_point

# Unused
def correct_zero_point(spot_qs,
                       zero_point,
                       detector_plane_normal,
                       detector_plane_point):
    
    tth, chi, wavelength = q_2_polar(spot_qs)

    upd_tth, upd_chi = [], []
    for i in range(len(spot_qs)):
        # Draw original vector to determine planar intercetp
        spot_k = k_vector((0, 0, 0), tth[i], chi[i])
        intercept = spot_k.get_planar_intercept(*detector_plane_normal,
                                                detector_plane_point)

        # Intercept on detector plane minus the new zero_point
        dx = intercept[0] - zero_point[0]
        dy = intercept[1] - zero_point[1]
        dz = intercept[2] - zero_point[2]

        # Determine new polar angles from new zero point
        upd_tth.append(vector_angle([dx, dy, dz], [0, 0, 1]))
        upd_chi.append(np.arctan(dy / dx))

    # Convert updated polar coordinates back to q-space
    return get_q_vect(np.asarray(upd_tth), np.asarray(upd_chi), wavelength).T

orientation, rmse = fit_orientation_index(best_connections[0],
                                          spot_qs,
                                          ref_qs)

rot_qs = orientation.apply(ref_qs, inverse=True)

corr_spot_qs, zero_point = apply_zero_point_correction(spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]],
                                                       np.asarray(rot_qs)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)],
                                                       rsm.ai)

In [741]:
spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]]

array([[ 1.7062054 ,  0.3085213 , -0.26560894],
       [ 2.8740857 ,  0.5306466 , -0.6257356 ],
       [ 3.504583  ,  0.66707957, -1.0826373 ],
       [ 4.953252  , -0.71816707, -1.5395108 ],
       [ 3.157827  , -1.0764898 , -0.7234503 ],
       [ 5.5778666 , -0.5856224 , -1.9986482 ],
       [ 4.6777163 ,  0.88764435, -1.445399  ]], dtype=float32)

In [59]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

colors = ['k',] * len(best_connections)
# colors = ['black', 'blue', 'green', 'purple'] + ['none'] * (len(best_connections) - 4)

for best_fit_ind in range(len(best_connections)):
    if best_fit_ind > 0:
        break

    orientation, rmse = fit_orientation_index(best_connections[best_fit_ind],
                                              spot_qs,
                                              ref_qs)

    #rot_qs = ref_qs @ orientation.as_matrix()
    rot_qs = orientation.apply(ref_qs, inverse=True)
    q_mask = generate_q_mask(rot_qs,
                            (min_tth, max_tth),
                            (min_chi, max_chi),
                            (min_wavelength, max_wavelength),
                            ext=0.05)

    ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c=colors[best_fit_ind])

    hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
    fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
    for idx, hkl in enumerate(hkls):
        ax.text(*fit_spots[idx], str(hkl), fontsize=8, c=colors[best_fit_ind])

for edge in edges:
    ax.plot(*edge, c='gray', lw=1)

ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r', label='spots')
ax.scatter(0, 0, 0, s=10, c='blue')

ax.scatter(*np.asarray(corr_spot_qs).T, s=1, c='g')

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [97]:
np.linalg.norm(zero_point)

0.004525107835702484

In [170]:
from xrdmaptools.crystal.crystal import LatticeParameters
from scipy import linalg
unstrained = LatticeParameters.from_Phase(stibnite)

def get_strain_orientation(spot_qs, ref_hkls, unstrained):
    I = np.eye(3)

    spot_qs = np.asarray(spot_qs)
    ref_hkls = np.asarray(ref_hkls)

    if len(spot_qs) != len(ref_hkls):
        raise ValueError('Number of spots and assigned hkl indices must be equal.')

    # Fit deformation (displacement?) tensor
    # x carries orientation and lattice parameter information
    x, res, rnk, s = linalg.lstsq(ref_hkls,
                                  spot_qs)

    # Convert to Busing and Levy UB matrix. Remove 2pi factor
    UBmat = x.T / (2 * np.pi)

    # Polar decomposition to remove rotation components.
    # The leftover U is the active rotation and the inverse (transpose) is required for passive definition. Maybe...
    U, B = linalg.polar(UBmat, side='right')

    # Build strained lattice parameters from the polar decomposed stretch tensor (B)
    # B is defined in Busing and Levy and is comprised of the strained reciprocal lattice vectors
    strained = LatticeParameters.from_UBmat(B)

    # Get transformation tensor (T) between strained and unstrained lattices
    Tij = np.dot(strained.Amat, np.linalg.inv(unstrained.Amat))
    # Tij = np.dot(unstrained.Bmat, np.linalg.inv(B)) # Switched positions account for opposite sign

    # Decompose transformation tensor into strain components
    # Is this eulerian or langrangian strain? Or infinitesimal?
    # This is still in crystal coordinates too...
    eij_full = 0.5 * (Tij + Tij.T) - I
    eij_hydro = np.trace(eij_full) / 3
    eij_dev = eij_full - eij_hydro * I

    return eij_dev, eij_hydro, Rotation.from_matrix(U.T)


def apply_crystal_strain(ref_qs, eij_full):
    return np.asarray(ref_qs) @ (np.eye - eij_full)

def apply_sample_strain(ref_qs, eij_full):
    raise NotImplementedError()

def apply_crystal_strain_rotation(ref_qs, eij_full, orientaiton):
    raise NotImplementedError()
    return np.asarray(ref_qs) @ (np.eye - eij_full) @ orientation.T # maybe

In [175]:
eij_dev, eij_hydro, U = get_strain_orientation(
    spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]],
    np.asarray(ref_hkls)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)],
    unstrained)

eij_full = eij_dev + eij_hydro * np.eye(3)

print(eij_hydro * 1e3)
print(eij_dev * 1e3)

19.28674335618578
[[-7.20907167 11.74839409 18.14045776]
 [11.74839409 -4.20521111 22.4494644 ]
 [18.14045776 22.4494644  11.41428278]]


In [63]:
from xrdmaptools.crystal.crystal import LatticeParameters

unstrained = LatticeParameters.from_Phase(stibnite)
print(f'|a = {unstrained.a:.6f}\t|b = {unstrained.b:.6f}\t|c = {unstrained.c:.6f}')
print(f'|alpha = {np.degrees(unstrained.alpha):.3f}\t|beta = {np.degrees(unstrained.beta):.3f} \t|gamma = {np.degrees(unstrained.gamma):.3f}')

|a = 11.314000	|b = 3.837000	|c = 11.234000
|alpha = 90.000	|beta = 90.000 	|gamma = 90.000


In [67]:
from scipy import linalg
from xrdmaptools.crystal.crystal import LatticeParameters
I = np.eye(3)

unstrained = LatticeParameters.from_Phase(stibnite)

orientation, rmse = fit_orientation_index(best_connections[0],
                                          spot_qs,
                                          ref_qs)

x, res, rnk, s = linalg.lstsq(np.asarray(ref_hkls)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)],
                              spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]])


refs = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
spots = [[(2 * np.pi) / (11.314 * 1.01), 0, 0],
         [0, (2 * np.pi) / (3.837 * 0.98), 0],
         [0, 0, (2 * np.pi) / (11.234 * 1.03)]]

spots = np.asarray(spots) @ Rotation.from_euler('xzx', [50, 10, 0]).as_matrix()

# x, res, rnk, s = linalg.lstsq(refs, spots)

UBmat = x.T / (2 * np.pi)
U, B = linalg.polar(UBmat, side='right') # This U is transpose of orientation!!!
strained = LatticeParameters.from_UBmat(B)

Tij = np.dot(strained.Amat, np.linalg.inv(unstrained.Amat))
# Tij = np.dot(unstrained.Bmat, np.linalg.inv(B)) # Also a valid option...
eij_full = 0.5 * (Tij + Tij.T) - I
eij_hydro = np.trace(eij_full) / 3
eij_dev = eij_full - eij_hydro * I
print(eij_full * 1e3)
print(eij_hydro * 1e3)
print(eij_dev * 1e3)
# print(np.degrees(Rotation.from_matrix(orientation.as_matrix() @ U).magnitude()))
print(f'|a = {strained.a:.6f}\t|b = {strained.b:.6f}\t|c = {strained.c:.6f}')
print(f'|alpha = {np.degrees(strained.alpha):.3f}\t|beta = {np.degrees(strained.beta):.3f} \t|gamma = {np.degrees(strained.gamma):.3f}')

[[ 7.57015516  1.52463524  8.149306  ]
 [ 1.52463524 -2.46844595  4.40999834]
 [ 8.149306    4.40999834 12.83157843]]
5.977762545720762
[[ 1.59239261  1.52463524  8.149306  ]
 [ 1.52463524 -8.4462085   4.40999834]
 [ 8.149306    4.40999834  6.85381589]]
|a = 11.399649	|b = 3.827546	|c = 11.380054
|alpha = 89.498	|beta = 89.078 	|gamma = 89.825


In [461]:
Tij = np.dot(unstrained.Bmat, np.linalg.inv(B))
# Tij = np.dot(unstrained.Bmat, np.linalg.inv(UBmat))
# Tij = np.dot(UBmat, np.linalg.inv(unstrained.Bmat))
# Tij = np.dot(np.linalg.inv(UBmat), unstrained.Bmat)
# Tij = np.dot(np.linalg.inv(unstrained.Bmat), UBmat)
eij_full = 0.5 * (Tij + Tij.T) - I
eij_hydro = np.trace(eij_full) / 3
eij_dev = eij_full - eij_hydro * I
print(eij_full * 1e3)
print(eij_hydro * 1e3)
print(eij_dev * 1e3)

[[ 1.00000000e+01 -7.87980098e-14 -8.28183165e-14]
 [-7.87980098e-14 -2.00000000e+01  2.53140591e-14]
 [-8.28183165e-14  2.53140591e-14  3.00000000e+01]]
6.66666666666671
[[ 3.33333333e+00 -7.87980098e-14 -8.28183165e-14]
 [-7.87980098e-14 -2.66666667e+01  2.53140591e-14]
 [-8.28183165e-14  2.53140591e-14  2.33333333e+01]]


In [462]:
wij = 0.5 * (Tij - Tij.T)
w = [wij[2, 1], wij[0, 2], wij[1, 0]]
np.degrees(w)

array([ 2.08139997e-16, -4.74514000e-15,  4.51479340e-15])

In [463]:
Rotation.from_matrix(wij + I).as_euler('xzx', degrees=True)

  if await self.run_code(code, result, async_=asy):


array([2.08139997e-16, 6.54978725e-15, 0.00000000e+00])

In [153]:
from scipy import linalg
from xrdmaptools.crystal.crystal import LatticeParameters
I = np.eye(3)

unstrained = LatticeParameters.from_Phase(stibnite)

orientation, rmse = fit_orientation_index(best_connections[0],
                                          spot_qs,
                                          ref_qs)

# x, res, rnk, s = linalg.lstsq(np.asarray(ref_hkls)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)],
#                               corr_spot_qs)

UBmat = x.T / (2 * np.pi)
U, B = linalg.polar(UBmat, side='right') # This U is transpose of orientation!!!
strained = LatticeParameters.from_UBmat(B)

Tij = np.dot(strained.Amat, np.linalg.inv(unstrained.Amat))
eij_full = 0.5 * (Tij + Tij.T) - I
eij_hydro = np.trace(eij_full) / 3
eij_dev = eij_full - eij_hydro * I
print(eij_full * 1e3)
print(eij_hydro * 1e3)
print(eij_dev * 1e3)
print(np.degrees(Rotation.from_matrix(orientation.as_matrix() @ U).magnitude()))
print(f'|a = {strained.a:.6f}\t|b = {strained.b:.6f}\t|c = {strained.c:.6f}')
print(f'|alpha = {np.degrees(strained.alpha):.3f}\t|beta = {np.degrees(strained.beta):.3f} \t|gamma = {np.degrees(strained.gamma):.3f}')

[[19.59543486 12.2086758  19.6670827 ]
 [12.2086758  -2.60335478 25.57793052]
 [19.6670827  25.57793052  0.1073777 ]]
5.699819262548737
[[13.8956156  12.2086758  19.6670827 ]
 [12.2086758  -8.30317404 25.57793052]
 [19.6670827  25.57793052 -5.59244156]]
1.2494266459753685
|a = 11.535703	|b = 3.828158	|c = 11.258569
|alpha = 87.020	|beta = 87.751 	|gamma = 88.598


In [136]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

colors = ['k',] * len(best_connections)
# colors = ['black', 'blue', 'green', 'purple'] + ['none'] * (len(best_connections) - 4)

for best_fit_ind in range(len(best_connections)):
    if best_fit_ind > 0:
        break

    orientation, rmse = fit_orientation_index(best_connections[best_fit_ind],
                                              spot_qs,
                                              ref_qs)

    # rot_qs = ref_qs @ orientation.as_matrix()
    # rot_qs = orientation.apply(ref_qs, inverse=True)
    rot_qs = Rotation.from_matrix(U.T).apply(ref_qs, inverse=True)

    q_mask = generate_q_mask(rot_qs,
                            (min_tth, max_tth),
                            (min_chi, max_chi),
                            (min_wavelength, max_wavelength),
                            ext=0.05)

    ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c=colors[best_fit_ind])

    hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
    fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
    for idx, hkl in enumerate(hkls):
        ax.text(*fit_spots[idx], str(hkl), fontsize=8, c=colors[best_fit_ind])

for edge in edges:
    ax.plot(*edge, c='gray', lw=1)

ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')
ax.scatter(0, 0, 0, s=10, c='blue')

fit_qs = ref_qs @ (I - eij_full) @ U.T

ax.scatter(*fit_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='g')

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [139]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

colors = ['k',] * len(best_connections)
# colors = ['black', 'blue', 'green', 'purple'] + ['none'] * (len(best_connections) - 4)

for best_fit_ind in range(len(best_connections)):
    if best_fit_ind > 0:
        break

    orientation, rmse = fit_orientation_index(best_connections[best_fit_ind],
                                              spot_qs,
                                              ref_qs)

    # rot_qs = ref_qs @ orientation.as_matrix()
    # rot_qs = orientation.apply(ref_qs, inverse=True)
    rot_qs = Rotation.from_matrix(U.T).apply(ref_qs, inverse=True)

    q_mask = generate_q_mask(rot_qs,
                            (min_tth, max_tth),
                            (min_chi, max_chi),
                            (min_wavelength, max_wavelength),
                            ext=0)

    ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c=colors[best_fit_ind], label='unstrained ref')

    hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
    # fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
    fit_spots = corr_spot_qs
    for idx, hkl in enumerate(hkls):
        ax.text(*fit_spots[idx], str(hkl), fontsize=8, c=colors[best_fit_ind])

for edge in edges:
    ax.plot(*edge, c='gray', lw=1)

ax.scatter(*np.asarray(corr_spot_qs).T, s = 1, c='r', label='offset spots')
ax.scatter(0, 0, 0, s=10, c='blue', label='(000)')

fit_qs = ref_qs @ (I - eij_full) @ U.T

ax.scatter(*fit_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='g', label='strained ref')

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')
ax.legend()

fig.show()

In [989]:
from scipy import linalg
I = np.eye(3)

orientation, rmse = fit_orientation_index(best_connections[0],
                                          spot_qs,
                                          ref_qs)

rot_spot_qs = orientation.apply(spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]], inverse=False)

x, res, rnk, s = linalg.lstsq(np.asarray(ref_hkls)[best_connections[0][np.nonzero(~np.isnan(best_connections[0]))[0]].astype(int)],
                              spot_qs[np.nonzero(~np.isnan(best_connections[0]))[0]])

UBmat = x.T / (2 * np.pi)
U, B = linalg.polar(UBmat, side='left')
strained = LatticeParameters.from_UBmat(B)

U, B = linalg.polar(UBmat, side='right')
strained = LatticeParameters.from_UBmat(B)

Tij = np.dot(strained.Amat, np.linalg.inv(unstrained.Amat))
eij_full = 0.5 * (Tij + Tij.T) - I
eij_hydro = np.trace(eij_full) / 3
eij_dev = eij_full - eij_hydro * I

print(eij_full.round(6) * 1e3)
print(eij_hydro.round(6) * 1e3)
print(eij_dev.round(6) * 1e3)

print(np.degrees(Rotation.from_matrix(U).as_euler('xzx')).round(3))

[[-1.238 -1.655 -2.317]
 [-1.655  6.24  -3.025]
 [-2.317 -3.025  9.454]]
4.819
[[-6.057 -1.655 -2.317]
 [-1.655  1.421 -3.025]
 [-2.317 -3.025  4.635]]
[-47.745  11.565  51.037]


In [868]:
spots = [[(2 * np.pi) / (11.314 * 1.001), 0, 0],
         [0, (2 * np.pi) / (3.837 * 1.002), 0],
         [0, 0, (2 * np.pi) / (11.234 * 1.003)]]
spots

[[0.5547912673484892, 0, 0],
 [0, 1.6342569765810016, 0],
 [0, 0, 0.5576279268993436]]

In [873]:
([[(2 * np.pi) / (11.314), 0, 0],
[0, (2 * np.pi) / (3.837), 0],
[0, 0, (2 * np.pi) / (11.234)]] @ (I - eij_full)).round(6)

array([[0.554791, 0.      , 0.      ],
       [0.      , 1.63425 , 0.      ],
       [0.      , 0.      , 0.557623]])

In [569]:
for lattice in [unstrained, strained]:
    print(f'a = {lattice.a}')
    print(f'b = {lattice.b}')
    print(f'c = {lattice.c}')
    print(f'alpha = {np.degrees(lattice.alpha)}')
    print(f'beta = {np.degrees(lattice.beta)}')
    print(f'gamma = {np.degrees(lattice.gamma)}')

a = 11.314
b = 3.837
c = 11.234
alpha = 90.0
beta = 90.0
gamma = 90.0
a = 11.29999799841607
b = 3.8609654896003085
c = 11.340530969453713
alpha = 90.34248192109521
beta = 90.26303930871917
gamma = 90.18852746913046


In [370]:
# Brute-force exhaustive search to find all higher order connection from list of valid pairs

valid_conn = dict(zip(
    range(2, 21),
    [[] for _ in range(20)]
))

# Already know all valid pairs
valid_conn[2] = np.asarray(expanded_pair_list[keep_pair_mask])

def print_output(pair_iteration):
    out_str = f"Pair {pair_iteration + 1}/{len(valid_conn[2])}   "

    for rank in range(3, len(valid_conn.keys()) + 2):
        rank_len = len(valid_conn[rank])
        if rank_len > 0:
            out_str += f"|{rank}s: {rank_len}   "

    print(out_str, end='\r')


def scrub_repeats(connections_to_check,
                   connections_to_scrub):
    
    connections_to_keep = []
    
    if len(connections_to_check) > 0:
        connections_to_check = np.asarray(connections_to_check)
    else:
        return connections_to_scrub # Assumes no repeats in input set...

    for connection in connections_to_scrub:
        indices = np.nonzero(~np.isnan(connection))[0]
        vals = connection[indices]

        if np.any(np.all([connections_to_check[:, index] == val
                       for index, val in zip(indices, vals)], axis=0)):
            continue
        else:
            connections_to_keep.append(connection)
        
    return connections_to_keep


print('Finding valid higher order connections...')
rem_pair_mask = np.asarray([True,] * len(valid_conn[2]))
for i in range(len(valid_conn[2])):

    # def trunc(values, decs=0):
    #     return np.trunc(values * 10**decs) / (10**decs)

    # # similar_pair_mask = trunc(pair_rmse, 13) == trunc(pair_rmse[2], 13)
    # similar_pair_mask = np.round(pair_rmse, 10) == np.round(pair_rmse[i], 10)
    # sym_pairs = expanded_pair_list[similar_pair_mask]
    # if len(sym_pairs) > 2:
    #     most_positive_index = np.argmax([np.sign(np.asarray(ref_hkls)[pair[np.nonzero(~np.isnan(pair))[0]].astype(int)]).sum()
    #                                     for pair in sym_pairs])
    #     # most_positive_index = np.sign(np.asarray(ref_hkls)[sym_pairs[np.nonzero(~np.isnan(sym_pairs))].astype(int)].reshape(-1, 2, 3)).sum(axis=(1, 2)).argmax()

    #     if i != np.nonzero(similar_pair_mask)[0][most_positive_index]:
    #         continue

    print_output(i)
    pair = valid_conn[2][i]
    pair_indices = np.nonzero(~np.isnan(pair))[0]
    pair_vals = pair[pair_indices]

    # rem_pair_mask = np.asarray([True,] * len(expanded_pair_list))
    # rem_pair_mask[: i + 1] = False

    # Eliminate remaining pair misorientations above some threshold
    R0 = pair_orientations[keep_pair_mask][i]
    misorientations = []
    for R1 in pair_orientations[keep_pair_mask][rem_pair_mask]:
        if not isinstance(R0, float) and not isinstance(R1, float):
            misorientations.append(np.degrees(Rotation.from_matrix(R0.as_matrix() @ R1.as_matrix().T).magnitude()))
        else:
            misorientations.append(0) # Cannot neglect colinear pairs

    misorientation_mask = rem_pair_mask.copy() # redundant, but differentiates
    misorientation_mask[np.nonzero(rem_pair_mask)] = np.asarray(misorientations) < 360

    # Initial triplet candidates
    trip_mask = np.any([valid_conn[2][:, pair_indices[i]] == pair_vals[i] for i in range(2)], axis=0)
    trip_mask = trip_mask & misorientation_mask
    if np.sum(trip_mask) < 1:
        continue
    
    next_pair0_mask = valid_conn[2][trip_mask, pair_indices[0]] == pair_vals[0]
    next_pairs0 = valid_conn[2][trip_mask][next_pair0_mask]
    next_pairs0[:, pair_indices[0]] = np.nan

    next_pair1_mask = valid_conn[2][trip_mask, pair_indices[1]] == pair_vals[1]
    next_pairs1 = valid_conn[2][trip_mask][next_pair1_mask]
    next_pairs1[:, pair_indices[1]] = np.nan

    partials3 = np.asarray([match for match in next_pairs0 if match in next_pairs1])
    full3 = []
    for triplet in partials3:
        new_triplet = pair.copy()
        new_triplet[~np.isnan(triplet)] = triplet[~np.isnan(triplet)].ravel()
        valid_conn[3].append(new_triplet)
        full3.append(new_triplet)

    # Must have at least 2 triplets to search for qudruplets
    if len(partials3) < 2:
        continue

    continue

    for idx4 in range(len(partials3)):
        print_output(i)
        partials4, full4 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                  partials3,
                                                  full3[idx4],
                                                  common_partial=partials3[idx4])
        
        valid_conn[4].extend(scrub_repeats(valid_conn[4], full4))
        
        if len(partials4) < 2:
            continue

        for idx5 in range(len(partials4)):
            print_output(i)
            partials5, full5 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                      partials4,
                                                      full4[idx5],
                                                      common_partial=partials4[idx5])
            
            valid_conn[5].extend(scrub_repeats(valid_conn[5], full5))
            
            if len(partials5) < 2:
                continue

            for idx6 in range(len(partials5)):
                print_output(i)
                partials6, full6 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                          partials5,
                                                          full5[idx6],
                                                          common_partial=partials5[idx6])
                
                valid_conn[6].extend(scrub_repeats(valid_conn[6], full6))
                
                if len(partials6) < 2:
                    continue
                    
                for idx7 in range(len(partials6)):
                    print_output(i)
                    partials7, full7 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                              partials6,
                                                              full6[idx7],
                                                              common_partial=partials6[idx7])
                    
                    valid_conn[7].extend(scrub_repeats(valid_conn[7], full7))
                    
                    if len(partials7) < 2:
                        continue

                    for idx8 in range(len(partials7)):
                        print_output(i)
                        partials8, full8 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                                  partials7,
                                                                  full7[idx8],
                                                                  common_partial=partials7[idx8])
                        
                        valid_conn[8].extend(scrub_repeats(valid_conn[8], full8))
                        
                        if len(partials8) < 2:
                            continue

                        for idx9 in range(len(partials8)):
                            print_output(i)
                            partials9, full9 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                                      partials8,
                                                                      full8[idx9],
                                                                      common_partial=partials8[idx9])
                            
                            valid_conn[9].extend(scrub_repeats(valid_conn[9], full9))
                            
                            if len(partials9) < 2:
                                continue
                        
                            for idx10 in range(len(partials9)):
                                print_output(i)
                                partials10, full10 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                                            partials9,
                                                                            full9[idx10],
                                                                            common_partial=partials9[idx10])
                                
                                valid_conn[10].extend(scrub_repeats(valid_conn[10], full10))
                                
                                if len(partials10) < 2:
                                    continue

                                for idx11 in range(len(partials10)):
                                    print_output(i)
                                    partials11, full11 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                                                partials10,
                                                                                full10[idx11],
                                                                                common_partial=partials10[idx11])
                                    
                                    valid_conn[11].extend(scrub_repeats(valid_conn[11], full11))
                                    
                                    if len(partials11) < 2:
                                        continue

                                    for idx12 in range(len(partials11)):
                                        print_output(i)
                                        partials12, full12 = _find_next_connections(valid_conn[2][misorientation_mask],
                                                                                    partials11,
                                                                                    full11[idx12],
                                                                                    common_partial=partials11[idx12])
                                        
                                        valid_conn[12].extend(scrub_repeats(valid_conn[12], full12))
                                        
                                        if len(partials12) < 2:
                                            continue
                                        
                                        raise RuntimeError('Currently no supported connections greater than 12!')
    rem_pair_mask[i] = False 
    #break         

Finding valid higher order connections...
Pair 17/125481   |3s: 1303   

KeyboardInterrupt: 

In [23]:
# Decompose connections
# Based on minimum of most connected structure (could do most connected -1)

def fit_orientation_index(connection,
                          spot_qs,
                          ref_qs):
    fit_spot_indices = np.nonzero(~np.isnan(connection))
    fit_spot_qs = spot_qs[fit_spot_indices]

    fit_ref_indices = connection[fit_spot_indices].astype(int)
    # fit_ref_hkls = np.asarray(ref_hkls)[fit_ref_indices]
    fit_ref_qs = np.asarray(ref_qs)[fit_ref_indices]

    fit_orientation, fit_rssd = Rotation.align_vectors(fit_ref_qs, fit_spot_qs)

    fit_euclidean_errors = [np.sqrt(np.sum([(p - q)**2 for p, q in zip(v1, v2)]))
                            for v1, v2 in zip(fit_ref_qs, fit_orientation.apply(fit_spot_qs, inverse=False))]
    rmse = np.mean(fit_euclidean_errors)

    return fit_orientation, rmse

def check_collinearity(connection_pair, ref_hkls):
    pair_ref_hkls = [ref_hkls[int(ind)] for ind
                    in connection_pair[~np.isnan(connection_pair)]]

    # Check for colinearity; 3D orientation cannot be determined
    pair_divs = np.array(pair_ref_hkls[0]) / np.array(pair_ref_hkls[1])
    return len(np.unique(pair_divs[~np.isnan(pair_divs)])) < 2


mutable_conn = dict(zip(
    valid_conn.keys(),
    [value.copy() for value in valid_conn.values()]    
))

excluded_spot_indices = []
best_connections = []
best_orientations = []
best_errors = []
DECOMPOSING_PATTERN = True
while DECOMPOSING_PATTERN:
    # Remove unused connections
    del_keys = []
    for key, value in mutable_conn.items():
        mutable_conn[key] = np.asarray(value)
        if len(value) == 0:
            del_keys.append(key)
    for key in del_keys:
        del mutable_conn[key]

    # Find best connected rank
    largest_key = np.max([key for key, value in mutable_conn.items() if len(value) > 0])
    largest_connections = mutable_conn[largest_key]

    # Find best bit orientation
    fit_orientations = []
    fit_rmse = []
    print(f'Searching for best fit rank {largest_key} connection out of {len(largest_connections)} possibilities.')
    for connection in largest_connections:
        if largest_key == 2:
            if check_collinearity(connection, ref_hkls):
                continue

        fit_orientation, rmse = fit_orientation_index(connection)    
        fit_orientations.append(fit_orientation)
        fit_rmse.append(rmse)
    
    if len(fit_orientations) < 1:
        DECOMPOSING_PATTERN = False

    best_connections.append(largest_connections[np.argmin(fit_rmse)])
    best_orientations.append(fit_orientations[np.argmin(fit_rmse)])
    best_errors.append(fit_rmse[np.argmin(fit_rmse)])
    fit_spot_indices = np.nonzero(~np.isnan(largest_connections[np.argmin(fit_rmse)]))[0]
    excluded_spot_indices.extend(fit_spot_indices)

    # Scrub all excluded_spot_indices from remaining connections
    conn_masks = []
    for key, connections in mutable_conn.items():
        conn_mask = [True,] * len(connections)
        for idx, connection in enumerate(connections):
            conn_mask[idx] = np.all([index not in np.nonzero(~np.isnan(connection))[0]
                                    for index in excluded_spot_indices])
        mutable_conn[key] = connections[conn_mask]

    if len(mutable_conn[2]) < 1:
        DECOMPOSING_PATTERN = False


NameError: name 'spot_qs' is not defined

In [85]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})


colors = ['k',] * len(best_connections)
colors = ['black', 'blue', 'green', 'purple'] + ['none'] * (len(best_connections) - 4)

for best_fit_ind in range(len(best_connections)):
    if best_fit_ind > 0:
        break

    if np.sum(~np.isnan(best_connections[best_fit_ind])) < 3:
        continue

    rot_qs = ref_qs @ best_orientations[best_fit_ind].as_matrix()
    q_mask = np.all([
        np.all([rot_qs[:, 0] > q_mins[0], rot_qs[:, 0] < q_maxs[0]], axis=0),
        np.all([rot_qs[:, 1] > q_mins[1], rot_qs[:, 1] < q_maxs[1]], axis=0),
        np.all([rot_qs[:, 2] > q_mins[2], rot_qs[:, 2] < q_maxs[2]], axis=0),
    ], axis=0)
    ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c=colors[best_fit_ind])

    hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
    fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
    for idx, hkl in enumerate(hkls):
        ax.text(*fit_spots[idx], str(hkl), fontsize=8, c=colors[best_fit_ind])


ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')
ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [141]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

best_fit_ind = 0

rot_qs = ref_qs @ best_orientations[best_fit_ind].as_matrix()
q_mask = np.all([
    np.all([rot_qs[:, 0] > q_mins[0], rot_qs[:, 0] < q_maxs[0]], axis=0),
    np.all([rot_qs[:, 1] > q_mins[1], rot_qs[:, 1] < q_maxs[1]], axis=0),
    np.all([rot_qs[:, 2] > q_mins[2], rot_qs[:, 2] < q_maxs[2]], axis=0),
], axis=0)
ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='k')

hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
for idx, hkl in enumerate(hkls):
    ax.text(*fit_spots[idx], str(hkl), fontsize=8)


best_fit_ind = 1

rot_qs = ref_qs @ best_orientations[best_fit_ind].as_matrix()
q_mask = np.all([
    np.all([rot_qs[:, 0] > q_mins[0], rot_qs[:, 0] < q_maxs[0]], axis=0),
    np.all([rot_qs[:, 1] > q_mins[1], rot_qs[:, 1] < q_maxs[1]], axis=0),
    np.all([rot_qs[:, 2] > q_mins[2], rot_qs[:, 2] < q_maxs[2]], axis=0),
], axis=0)
ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='blue')

hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
for idx, hkl in enumerate(hkls):
    ax.text(*fit_spots[idx], str(hkl), fontsize=8, c='blue')


best_fit_ind = 2

rot_qs = ref_qs @ best_orientations[best_fit_ind].as_matrix()
q_mask = np.all([
    np.all([rot_qs[:, 0] > q_mins[0], rot_qs[:, 0] < q_maxs[0]], axis=0),
    np.all([rot_qs[:, 1] > q_mins[1], rot_qs[:, 1] < q_maxs[1]], axis=0),
    np.all([rot_qs[:, 2] > q_mins[2], rot_qs[:, 2] < q_maxs[2]], axis=0),
], axis=0)
ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='green')

hkls = [ref_hkls[int(ind)] for ind in best_connections[best_fit_ind][~np.isnan(best_connections[best_fit_ind])]]
fit_spots = spot_qs[np.nonzero(~np.isnan(best_connections[best_fit_ind]))[0]]
for idx, hkl in enumerate(hkls):
    ax.text(*fit_spots[idx], str(hkl), fontsize=8, c='green')


ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')
ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [31]:
largest_key = np.max([key for key, value in valid_conn.items() if len(value) > 0])
largest_connections = valid_conn[largest_key]

print(f'Searching for best fit connection out of {len(largest_connections)} of rank {largest_key} connections.')
fit_orientations = []
fit_mean_errors = []
fit_ref_all_hkls = []
fit_all_spots = []
for connection in largest_connections:

    #connection = valid_octets[0]

    fit_spot_indices = np.nonzero(~np.isnan(connection))
    fit_spot_qs = spot_qs[fit_spot_indices]
    fit_all_spots.append(fit_spot_qs)

    fit_ref_indices = connection[fit_spot_indices].astype(int)
    fit_ref_hkls = np.asarray(ref_hkls)[fit_ref_indices]
    fit_ref_all_hkls.append(fit_ref_hkls)
    fit_ref_qs = np.asarray(ref_qs)[fit_ref_indices]

    fit_orientation, fit_rssd = Rotation.align_vectors(fit_ref_qs, fit_spot_qs)

    fit_euclidean_error = [np.sqrt(np.sum([(p - q)**2 for p, q in zip(v1, v2)]))
                        for v1, v2 in zip(fit_ref_qs, fit_orientation.apply(fit_spot_qs, inverse=False))]
    mean_euclidean_error = np.mean(fit_euclidean_error)

    fit_orientations.append(fit_orientation)
    fit_mean_errors.append(mean_euclidean_error)
fit_mean_errors

Searching for best fit connection out of 12 of rank 9 connections.


[0.04530987972079391,
 0.03632377315726579,
 0.04530987972079413,
 0.036323773157265773,
 0.045309879720794394,
 0.036323773157265746,
 0.04530987972079405,
 0.03632377315726581,
 0.04560220177884898,
 0.045602201778849105,
 0.045602201778849105,
 0.045602201778848994]

In [35]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=200, subplot_kw={'projection':'3d'})

min_ind = np.argmin(fit_mean_errors)
min_ind = 8

ax.scatter(*np.asarray(spot_qs).T, s = 1, c='r')

rot_qs = ref_qs @ fit_orientations[min_ind].as_matrix()
q_mask = np.all([
    np.all([rot_qs[:, 0] > q_mins[0], rot_qs[:, 0] < q_maxs[0]], axis=0),
    np.all([rot_qs[:, 1] > q_mins[1], rot_qs[:, 1] < q_maxs[1]], axis=0),
    np.all([rot_qs[:, 2] > q_mins[2], rot_qs[:, 2] < q_maxs[2]], axis=0),
], axis=0)

ax.scatter(*rot_qs[q_mask].T, s=ref_fs[q_mask] * 0.1, c='k')

for idx, hkl in enumerate(fit_ref_all_hkls[min_ind]):
    ax.text(*fit_all_spots[min_ind][idx], str(hkl), fontsize=8)

ax.set_xlim(q_mins[0], q_maxs[0])
ax.set_ylim(q_mins[1], q_maxs[1])
ax.set_zlim(q_mins[2], q_maxs[2])

ax.set_xlabel('qx [Å⁻¹]')
ax.set_ylabel('qy [Å⁻¹]')
ax.set_zlabel('qz [Å⁻¹]')
ax.set_aspect('equal')

fig.show()

In [22]:
def are_collinear(vectors):
    vecs = np.asarray(vectors)
    collinear_flag = True

    # Probably faster. Not easy to perform pairwise
    if len(vecs) == 2:
        if np.sum(np.abs(np.cross(*vecs))) > 1e-8:
            collinear_flag = False
        return collinear_flag

    # Pairwise analysis fo list of vectors
    const_list = []
    for ind in range(vecs.shape[1]):
        vecs_axis = vecs[:, ind]
        if not np.any(vecs_axis == 0):
            const = np.abs(vecs_axis[:, np.newaxis] / vecs_axis[np.newaxis, :])
            const_list.append(np.round(const, 3))

    combos = list(combinations(range(vecs.shape[1]), 2))
    if len(combos) > 1:
        combos.pop(-1) # last index is redundant

    for combo in combos:
        if np.any(const_list[combo[0]] != const_list[combo[1]]):
            collinear_flag = False
            break
        
    return collinear_flag


def vector_angle(v1, v2, degrees=False):
    angle = np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1, axis=-1) *  np.linalg.norm(v2, axis=-1)))
    if degrees:
        angle = np.degrees(angle)
    return angle


def multi_vector_angles(v1s, v2s, degrees=False):
    v1_units = v1s / np.linalg.norm(v1s, axis=1).reshape(-1, 1)
    v2_units = v2s / np.linalg.norm(v2s, axis=1).reshape(-1, 1)
    angles = np.arccos(np.inner(v1_units, v2_units).round(6)) # Not happy about the round...
    if degrees:
        angles = np.degrees(angles)
    return angles

In [24]:

def pair_voting_indexing(spot_qs,
                         spot_ints,
                         phase,
                         near_q=0.005,
                         near_angle=5):

    spot_q_mags = np.linalg.norm(spot_qs, axis=1)
    max_q = np.max(spot_q_mags)

    # Combine these at some point...
    stibnite.get_hkl_reflections()
    ref_hkls, ref_qs, ref_fs = generate_reciprocal_lattice(stibnite, 1.15 * max_q) # 15% window
    ref_q_mags = np.linalg.norm(ref_qs, axis=1)

    # Minimum step size in q-space.
    min_q = np.min(np.linalg.norm(phase.Q([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), axis=0))

    # Find difference between measured and calculated q magnitudes
    diff_arr = np.abs(spot_q_mags[:, np.newaxis]
                    - ref_q_mags[np.newaxis, :])

    # Eliminate any reflections outside phase-allowed spots
    phase_mask = np.any(diff_arr < near_q, axis=1)
    diff_arr = diff_arr[phase_mask]
    spot_qs = spot_qs[phase_mask]
    spot_q_mags = spot_q_mags[phase_mask]
    spot_ints = spot_ints[phase_mask]

    # Generate all pairs of spots which are crystallographically feasible
    spot_pair_indices = list(combinations(range(len(spot_qs)), 2))
    spot_diff_arr = np.abs(spot_q_mags[:, np.newaxis]
                        - spot_q_mags[np.newaxis, :])
    allowed_pairs = [spot_diff_arr[indices] > min_q * 0.85 for indices in spot_pair_indices]
    spot_pair_indices = np.asarray(spot_pair_indices)[allowed_pairs]

    # Compute all angles
    spot_angles = multi_vector_angles(spot_qs, spot_qs, degrees=True)
    ref_angles = multi_vector_angles(ref_qs, ref_qs, degrees=True)

    votes = [[] for _ in range(len(spot_qs))]

    for pair in spot_pair_indices:
        ref_combos = list(product(*[np.nonzero(diff_arr[i] < near_q)[0] for i in pair]))
        angle_mask = [np.abs(spot_angles[tuple(pair)] - ref_angles[tuple(combo)]) < near_angle for combo in ref_combos]
        for ref_vote in np.asarray(ref_combos)[angle_mask]:
            votes[pair[0]].append(ref_vote[0])
            votes[pair[1]].append(ref_vote[1])

    guess_refl = [None,] * len(spot_qs)

    for i, vote in enumerate(votes):
        spot_inds = np.unique(vote)
        spot_votes = [np.sum(vote == ind) for ind in spot_inds]
        sorted_votes = sorted(spot_votes, reverse=True)
        sorted_ind = [x for x, _ in sorted(zip(spot_inds, spot_votes),
                                        key=lambda pair: pair[0])]

        guess_refl[i] = dict(zip(sorted_ind, sorted_votes))

    return guess_refl, spot_qs, ref_hkls