# Example for Ni-Co-Al ternary alloy

In [None]:
import json
from pathlib import Path
import numpy as np
import pandas as pd
import re
import os
from xca.data_synthesis.builder import cycle_params
from xca.data_synthesis.cctbx import load_cif, calc_structure_factor, convert_to_numpy
from xca.ml.tf_data_proc import dir2TFR, build_test_dataset
from xca.ml.tf_models import CNN_training as training
from xca.ml.tf_parameters import load_hyperparameters

### The following functions will help generate a synthetic dataset.
These are largely replicates of the contents of `example_scripts`, placed in notebooks for more accessible visibility.

The default `N_PATTERNS` is set to 100 for convenience in testing. To reproduce the work of arXiv:2008.00283,
change this to 50,000.

In [None]:
N_PATTERNS = 100
system = "NiCoAl"

In [None]:
def get_reflections(cif_path, tth_min, tth_max, wavelength):
    """Checks relevant reflections that occur between tth_min and tth_max at a given wavelength"""
    data = load_cif(cif_path)
    sf = calc_structure_factor(data['structure'])
    scattering = convert_to_numpy(sf, wavelength=wavelength, tth_max=tth_max, tth_min=tth_min)
    reflections = zip(scattering['hkl'], scattering['2theta'], scattering['I'])
    keep = []
    for reflection in reflections:
        if reflection[1] < tth_max and reflection[1] > tth_min and reflection[2] > 1:
            keep.append(tuple(reflection[0]))
    return keep


def log_reflections(cif_paths, tth_min, tth_max, wavelength, outpath=None):
    """Itterates over list of cifs and puts relevant reflections into json file"""
    dic = {}
    for cif_path in cif_paths:
        if not isinstance(cif_path, Path):
            path = Path(cif_path)
        else:
            path = cif_path
        dic[path.stem] = get_reflections(path, tth_min, tth_max, wavelength)
    if outpath:
        with open(outpath, 'w') as f:
            json.dump(dic, f)

    return dic

def pattern_simulation(n_patterns):
    """Example pattern simulation as reported in arXiv:2008.00283"""
    wavelength = [(1.54060, 0.5), (1.54439, 0.5)]
    param_dict = {'noise_std': 5e-3,
                  'instrument_radius': 240.00,
                  'theta_m': 26.6,
                  '2theta_min': 20.0,
                  '2theta_max': 89.93999843671914,
                  'n_datapoints': 3498}
    kwargs = {'bkg_0': (0.0, 0.05)}
    march_range = (0.05, 1)
    sample_height = (-2.0, 2.0)
    shape_limit = 0.05
    cif_paths = list(Path(f'../example_scripts/cifs-{system}/').glob('*.cif'))
    reflections = log_reflections(cif_paths, param_dict['2theta_min'], param_dict['2theta_max'], wavelength[0][0])
    filemap = {}
    for idx, cif in enumerate(cif_paths):
        print(cif)
        phase = cif.stem
        param_dict['input_cif'] = cif
        output_path = Path(f'./tmp/{system}') / str(idx)
        output_path.mkdir(parents=True, exist_ok=True)
        cycle_params(n_patterns,
                     output_path,
                     input_params=param_dict,
                     march_range=march_range,
                     shape_limit=shape_limit,
                     sample_height=sample_height,
                     preferred_axes=reflections[phase],
                     **kwargs)

### Simulate the patterns
These patterns are output as numpy files, which are converted to tensorflow records for convenience.
The file names in this case are used to track key information which is stored in a dictionary.


In [None]:
pattern_simulation(N_PATTERNS)
dir2TFR(f"./tmp/{system}", f"./tmp/{system}.tfrecords")


### Training from this new dataset

In [None]:
params = load_hyperparameters(params_file=f"../example_scripts/{system}_training.json")
res, model = training(params=params)
print(f"Results for {system}")
print(res)
model.save(f"./tmp/{system}_model")

### Now we segregate the test data based on this trained model
The output here is a dictionary where the keys are the ground truth classification,
and the values are the counts of argmax predictions for all of the patterns that correspond to that class.

In [None]:
test_params = {}
test_params.update(params)
test_params['dataset_path'] = '../example_data/NiCoAl.tfrecords'
test_dataset = build_test_dataset(test_params)


In [None]:
classifications = {}
for batch in test_dataset:
    label = batch['label'].numpy()[0].decode("utf-8")
    if label not in classifications:
        classifications[label] = np.zeros(params['n_classes'])

    y_pred = model({'X':batch['X']}, training=False)
    classifications[label][np.argmax(y_pred)] += 1
