# Classifying on real data


In [1]:
# Imports
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import sys
import json
import matplotlib.pyplot as plt
from master_scripts.data_functions import *
from master_scripts.analysis_functions import *
from tqdm import tqdm
%load_ext autoreload
%autoreload 2


## Data import

In [None]:
config = {
    "DATA_PATH": "../../data/real/anodedata_short.txt",
    "DATA_FILENAME": "anodedata_short.txt",
    "SAMPLE_PATH": "../../data/sample/CeBr10k_1.txt",
    "MODEL_PATH": "../../data/output/models/",   
    "OUTPUT_PATH": "../../data/output/",
    "CLASSIFIER": "Project-0.97.hdf5",                      
    "SINGLE_ENERGY_MODEL": "cnn_energy_single_r2_0.97.hdf5",    
    "SINGLE_POSITION_MODEL": "cnn_pos_single_mse_0.00083.hdf5",
    "DOUBLE_ENERGY_MODEL": "double_energy_model_name.hdf5",    
    "DOUBLE_POSITION_MODEL": "double_position_model_name.hdf5" 
}



In [None]:
events, images = import_real_data(config) # images not normalized
images = normalize_image_data(images)

# load classification results (doesn't contain images)
with open(config['OUTPUT_PATH']+"events_classified_anodedata_short.json") as fp:
    events = json.load(fp)


# Results
## Plots
### Histogram of descriptor vs predicted class

## Preliminary results on anodedata_short.txt

|Event descriptor | Event type                   | singles | doubles |
| :---           |  :---:                       | :---:   | :---:   |
|        1        |           Implant            |  1007   |  59923  |
|        2        |            Decay             | 174274  |    0    |
|        4        |          Light ion           |  4328   |  96827  |
|       10        |    Decay + Double (time)     |   421   |    0    |
|       12        |  Light ion + Double (time)   |    1    |    0    |
|       16        |        Double (space)        |    2    |   11    |



In [None]:
# Extract doubles
doubles = []
for event_id, event in events.items():
    if event['event_descriptor'] == 16:
        doubles.append(event_id)

In [None]:
# Plot the doubles and event_class from model
fig, ax = plt.subplots(5,3, figsize=(14,20))
idx = 0
for row in ax:
    for ax_col in row:
        if idx >= len(doubles):
            break
        ax_col.imshow(images[events[doubles[idx]]['image_idx']].reshape((16,16)), origin='lower')
        ax_col.text(0,14.6 , "id: " + doubles[idx], fontsize=11, color='red')
        ax_col.text(8,14.6, events[doubles[idx]]['event_class'], fontsize=11, color='red')
        idx += 1
fig.savefig("test.pdf")
    
    