# Mask, center and threshold 2D raw data

In [None]:
%matplotlib qt
import pyxem as pxm
import os
import gc
import numpy as np

In [None]:
data = []
base_root = r"D:\jf631\mg24111-1"
for root, dirs, files in os.walk(base_root):
    for file in files:
        if file.endswith('_diff_plane_515x515.hdf5'):
            data.append(os.path.join(root, file))
data.sort()

mask = np.load('mask_jordi.npy')
mask[250:259, 150:159]

In [None]:
data[:3]

In [None]:
i = 0
for fname in data[:]:
    dp = pxm.load(fname)
    dp *= mask
    mean_dp = dp.mean()
    mean_dp = pxm.ElectronDiffraction2D(mean_dp)
    centre = mean_dp.get_direct_beam_position(method = 'cross_correlate', radius_start = 1, radius_finish = 10)
    shifts = [[centre.data[0], centre.data[1]]]
    mean_dp.change_dtype('float32')
    mean_dp.save(str(os.path.basename(fname) + '_mean_dp.tiff'))
    n_shifts = shifts*(len(dp.data[0])*len(dp.data[:,1]))
    n_shifts=np.array(n_shifts)
    dp.align2D(shifts = -n_shifts, crop = False)
    dp.center_direct_beam(method='interpolate', sigma=5, upsample_factor=4, kind='linear', half_square_width=10)
    dp.save(str(os.path.basename(fname) + '_centred'))
    dp2 = dp.deecopy()
    dp2.data[dp2.data==1]=0
    dp2.save(str(os.path.basename(fname) + '_centred-threshold'))
    del dp
    del dp2
    del mean_dp
    i+=1
    gc.collect()
    

In [None]:
os.path.basename(data[0]).split('.')[0]

# Paralellizing

In [1]:
# Define function (with packages imported within)
def centering_and_thresholding_in_parallel(fname_path, mask, radius_start=1, radius_finish=10, half_square_width=10):
    # Import packages
    import os
    import pyxem as pxm
    import gc
    import numpy as np
    
    folder = os.path.dirname(fname_path)
    fname = os.path.basename(fname_path).split('.')[0]
    
    dp = pxm.load(fname_path, signal_type='electron_diffraction')
    dp *= mask
    
    # Get the rough shifts from the mean dp and save mean dp
    mean_dp = dp.mean()
    centre = mean_dp.get_direct_beam_position(method = 'cross_correlate', radius_start = radius_start, radius_finish = radius_finish)
    shifts = [[centre.data[0], centre.data[1]]]
    mean_dp.change_dtype('float32')
    mean_dp.save(os.path.join(folder, (fname + '_mean_dp.tiff')))
    
    # Align with the rough shifts
    n_shifts = shifts*(len(dp.data[0])*len(dp.data[:,1]))
    n_shifts = np.array(n_shifts)
    dp.align2D(shifts = -n_shifts, crop = False)
    
    # Align using fine alignment and save
    dp.center_direct_beam(method='interpolate', sigma=5, upsample_factor=4, kind='linear', half_square_width=half_square_width)
    dp.save(os.path.join(folder, (fname + '_centred_nonthresholded')), overwrite=True)
    
    # Threshold the data and save
    dp.data[dp.data==1] = 0
    dp.save(os.path.join(folder, (fname + '_centred_thresholded')), overwrite=True)
    
    # Clean up the RAM
    del dp
    del mean_dp
    gc.collect()
    return

In [3]:
# Define variables to iterate from
import os
import numpy as np
from itertools import product

fnames = []
base_root = r"D:\jf631\mg24111-1"
for root, dirs, files in os.walk(base_root):
    for file in files:
        if file.endswith('_diff_plane_515x515.hdf5'):
            fnames.append(os.path.join(root, file))
fnames.sort()

mask = np.load('mask_jordi.npy')

mask = [mask,]

# Create iteration tools
iterations = product(fnames, mask)

In [4]:
# Run
# make sure to always use multiprocess
from multiprocess import Pool
import psutil

# start your parallel workers at the beginning of your script
n_cores = psutil.cpu_count(logical=False)
n_cores=5
pool = Pool(n_cores)

# execute a computation(s) in parallel
pool.starmap(centering_and_thresholding_in_parallel, iterations)

# turn off your parallel workers at the end of your script
pool.close()