# Example for BaTiO$_3$

In [None]:
import json
from pathlib import Path
import numpy as np
import tensorflow as tf
from xca.data_synthesis.builder import cycle_params
from xca.data_synthesis.cctbx import load_cif, calc_structure_factor, convert_to_numpy
from xca.ml.tf_data_proc import dir2TFR
from xca.ml.tf_models import CNN_training as training
from xca.ml.tf_parameters import load_hyperparameters

### The following functions will help generate a synthetic dataset.
These are largely replicates of the contents of `example_scripts`, placed in notebooks for more accessible visibility.

The default `N_PATTERNS` is set to 100 for convenience in testing. To reproduce the work of arXiv:2008.00283,
change this to 50,000.

In [None]:
N_PATTERNS = 100
system = "BaTiO"

In [None]:
def get_reflections(cif_path, tth_min, tth_max, wavelength):
    """Checks relevant reflections that occur between tth_min and tth_max at a given wavelength"""
    data = load_cif(cif_path)
    sf = calc_structure_factor(data['structure'])
    scattering = convert_to_numpy(sf, wavelength=wavelength, tth_max=tth_max, tth_min=tth_min)
    reflections = zip(scattering['hkl'], scattering['2theta'], scattering['I'])
    keep = []
    for reflection in reflections:
        if reflection[1] < tth_max and reflection[1] > tth_min and reflection[2] > 1:
            keep.append(tuple(reflection[0]))
    return keep


def log_reflections(cif_paths, tth_min, tth_max, wavelength, outpath=None):
    """Itterates over list of cifs and puts relevant reflections into json file"""
    dic = {}
    for cif_path in cif_paths:
        if not isinstance(cif_path, Path):
            path = Path(cif_path)
        else:
            path = cif_path
        dic[path.stem] = get_reflections(path, tth_min, tth_max, wavelength)
    if outpath:
        with open(outpath, 'w') as f:
            json.dump(dic, f)

    return dic

def pattern_simulation(n_patterns):
    """Example pattern simulation as reported in arXiv:2008.00283"""
    wavelength = 0.1671
    param_dict = {'wavelength': 0.1671,
                  'noise_std': 5e-4,
                  'instrument_radius': 1065.8822732979447,
                  'theta_m': 0.0,
                  '2theta_min': 0.011231808788013649,
                  '2theta_max': 24.853167100343246,
                  'n_datapoints': 3488}
    kwargs = {'bkg_1': (-1e-4, 1e-4),
              'bkg_0': (0, 1e-3)}
    cif_paths = list(Path(f'../example_scripts/cifs-{system}/').glob('*.cif'))
    march_range = (0.8, 1.0)
    sample_height = (-2.0, 2.0)
    shape_limit = 1e-1
    reflections = log_reflections(cif_paths, param_dict['2theta_min'], param_dict['2theta_max'], wavelength)
    for idx, cif in enumerate(cif_paths):
        print(cif)
        phase = cif.stem
        param_dict['input_cif'] = cif
        output_path = Path(f'./tmp/{system}') / str(idx)
        output_path.mkdir(parents=True, exist_ok=True)
        cycle_params(n_patterns,
                     output_path,
                     input_params=param_dict,
                     march_range=march_range,
                     shape_limit=shape_limit,
                     sample_height=sample_height,
                     preferred_axes=reflections[phase],
                     **kwargs)

### Simulate the patterns
These patterns are output as numpy files, which are converted to tensorflow records for convenience.


In [None]:
pattern_simulation(N_PATTERNS)
dir2TFR(f"./tmp/{system}", f"./tmp/{system}.tfrecords")

### Training from this new dataset

In [None]:
params = load_hyperparameters(params_file=f"../example_scripts/{system}_training.json")
res, model = training(params=params)
print(f"Results for {system}")
print(res)
model.save(f"./tmp/{system}_model")

### Now we segregate the test data based on this trained model

In [None]:
exp = np.loadtxt('../example_data/BaTiO.csv', delimiter=',', skiprows=1)
temperatures = []
with open('../example_data/BaTiO.csv', 'r') as f:
    names = f.readline().split(',')
    for name in names:
        temperatures.append(float(name.split('_')[1][:-1]))
exp = tf.reshape(tf.convert_to_tensor(exp, dtype=tf.float32), (60, 3488, 1))
y_pred = model({'X': exp}, training=False)


In [None]:
import matplotlib.pyplot as plt
for i in range(y_pred.shape[1]):
    plt.plot(y_pred[:,i])
plt.xticks(range(0, 60, 19), [temperatures[i] for i in range(0, 60, 19)])
plt.xlabel('Temperature [K]')
plt.ylabel('Phase probability')
plt.ylim(0,1)
plt.show()
