# Example for adamantane-1,3,5,7-tetracarboxylic acid (ADTA)

In [None]:
import json
from pathlib import Path
import numpy as np
from xca.data_synthesis.builder import cycle_params
from xca.data_synthesis.cctbx import load_cif, calc_structure_factor, convert_to_numpy
from xca.ml.tf_data_proc import dir2TFR, build_test_dataset
from xca.ml.tf_models import CNN_training as training
from xca.ml.tf_parameters import load_hyperparameters

### The following functions will help generate a synthetic dataset.
These are largely replicates of the contents of `example_scripts`, placed in notebooks for more accessible visibility.

The default `N_PATTERNS` is set to 100 for convenience in testing. To reproduce the work of arXiv:2008.00283,
change this to 50,000.

In [None]:
N_PATTERNS = 100
system = "ADTA"

In [None]:
def get_reflections(cif_path, tth_min, tth_max, wavelength):
    """Checks relevant reflections that occur between tth_min and tth_max at a given wavelength"""
    data = load_cif(cif_path)
    sf = calc_structure_factor(data['structure'])
    scattering = convert_to_numpy(sf, wavelength=wavelength, tth_max=tth_max, tth_min=tth_min)
    reflections = zip(scattering['hkl'], scattering['2theta'], scattering['I'])
    keep = []
    for reflection in reflections:
        if reflection[1] < tth_max and reflection[1] > tth_min and reflection[2] > 1:
            keep.append(tuple(reflection[0]))
    return keep


def log_reflections(cif_paths, tth_min, tth_max, wavelength, outpath=None):
    """Itterates over list of cifs and puts relevant reflections into json file"""
    dic = {}
    for cif_path in cif_paths:
        if not isinstance(cif_path, Path):
            path = Path(cif_path)
        else:
            path = cif_path
        dic[path.stem] = get_reflections(path, tth_min, tth_max, wavelength)
    if outpath:
        with open(outpath, 'w') as f:
            json.dump(dic, f)

    return dic

def pattern_simulation(n_patterns):
    """Example pattern simulation as reported in arXiv:2008.00283"""
    wavelength = 1.54060
    param_dict = {'noise_std': 2e-3,
                  'instrument_radius': 240.00,
                  '2theta_min': 2.00756514,
                  '2theta_max': 39.99347292,
                  'n_datapoints': 2894}
    kwargs = {'bkg_-1': (0.0, 0.5),
              'bkg_-2': (0.0, 1.0)}
    march_range = (0.05, 1)
    sample_height = (-2, 2)
    shape_limit = 0.05
    cif_paths = list(Path(f'../example_scripts/cifs-{system}/').glob('*.cif'))
    reflections = log_reflections(cif_paths, param_dict['2theta_min'], param_dict['2theta_max'], wavelength)
    for idx, cif in enumerate(cif_paths):
        print(cif)
        phase = cif.stem
        param_dict['input_cif'] = cif
        output_path = Path(f'./tmp/{system}') / str(idx)
        output_path.mkdir(parents=True, exist_ok=True)
        cycle_params(n_patterns,
                     output_path,
                     input_params=param_dict,
                     march_range=march_range,
                     shape_limit=shape_limit,
                     sample_height=sample_height,
                     preferred_axes=reflections[phase],
                     **kwargs)

### Simulate the patterns
These patterns are output as numpy files, which are converted to tensorflow records for convenience.


In [None]:
pattern_simulation(N_PATTERNS)
dir2TFR(f"./tmp/{system}", f"./tmp/{system}.tfrecords")

### Training from this new dataset

In [None]:
params = load_hyperparameters(params_file=f"../example_scripts/{system}_training.json")
res, model = training(params=params)
print(f"Results for {system}")
print(res)
model.save(f"./tmp/{system}_model")

### Now we segregate the test data based on this trained model
The output here is a dictionary where the keys are the ground truth classification,
and the values are the counts of argmax predictions for all of the patterns that correspond to that class.

In [None]:
test_params = {}
test_params.update(params)
test_params['dataset_path'] = '../example_data/ADTA.tfrecords'
test_dataset = build_test_dataset(test_params)


In [None]:
classifications = {}
for batch in test_dataset:
    label = batch['label'].numpy()[0].decode("utf-8")
    if label not in classifications:
        classifications[label] = np.zeros(params['n_classes'])

    y_pred = model({'X':batch['X']}, training=False)
    classifications[label][np.argmax(y_pred)] += 1

In [None]:
classifications