# Predicting phases for experimental dataset

In [1]:
# Import packages
%matplotlib inline
import tensorflow as tf
import numpy as np
import os, glob
import gc
import tqdm
import hyperspy.api as hs
from tempfile import TemporaryFile



## Import list of experimental datafile paths to process
If only one file to predict, make it a list `data` of a single path.

In [2]:
# Select which experimental dataset you want to predict.

file_end = 'cropped_sqrt.npz'
#file_end = 'cropped.npz'

tresholded = True #False
base_root = r"D:\jf631\mg24111-1"

data = []
for root, dirs, files in os.walk(base_root):
    for file in files:
        if tresholded:
            if file.endswith(file_end) and 'nonthreshold' not in file:
                data.append(os.path.join(root, file))
        else:
            if file.endswith(file_end) and 'nonthreshold' in file:
                data.append(os.path.join(root, file))
data.sort()
data[:3]

['D:\\jf631\\mg24111-1\\20200927_094445_TripCatBeamDam_3p59A2_150kx_10umAp_20cmCL_scan_array_253x255_diff_plane_515x515_centred_thresholded_radial_norm_cropped_sqrt.npz',
 'D:\\jf631\\mg24111-1\\20200927_101523_TripCatBeamDam_3p59A2_150kx_10umAp_20cmCL_scan_array_254x255_diff_plane_515x515_centred_thresholded_radial_norm_cropped_sqrt.npz',
 'D:\\jf631\\mg24111-1\\20200927_101748_TripCatBeamDam_3p59A2_150kx_10umAp_20cmCL_scan_array_255x255_diff_plane_515x515_centred_thresholded_radial_norm_cropped_sqrt.npz']

## Select a list of models to use to predict
If only one model, make it a list `models` of a single path to the model `.h5` file.

In [3]:
# Select which model to use
cwd = os.path.abspath(os.getcwd())
models = glob.glob(os.path.join(cwd, '*.h5'))
models.sort()
print(len(models))

models

2


['C:\\Users\\Sauron\\Documents\\jf631\\SED_scripts\\nn_models\\20210216_1D_multiclass_corrupted_noisy_bkg\\NN_6classes_4epochs_130133Train_ac0.7795_130133Test0.7786.h5',
 'C:\\Users\\Sauron\\Documents\\jf631\\SED_scripts\\nn_models\\20210216_1D_multiclass_corrupted_noisy_bkg\\NN_6classes_4epochs_130133Train_ac0.7958_130133Test0.7949.h5']

In [4]:
min_y_image_size = 253
date_folder = '20210217'


######### Run the predicting
for model_path in models:

    # Load model
    model = tf.keras.models.load_model(model_path)
    model_basename = os.path.basename(model_path)[:-3]

    if tresholded:
        model_name = model_basename + '_tresholded'
    else:
        model_name = model_basename + '_nontresholded'

    signals = []

    for i, fname in enumerate(tqdm.tqdm(data)):

        exp_npzfile = np.load(fname)
        exp_data = exp_npzfile['exp1d']
        shape = (exp_data.shape[0] *exp_data.shape[1], exp_data.shape[2], 1)
        exp_data_reshape = np.reshape(exp_data, shape)
        exp_preds = model.predict(exp_data_reshape)

        shape = (exp_data.shape[0], exp_data.shape[1], exp_preds.shape[-1])
        exp_pred_reshape = np.reshape(exp_preds, shape)

        exp_pred_reshape = np.moveaxis(exp_pred_reshape, -1, 0)
        s = hs.signals.Signal2D(exp_pred_reshape)

        signals.append(s)

        del exp_npzfile
        del exp_data
        del exp_data_reshape
        del exp_preds
        del exp_pred_reshape
        del s
        gc.collect()

    # Crop and stack
    signals_cropped = [s.isig[:, :min_y_image_size] for s in signals]  ## HARD CODED VALUE!!
    signals = hs.stack(signals_cropped)

    # Add phases in the metadata
    phase_path = model_path[:-3] + '_phases.npy'
    phases = np.load(phase_path)
    phases = [s for s in phases]
    signals.metadata.General.set_item("Phases", phases)


    # Save results in 2 folders
    def save(signal, name, base_root=base_root, model_path=model_path, date_folder=date_folder):
        signal.save(os.path.join(base_root, 'stacked_predictions', date_folder, name), overwrite=True)
        signal.save(os.path.join(os.path.dirname(model_path), 'stacked_predictions', name), overwrite=True)

    # Save the stack with probabilites
    name = 'probability_preditions_' + model_name + '.hspy'
    save(signals, name)

    # Save the sparse categorical results [1 to 6]
    signal_cat = hs.signals.Signal2D(signals.data.argmax(axis=1))
    signal_cat.metadata.General.set_item("Phases", phases)
    name = 'sparse_categorical_' + model_name + '.hspy'
    save(signal_cat, name)

    # Save the one-hot encoded categorical results [0,1,...,0]
    def cont_to_categorical(continous_ar):
        b =  np.zeros_like(continous_ar)
        b[np.arange(len(continous_ar)), np.argmax(continous_ar, axis=1)] = 1
        return b

    signal_cat_expanded = signals.T.map(cont_to_categorical, inplace=False,).T
    name = 'onehot_categorical_' + model_name + '.hspy'
    save(signal_cat_expanded, name)

100%|██████████| 76/76 [03:38<00:00,  2.88s/it]

  0%|          | 0/76 [00:00<?, ?it/s][A
  1%|▏         | 1/76 [00:02<03:01,  2.42s/it][A
  3%|▎         | 2/76 [00:04<02:59,  2.43s/it][A
  4%|▍         | 3/76 [00:07<02:57,  2.43s/it][A
  5%|▌         | 4/76 [00:09<02:55,  2.44s/it][A
  7%|▋         | 5/76 [00:12<02:50,  2.40s/it][A
  8%|▊         | 6/76 [00:14<02:49,  2.42s/it][A
  9%|▉         | 7/76 [00:16<02:42,  2.36s/it][A
 11%|█         | 8/76 [00:19<02:43,  2.40s/it][A
 12%|█▏        | 9/76 [00:21<02:40,  2.39s/it][A
 13%|█▎        | 10/76 [00:24<02:40,  2.43s/it][A
 14%|█▍        | 11/76 [00:26<02:39,  2.46s/it][A
 16%|█▌        | 12/76 [00:29<02:36,  2.45s/it][A
 17%|█▋        | 13/76 [00:31<02:41,  2.57s/it][A
 18%|█▊        | 14/76 [00:34<02:44,  2.65s/it][A
 20%|█▉        | 15/76 [00:37<02:48,  2.76s/it][A
 21%|██        | 16/76 [00:41<02:53,  2.90s/it][A
 22%|██▏       | 17/76 [00:43<02:50,  2.89s/it][A
 24%|██▎       | 18/76 [00:47<03:00,  3.10s/it][A


[########################################] | 100% Completed |  0.1s

[########################################] | 100% Completed |  0.1s


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=64515.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:
print(signals)

In [None]:
%matplotlib qt
signals.plot(cmap='viridis')
signal_cat.plot(cmap='Dark2')
signal_cat_expanded.plot()
