In [1]:
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
from tensorflow import keras

import numpy as np
import collections
import pandas as pd

import matplotlib.pyplot as plt

## Upload test dataset

In [None]:
# change path/ upload all 1Mev Test
test_mat = np.load("./proj_raw_data_test_0.npz")['arr_0']
test_labels = np.array(pd.read_csv("targets_test_0.csv")["edep"])

In [None]:
# nan to zero 
test_mat[np.isnan(test_mat)] = 0

## Define costum objects

In [None]:
#******CUSTOM LEARNING RATE******#
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

  def __init__(self, initial_learning_rate, epochs, steps_per_epoch):
    self.initial_learning_rate = initial_learning_rate
    self.epochs = epochs
    self.steps_per_epoch = steps_per_epoch
    self.m = initial_learning_rate / steps_per_epoch
    self.decay_rate = tf.constant((10**-8 / initial_learning_rate)**(((epochs - 1)*steps_per_epoch)**-1), dtype=tf.float32)
    print('decay_rate:', self.decay_rate)

  def __call__(self, step):
    result = tf.cond(tf.less(step, self.steps_per_epoch), 
                   lambda: self.m * (step+1),
                   lambda: self.initial_learning_rate * self.decay_rate**tf.cast(step+1-self.steps_per_epoch, dtype=tf.float32))

    tf.print('lr at step', step, 'is', result, output_stream='file://learning_rates.txt')
    return result  

  def get_config(self):
      return {
          "initial_learning_rate": self.initial_learning_rate,
          "epochs": self.epochs,
          "steps_per_epoch": self.steps_per_epoch
      }

## Upload and test resnet model

In [None]:
resnet_model = tf.keras.models.load_model("./20220603-143050", custom_objects={'MyLRSchedule': MyLRSchedule})

In [None]:
# make predicitons
edep_pred = resnet_model.predict(test_mat)
edep_pred = edep_pred.reshape(len(edep_pred), )

In [None]:
fig = plt.figure(figsize=(14,8))
plt.hist(edep_pred-test_labels, bins=100)
plt.xlabel("E_pred - E_true, MeV")
plt.show()

In [None]:
# 3sigma cut
res = (edep_pred-test_labels)/test_labels
res = res[np.abs(res - res.mean()) < 3*res.std()]

In [None]:
fig = plt.figure(figsize=(14,8))
plt.hist(res, bins=100)
plt.xlabel("res, MeV")
plt.show()