In [10]:
%matplotlib qt5

from pathlib import Path
import pickle
from orix import io, plot
from hyperspy import signal as signal
import numpy as np #General numerical and matrix support
from pyxem.utils import indexation_utils as iutls
from orix.quaternion import Rotation, symmetry, Orientation
from orix.vector.vector3d import Vector3d
from orix import plot
import matplotlib
matplotlib.rcParams["backend"] = "Agg"
from orix.crystal_map.crystal_map import CrystalMap
import itertools
import dask.array as da
import time

## Import crystal maps Crystal maps
Not needed for angle cleanup method

In [2]:
# Import xmaps

# xmap1_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_123856_xmap.h5")
# xmap2_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_131643_xmap.h5")
# xmap3_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_135554_xmap.h5")
# xmap4_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_144154_xmap.h5")
# xmap5_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_160122_xmap.h5")
# xmap6_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_164221_xmap.h5")
xmap1_ = io.load("D:/Master_thesis/Centered_xmaps/aligned_xmap_1.h5")
xmap2_ = io.load("D:/Master_thesis/Centered_xmaps/aligned_xmap_2.h5")
xmap3_ = io.load("D:/Master_thesis/Centered_xmaps/aligned_xmap_3.h5")
xmap4_ = io.load("D:/Master_thesis/Centered_xmaps/aligned_xmap_4.h5")
xmap5_ = io.load("D:/Master_thesis/Centered_xmaps/aligned_xmap_5.h5")
xmap6_ = io.load("D:/Master_thesis/Data_from_workstation/180224/20240208_164221_xmap.h5")

In [3]:
# Functions for plotting color maps
def make_colormap(xmap): # Plotting the orientation map 
    ipf_key = plot.IPFColorKeyTSL(xmap.phases[0].point_group, direction=Vector3d.zvector())
    # ipf_key.plot()
    rgb = ipf_key.orientation2color(xmap.orientations)
    xmap.plot(rgb, remove_padding=True, return_figure=True,scalebar_properties=dict(location="lower left", frameon=False))
    # xmap.plot(rgb, overlay = xmap.correlation[:,0],remove_padding=True, return_figure=True,scalebar_properties=dict(location="lower left", frameon=True))
    
def make_colormap_error(xmap,error): # Plot orientation map overlaid with the TMA
    ipf_key = plot.IPFColorKeyTSL(xmap.phases[0].point_group, direction=Vector3d.zvector())
    # ipf_key.plot()
    rgb = ipf_key.orientation2color(xmap.orientations)
    error_overlay = np.ndarray.flatten(error)
    error_overlay = 1-(error_overlay/np.max(error_overlay))
    xmap.plot(rgb, overlay=error_overlay, remove_padding=True, return_figure=True,scalebar_properties=dict(location="lower left", frameon=True))

## Angle cleanup method
### Total misorientation angle (TMA):
$\alpha_{TMA} = |m_{12}-\alpha_{12}|+|m_{13}-\alpha_{13}|+|m_{23}-\alpha_{23}|$

m is calculated misorientation angle between orientations, $\alpha$ is actual angle difference between tilts.

In [5]:
def dict_slice(dict,start_y,end_y): # Slice the dictionaries
    keys = dict.keys()
    for i in keys:
        dict[f'{i}'] = dict[f'{i}'][start_y:end_y,:]
    return dict


def combos_func(iterables_array): # Find all combinations of an array
    return np.array(list(itertools.product(*iterables_array)))

def change_step_size(xmap, pixel_size, scan_unit): # Change the step size of an orix CrystalMap, also update rotations
    x = xmap.x * pixel_size
    y = xmap.y * pixel_size
    rot = xmap.rotations
    phaseid = xmap.phase_id
    prop = xmap.prop
    is_in_data = xmap.is_in_data
    phaselist = xmap.phases
    new_xmap = CrystalMap(rotations = rot,
                        phase_id = phaseid,
                        x = x,
                        y = y,
                        prop = prop,
                        scan_unit = scan_unit,
                        is_in_data = is_in_data,
                        phase_list=phaselist)
    return new_xmap


