# Create Simulated Datasets of 2D Peaks

## Coordinates Output
For both the training and testing datasets.
This will simplify the simulation to take a step back.

This will output coordinates in both Cartesian and k space (based on pure crystallographic structure) rather than diffraction simulation images.

## Initialization

In [None]:
# Packages
%matplotlib qt
import numpy as np
import pandas as pd
import hyperspy.api as hs
import pyxem as pxm
import diffpy.structure
from matplotlib import pyplot as plt
from tempfile import TemporaryFile
from diffsims.libraries.structure_library import StructureLibrary
from diffsims.generators.diffraction_generator import DiffractionGenerator
from diffsims.generators.library_generator import DiffractionLibraryGenerator, VectorLibraryGenerator
from diffsims.sims.diffraction_simulation import DiffractionSimulation
from diffsims.libraries.diffraction_library import DiffractionLibrary
from pyxem.generators.indexation_generator import VectorIndexationGenerator
from pyxem.generators.subpixelrefinement_generator import SubpixelrefinementGenerator
from pyxem.signals.diffraction_vectors import DiffractionVectors
import tqdm
import gc
import os

In [None]:
### Variables

# Paths
root = r'C:/Users/anish/Documents/GitHub/ml_pyxem/mini_2/'

# Phases
structures_path = os.path.join(root, 'crystal_phases')
phase_files = ['p4mbm_tetragonal.cif',]
add_bkg_phase = False # Do you want to add a bkg/just noise phase at the end? If True, the final datasets will be phases + 1 shape.

# Calibration values
calibration = 0.00588 #To have multiple, create an array (i.e. calibrations = [0.00588]) and set for value in for loop

# Processing values
n_angle_points = 10

# Domain amplification
with_direct_beam = False

# Noise addition values (do not change)
remove_peaks = False
add_noise = True
include_also_non_noisy_simulation = False # If add noise, do you want to also have the non-noisy data?
snrs = [0.9, 0.99]
intensity_spikes = [0.25,]

# Simulation microscope values (for azimuthal integration)
detector_size = 515 #px
beam_energy = 200.0 #keV
wavelength = 2.5079e-12 #m
detector_pix_size = 55e-6 #m
from pyxem.detectors import Medipix515x515Detector
detector = Medipix515x515Detector()

In [None]:
val = n_angle_points * (len(phase_files)) * len(relrod_list) * len(spot_spread_list) #* len(snrs) * len(intensity_spikes)
print('Approx amount of 2D diffraction patterns that will be produced: {}'.format(val))
memory = detector_size**2 * val * 4 / 1e9  #4 bytes per float32 value
print('Approx memory needed: {} GB'.format(memory))

## Simulate Data

### Define Functions

In [None]:
phase_dict = {}
for phase in phase_files:
     name = phase.split(".")[0]
     phase_dict[name] = diffpy.structure.loadStructure(os.path.join('crystal_phases', phase))
     print('n_phases = {}'.format(len(phase_dict)))

In [None]:
def get_random_euler(npoints):
    radius = 1
    np.random.seed(1)
    u = np.random.randint(-100,100+1,size=(npoints,))/100 
    u2 = 2*np.pi*np.random.random(size=(npoints,))
    theta = 2*np.pi*np.random.random(size=(npoints,))
    x = radius*np.sqrt(1-u**2)*np.cos(theta)
    y = radius*np.sqrt(1-u**2)*np.sin(theta)
    z = radius*u 
    phi = np.arccos(z/radius)
    eulerAlpha = u2
    eulerBeta = phi
    eulerGamma = theta
    return np.array([np.rad2deg(eulerAlpha),np.rad2deg(eulerBeta),np.rad2deg(eulerGamma)]).T 


def get_reciprocal_radius(detector_size, calibration):
    half_pattern_size = detector_size // 2
    reciprocal_radius = calibration * half_pattern_size
    return reciprocal_radius