classifications

## Mappings

In [None]:
cif_paths = list(Path(f'../example_scripts/cifs-{system}/').glob('*.cif'))
filemap = {}
elements = ('Ni','Co','Al')
for idx, cif in enumerate(cif_paths):
    filemap[idx]={}
    name, syst, num = str(cif.stem).split('-')
    filemap[idx]['name'] = name
    filemap[idx]['crystal-system'] = syst
    filemap[idx]['space-group'] = int(num)
    filemap[idx]['composition'] = {}
    for e in elements:
        # Uses the first (assuming only) instance of the element in formula name to get composition
        p = re.compile('{}[0-9]*\.?[0-9]*'.format(e))
        comp = p.findall(name)
        if comp:
            x = comp[0].replace(e,'')
            if x:
                filemap[idx]['composition'][e] = float(x)
            else: #Empty string after removal means only 1 of element
                filemap[idx]['composition'][e] = 1.
        else: # Empty list
            filemap[idx]['composition'][e] = 0.
    s = sum(filemap[idx]['composition'].values())
    for e in elements:
        filemap[idx]['composition'][e] /= s

In [None]:
sim_mapping = {}
sim_mapping_rev = {}
for key in filemap:
    sim_mapping[int(key)] = filemap[key]['name']+'-'+str(filemap[key]['space-group'])
    sim_mapping_rev[filemap[key]['name']+'-'+str(filemap[key]['space-group'])] = int(key)
sim_mapping

## Constructing a full output of the probability distributions

In [None]:
def EDX_proby(path):
    def gaussian(x, sigma=0.1164977062):
        """
        Calculates gaussian function centered at (0,1).
        https://en.wikipedia.org/wiki/Gaussian_function
        Default Full-Width-Tenth-Max = 0.5

        Parameters
        ----------
        x: float or array-like
        sigma: float
            standard deviation or sqrt(variance) of function

        Returns
        -------
        probability density: type of x
        """
        return np.exp(-np.power(x, 2.) / (2 * np.power(sigma, 2.)))
    EDX_df = pd.read_csv(path, index_col='Number')
    EDX_df.iloc[:,:] /= 100
    df = pd.DataFrame(index=EDX_df.index)
    for key in filemap:
        prod = np.ones(len(EDX_df))
        for elem, value in filemap[key]['composition'].items():
            prod *= gaussian(EDX_df[elem].values - value)
        df[filemap[key]['name']] = prod

    df.iloc[:,:]=df.iloc[:,:].div(df.iloc[:,:].sum(axis=1),axis=0)

    return df,EDX_df
edx_proby, comp = EDX_proby(path='../example_data/EDXmap.csv')


In [None]:
XRD_all = {}
joint_all = {}
for batch in test_dataset:
    label = batch['label'].numpy()[0].decode("utf-8")
    label_split = list(label.split('+'))
    fnum = int(os.path.splitext(batch['fname'].numpy()[0].decode("utf-8"))[0])
    fname = batch['fname'].numpy()[0].decode("utf-8")
    XRD_pred = model({'X':batch['X']}, training=False).numpy()[0,:]

    joint_pred = np.copy(XRD_pred)
    # This loops over each prediction to allow for multiple phases of same composition
    for idx in range(len(joint_pred)):
        joint_pred[idx] *= edx_proby.loc[fnum][filemap[idx]['name']]
    joint_pred /= np.sum(joint_pred)

    XRD_all[fname] = (label,XRD_pred[:])
    joint_all[fname] = (label,joint_pred[:])

In [None]:
def write_csv(path, dic):
    with open(path,'w') as f:
        f.write("Numpy idx, True Phase, ")
        for idx in range(len(filemap)):
            f.write("{}, ".format(filemap[idx]['name']))
        f.write('\n')
        f.write(", Space Group, ")
        for idx in range(len(filemap)):
            f.write("{}, ".format(filemap[idx]['space-group']))
        f.write('\n')
        f.write(", Crystal System, ")
        for idx in range(len(filemap)):
            f.write("{}, ".format(filemap[idx]['crystal-system']))
        f.write('\n')

        for key in sorted(dic.keys(), key=lambda x: int(x[:-4])):
            true, preds = dic[key]
            f.write("{}, {}, {}\n".format(key[:-4], true.replace(',','/'), ", ".join(["{:.3e}".format(p) for p in preds])))

In [None]:
write_csv('./tmp/NiCoAl_xrd_pred.csv', XRD_all)
write_csv('./tmp/NiCoAl_joint_pred.csv', joint_all)