def fast_compare(best1,best2,best3,sign1,sign2,sign3,corr1,corr2,corr3,res1,res2,res3,errors,n_best): # Find all combinations, calculate TMA and choose smallest TMA
    a,b,c = np.shape(best1)
    nr = 5
    tilt1,tilt2,tilt3 = np.ones((a,b,c,nr)),np.ones((a,b,c,nr)),np.ones((a,b,c,nr))
    numerated = np.linspace(0,c-1,c,dtype='int')


    tilt1[:,:,:,0],tilt2[:,:,:,0],tilt3[:,:,:,0] = best1,best2,best3
    tilt1[:,:,:,1],tilt2[:,:,:,1],tilt3[:,:,:,1] = sign1,sign2,sign3
    tilt1[:,:,:,2],tilt2[:,:,:,2],tilt3[:,:,:,2] = corr1,corr2,corr3
    tilt1[:,:,:,3],tilt2[:,:,:,3],tilt3[:,:,:,3] = numerated,numerated,numerated

    # make all possible combinations of the three tilts
    iterables_array = np.zeros((a,b,3,c,nr))
    iterables_array[:,:,0],iterables_array[:,:,1],iterables_array[:,:,2] = tilt1,tilt2,tilt3
    combos = np.ones((a,b,c*c*c,3,nr))
    for iy,ix in np.ndindex(best1.shape[0:2]):
        combos[iy,ix] = combos_func(iterables_array[iy,ix])

    # find misorientation angles between all three:
    misoris_tot = np.zeros((a,b,c**3))
    for iy,ix in np.ndindex(a,b):
        ind_res1 = combos[iy,ix,:,0,3].astype(int)
        ind_res2 = combos[iy,ix,:,1,3].astype(int)
        ind_res3 = combos[iy,ix,:,2,3].astype(int)
        os1 = Orientation.from_euler(np.deg2rad(res1['orientation'][iy,ix,ind_res1]),symmetry=symmetry.Oh)#.map_into_symmetry_reduced_zone()
        os2 = Orientation.from_euler(np.deg2rad(res2['orientation'][iy,ix,ind_res2]),symmetry=symmetry.Oh)#.map_into_symmetry_reduced_zone()
        os3 = Orientation.from_euler(np.deg2rad(res3['orientation'][iy,ix,ind_res3]),symmetry=symmetry.Oh)#.map_into_symmetry_reduced_zone()
        
        # Calculate the TMA
        misoris_tot[iy,ix] = (np.abs(os1.angle_with(os2,degrees = True)-5) +   # tilts 1 to 2
                        np.abs(os1.angle_with(os3,degrees = True)-5) + # tilts 1 to 
                        np.abs(os2.angle_with(os3,degrees = True)-7.07))    # tilts 2 to 3 

    min_elems = np.ones((a,b,3,nr))
    for iy,ix in np.ndindex(a,b):
        ind = np.argmin(misoris_tot[iy,ix])

        # if misoris_tot[iy,ix,ind]<0.5: # Could work for not having to update all orientations every time
        #     continue
        min_elems[iy,ix] = combos[iy,ix,ind]
        min_elems[iy,ix,:,-1] = misoris_tot[iy,ix,ind]
        comb_2 = min_elems

        res1['template_index'][iy,ix,0] = comb_2[iy,ix,0,0].astype(int)
        res2['template_index'][iy,ix,0] = comb_2[iy,ix,1,0].astype(int)
        res3['template_index'][iy,ix,0] = comb_2[iy,ix,2,0].astype(int)

        res1['mirrored_template'][iy,ix,0] = comb_2[iy,ix,0,1].astype(int)
        res2['mirrored_template'][iy,ix,0] = comb_2[iy,ix,1,1].astype(int)
        res3['mirrored_template'][iy,ix,0] = comb_2[iy,ix,2,1].astype(int)

        ind_res1 = comb_2[iy,ix,0,3].astype(int)
        ind_res2 = comb_2[iy,ix,1,3].astype(int)
        ind_res3 = comb_2[iy,ix,2,3].astype(int)

        res1['orientation'][iy,ix,0] = res1['orientation'][iy,ix,ind_res1]
        res2['orientation'][iy,ix,0] = res2['orientation'][iy,ix,ind_res2]
        res3['orientation'][iy,ix,0] = res3['orientation'][iy,ix,ind_res3]

        res1['correlation'][iy,ix,0] = comb_2[iy,ix,0,2].astype(int)
        res2['correlation'][iy,ix,0] = comb_2[iy,ix,1,2].astype(int)
        res3['correlation'][iy,ix,0] = comb_2[iy,ix,2,2].astype(int)

        errors[iy,ix] = comb_2[iy,ix,0,-1]
        n_best[iy,ix,0] = ind_res1
        n_best[iy,ix,1] = ind_res2
        n_best[iy,ix,2] = ind_res3