def create_diffraction_library(phase_dict, euler_list, beam_energy, calibration, detector_size, with_direct_beam):

    phase_names = list(phase_dict.keys())
    phases = list(phase_dict.values())
    euler_list_n = [euler_list, ] * len(phase_names)

    sample_lib = StructureLibrary(phase_names, phases, euler_list_n)
    ediff = DiffractionGenerator(beam_energy)
    diff_gen = DiffractionLibraryGenerator(ediff)

    reciprocal_radius = get_reciprocal_radius(detector_size, calibration)
    library = diff_gen.get_diffraction_library(sample_lib,
                                               calibration=calibration,
                                               reciprocal_radius=reciprocal_radius,
                                               half_shape=(detector_size//2, detector_size//2),
                                               with_direct_beam=with_direct_beam)
    return library

### Create Diffraction Patterns

In [None]:
data = {}
for key, val in phase_dict.items():
    data[key] = []

euler_list = get_random_euler(n_angle_points)
library = create_diffraction_library(phase_dict, euler_list, beam_energy, calibration, detector_size, with_direct_beam)
reciprocal_radius = get_reciprocal_radius(detector_size, calibration)
#print(library)

for euler in euler_list:
    for phase in library.keys():
        pattern = DiffractionSimulation.get_diffraction_pattern(library.get_library_entry(phase=phase,angle=euler)['Sim'])
        data[phase].append(pattern)

### Plotting Check

In [None]:
#for euler in euler_list: #shows diffraction pattern for every configuration simulated
#    for phase in library.keys():
#        pattern = DiffractionSimulation.get_diffraction_pattern(library.get_library_entry(phase=phase,angle=euler)['Sim'])
#        plt.figure() 
#        plt.imshow(pattern, cmap='viridis', vmax=0.3)

In [None]:
#px_coords_array = library['p4mbm_tetragonal']['pixel_coords']
#px_coords_i_x = library['p4mbm_tetragonal']['pixel_coords'][0][:,0]
#px_coords_i_y = library['p4mbm_tetragonal']['pixel_coords'][0][:,1]
#plt.figure(1)
#plt.scatter(px_coords_i_x,px_coords_i_y) #shows diffraction pattern of a single configuration - based on Cartesian coordinates

#rec_coords_array = library['p4mbm_tetragonal']['rec_coords']
#rec_coords_i_x = library['p4mbm_tetragonal']['rec_coords'][0][:,0]
#rec_coords_i_y = library['p4mbm_tetragonal']['rec_coords'][0][:,1]
#plt.figure(2)
#plt.scatter(rec_coords_i_x,rec_coords_i_y) #shows diffraction pattern of a single configuration - based on reciprocal coordinates

### Stack Data

In [None]:
import dask.array as da

for i, value in enumerate(data.values()):
    list_data = da.from_array([x.data for x in value], chunks=(10, detector_size, detector_size))

    if i ==0:
        #list_data = np.expand_dims(list_data, 1)
        training_data = list_data
    else:
        #list_data = np.expand_dims(list_data, 1)
        training_data = da.vstack([training_data, list_data],)

del data
del library
del list_data
gc.collect()

shape = (len(phase_dict.keys()),
         n_angle_points*#len(relrod_list)*len(spot_spread_list)*len(simulated_direct_beam_bool),
         detector_size,
         detector_size)

#training_data = training_data.reshape(shape)
#training_data = pxm.LazyElectronDiffraction2D(training_data)
#training_data.set_diffraction_calibration(calibration)
print(training_data)

## Data Augmentation

### Define Functions

In [None]:
def remove_random_peaks(n_patterns, phase_dict,size):
    from numpy.random import default_rng
    import pandas as pd
    
    phase_names = list(phase_dict.keys())
    augmented_library = library
    
    for phase_name in phase_names:
        for i in range(n_patterns):
            df = pd.DataFrame(augmented_library[phase_name]['rec_coords'][i])
            rand_index = default_rng().choice(df.index, size, replace=False)
            df = df.drop(index=rand_index)
            augmented_library[phase_name]['rec_coords'][i]= df.to_numpy()
            df2 = pd.DataFrame(augmented_library[phase_name]['pixel_coords'][i])
            df2 = df2.drop(index=rand_index)
            augmented_library[phase_name]['pixel_coords'][i]= df2.to_numpy()
            df3 = pd.DataFrame(augmented_library[phase_name]['intensities'][i])
            df3 = df3.drop(index=rand_index)
            augmented_library[phase_name]['intensities'][i]= df3.to_numpy()
            
    return augmented_library

### Testing

In [None]:
if remove_peaks:
    augmented_library = remove_random_peaks(n_patterns=n_angle_points, phase_dict=phase_dict, size=1)
else:
    augmented_library = library
    
# print(augmented_library)

x = len(augmented_library['p4mbm_tetragonal']['pixel_coords'][0])
print(x)

In [None]:
def add_noise_to_simulation(simulation_arr, snr, int_salt,):
    import numpy as np
    # Salt and pepper
    def addsalt_pepper(dp_arr, snr, int_min = 0, int_max = int_salt,):
        p0 = snr
        # Add noise
        size = np.shape(dp_arr)
        mask = np.random.choice(a=(0, 1, 2),
                                size=size,
                                p=[p0, (1 - p0) / 2., (1 - p0) / 2.])

        im = dp_arr.copy()
        #im[mask == 1] = int_min # salt noise
        im[mask == 2] = int_max # pepper noise
        return im
    # Add poisson noise on sp noise and normalise
    im = simulation_arr.copy()
    im += np.random.poisson(im)
    max = im.max()
    if max == 0:
        im = im
    else:
        im = im / im.max()
    # Add bright spots randomly accross detector
    im_sp = addsalt_pepper(im, snr,)
    return im_sp

In [None]:
# Map the noise addition function on signal
if add_noise:
    training_data_noisy = []

    # Include the non-corrupted data in the dataset?
    if include_also_non_noisy_simulation:
        training_data_noisy.append(training_data)

    # Append noisy data
    for snr in snrs:
        for int_spike in intensity_spikes:

            signal_noisy = training_data.map(add_noise_to_simulation,
                                             snr=snr, int_salt=int_spike,
                                             inplace=False, parallel=True)

            training_data_noisy.append(signal_noisy)

    del training_data
    del signal_noisy
    gc.collect()

    training_data_noisy = hs.stack(training_data_noisy, axis=0)

else:
    # No noise addition
    training_data_noisy = training_data

## K Space Output

This will output only coordinates in the k space (based on pure crystallographic structure) rather than diffraction simulation images