def cleanup_for_part(res1,res2,res3,errors,n_best,start_y,end_y): # Takes parts of the dataset and calls fast_compare()
    res1,res2,res3 = dict_slice(res1.copy(),start_y,end_y),dict_slice(res2.copy(),start_y,end_y),dict_slice(res3.copy(),start_y,end_y)

    d_best1, d_best2, d_best3 = res1['template_index'], res2['template_index'], res3['template_index']
    d_sign1, d_sign2, d_sign3  = res1['mirrored_template'], res2['mirrored_template'], res3['mirrored_template']
    d_corr1, d_corr2, d_corr3 = res1['correlation'], res2['correlation'], res3['correlation']    
    d_sign1 = d_sign1.astype(int)
    d_sign1[d_sign1==0] = -1
    d_sign2 = d_sign2.astype(int)
    d_sign2[d_sign2==0] = -1
    d_sign3 = d_sign3.astype(int)
    d_sign3[d_sign3==0] = -1
    a,b,c = np.shape(d_best1)[0:3]

    c_size = 5 # Chunk size

    dask_array_return = da.map_blocks(func=fast_compare,best1=d_best1,best2=d_best2,best3=d_best3,
                                      sign1=d_sign1,sign2=d_sign2,sign3=d_sign3,
                                      corr1=d_corr1,corr2=d_corr2,corr3=d_corr3,
                                      res1=res1,res2=res2,res3=res3,errors=errors,n_best=n_best, dtype=object,chunks=(c_size,c_size,c))

    dask_array_return.compute()

    
            
def cleanup(res1,res2,res3,phase): # Main function for the ACM. Call to run the ACM. Returns the updated Crystal maps, the error_array (TMA), and the n_best_array
    start_time = time.process_time()
    cs = symmetry.Oh

    y,x = np.shape(res1['template_index'])[0:2]
    errors = np.zeros((y,x))
    n_best = np.zeros((y,x,3))

    # Nan to Num
    res1['correlation'][:,:] = np.nan_to_num(res1['correlation'][:,:])
    res2['correlation'][:,:] = np.nan_to_num(res2['correlation'][:,:])
    res3['correlation'][:,:] = np.nan_to_num(res3['correlation'][:,:])

    divide_lim = 2 # pixels
    for i in range(divide_lim,y+divide_lim,divide_lim):
        if i%10==0:
            print(i)
        if i>y:
            cleanup_for_part(res1,res2,res3,errors[i-divide_lim:y,:],n_best[i-divide_lim:y,:],i-divide_lim,y)
        else:
            cleanup_for_part(res1,res2,res3,errors[i-divide_lim:i,:],n_best[i-divide_lim:i,:],i-divide_lim,i)
            
    print('Creating xmap')
    # Make xmap:
    xmap1 = iutls.results_dict_to_crystal_map(res1, phase, diffraction_library=library)
    xmap1.phases[0].space_group = 227 
    xmap1.correlation = np.nan_to_num(xmap1.correlation) # If any correlation scores are NaN
    xmap1 = change_step_size(xmap1,15.356,'nm')

    xmap2 = iutls.results_dict_to_crystal_map(res2, phase, diffraction_library=library)
    xmap2.phases[0].space_group = 227 
    xmap2.correlation = np.nan_to_num(xmap2.correlation) # If any correlation scores are NaN
    xmap2 = change_step_size(xmap2,15.356,'nm')

    xmap3 = iutls.results_dict_to_crystal_map(res3, phase, diffraction_library=library)
    xmap3.phases[0].space_group = 227 
    xmap3.correlation = np.nan_to_num(xmap3.correlation) # If any correlation scores are nan
    xmap3 = change_step_size(xmap3,15.356,'nm')
    end_time = time.process_time()
    print(f'Elapsed time: {end_time-start_time}s')

    return xmap1, xmap2, xmap3, errors, n_best


## Loading and aligning data
The input datasets must be on dictionary form to be able to change the content.

In [6]:
libpath = Path("./Libraries/lib_035res.pkl")
with open(libpath, 'rb') as fp: # With b sub
    library = pickle.load(fp)

resultpath1 = Path("D:/Master_thesis/Data_from_workstation/180224/20240208_123856_xmap_result.pkl")
with open(resultpath1, 'rb') as fp: # With b sub
    result1 = pickle.load(fp)

resultpath2 = Path("D:/Master_thesis/Data_from_workstation/180224/20240208_131643_xmap_result.pkl")
with open(resultpath2, 'rb') as fp: # With b sub
    result2 = pickle.load(fp)

resultpath3 = Path("D:/Master_thesis/Data_from_workstation/180224/20240208_160122_xmap_result.pkl")
with open(resultpath3, 'rb') as fp: # With b sub
    result3 = pickle.load(fp)


phasepath1 = Path("D:/Master_thesis/Data_from_workstation/180224/20240208_123856_xmap_phasedict.pkl")
with open(phasepath1, 'rb') as fp: # With b sub
    phasedict1 = pickle.load(fp)

In [7]:
# Choose the size of the dataset to process, and the number of best templates to use
keys = result1.keys()
for i in keys:
    result1[f'{i}'] = result1[f'{i}'][35:-5,5:]
    result2[f'{i}'] = result2[f'{i}'][29:-11,5:]
    result3[f'{i}'] = result3[f'{i}'][:,:-5]
print(np.shape(result1['template_index']),np.shape(result2['template_index']),np.shape(result3['template_index']))

limit = 23
x1,x2,y1,y2 = 240,260,380,400
# x1,x2,y1,y2 = 0,-1,0,-1

for i in keys:
    result1[f'{i}'] = result1[f'{i}'][y1:y2,x1:x2,0:limit]
    result2[f'{i}'] = result2[f'{i}'][y1:y2,x1:x2,0:limit]
    result3[f'{i}'] = result3[f'{i}'][y1:y2,x1:x2,0:limit]
print(np.shape(result1['template_index']),np.shape(result2['template_index']),np.shape(result3['template_index']))


(520, 275, 30) (520, 275, 30) (520, 275, 30)
(20, 20, 23) (20, 20, 23) (20, 20, 23)


## Run the ACM

In [8]:
xmap1,xmap2,xmap3,error_array,n_best_array = cleanup(result1,result2,result3,phasedict1)

10
20
Creating xmap
Elapsed time: 17.328125s


In [9]:
make_colormap(xmap1)

In [None]:
# Can find the best template indexes for the tilts.
xt,yt = 10,10
print(f"n best tilt 1: {np.where(result1['template_index'][xt,yt]==result1['template_index'][xt,yt][0])[0]}")
print(f"n best tilt 2: {np.where(result2['template_index'][xt,yt]==result2['template_index'][xt,yt][0])[0]}")
print(f"n best tilt 3: {np.where(result3['template_index'][xt,yt]==result3['template_index'][xt,yt][0])[0]}")

## Saving cleaned up Crystal maps

In [None]:
name = 'grain_boundary'

io.save(f"D:/Master_thesis/Cleanup_xmaps/{name}_1.h5", xmap1) 
io.save(f"D:/Master_thesis/Cleanup_xmaps/{name}_2.h5", xmap2) 
io.save(f"D:/Master_thesis/Cleanup_xmaps/{name}_3.h5", xmap3) 

io.save(f"D:/Master_thesis/Cleanup_xmaps_ang/{name}_1.ang", xmap1) 
io.save(f"D:/Master_thesis/Cleanup_xmaps_ang/{name}_2.ang", xmap2) 
io.save(f"D:/Master_thesis/Cleanup_xmaps_ang/{name}_3.ang", xmap3) 

with open(f"D:/Master_thesis/Cleanup_xmaps/error_{name}.pkl", 'wb') as fp:
    pickle.dump(error_array, fp)

with open(f"D:/Master_thesis/Cleanup_xmaps/n_best_{name}.pkl", 'wb') as fp:
    pickle.dump(n_best_array, fp)