# MASK ENERGY

In [None]:
import numpy
import h5py
import time
import os, sys
import random
from collections import OrderedDict
import itertools
import matplotlib.pyplot as plt

In [None]:
## ONLY RUN IF YOU WANT TO SPECIFY A SPECIFIC DEVICE
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
# SETTINGS

# 1 = Energy only, 2 = Energy and Zenith
train_variables = 1
num_labels = train_variables

batch_size = 256
dropout = 0.2
learning_rate = 1e-3
DC_drop_value = dropout
IC_drop_value =dropout
connected_drop_value = dropout
min_energy = 5
max_energy = 100.

# HOW MANY EPOCHS? START AT 0? IF YOU GIVE OLD MODEL, IT WILL USE IT!!!
start_epoch = 0
num_epochs = 14
old_model_given = None #"output_plots/%s/current_model_while_running.hdf5"%filename #"output_plots/%s/%s_model.hdf5"%(filename,filename)

#INPUT FILE(S)
files = "NuMu_140000_level2.zst_uncleaned_cleanedpulsesonly_lt100_CC_flat_95bins_36034evtperbinall_file0?.transformed.hdf5"
file_names = sorted(glob.glob(files))

#NAME OF DIRECTORY AND FILE TO SAVE TO
filename = 'numu_flat_E_5_100_CC_uncleaned_3600kevents_Enottransformed'
save = True
save_folder_name = "output_plots/%s/"%(filename)
if save==True:
    if os.path.isdir(save_folder_name) != True:
        os.mkdir(save_folder_name)
        

use_old_reco = False

In [None]:
file1 = file_names[0]
f = h5py.File(file1, 'r')
X_train_DC = f['X_train_DC'][:]
X_train_IC = f['X_train_IC'][:]
f.close()
del f
print("Train Data DC", X_train_DC.shape)
print("Train Data IC", X_train_IC.shape)

In [None]:
from keras.optimizers import SGD
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from cnn_model import make_network

model_DC = make_network(X_train_DC,X_train_IC,num_labels,DC_drop_value,IC_drop_value,connected_drop_value)

In [None]:
# WRITE OWN LOSS FOR MORE THAN ONE REGRESSION OUTPUT
from keras.losses import mean_squared_error
from keras.losses import mean_squared_logarithmic_error
from keras.losses import logcosh
from keras.losses import mean_absolute_percentage_error

def EnergyLoss(y_truth,y_predicted):
    #return mean_squared_logarithmic_error(y_truth[:,0],y_predicted[:,0]) #/120.
    return mean_squared_error(y_truth[:,0],y_predicted[:,0])
    #return mean_absolute_percentage_error(y_truth[:,0],y_predicted[:,0])

def ZenithLoss(y_truth,y_predicted):
    #return logcosh(y_truth[:,1],y_predicted[:,1])
    return mean_squared_error(y_truth[:,1],y_predicted[:,1])

def TrackLoss(y_truth,y_predicted):
    return mean_squared_logarithmic_error(y_truth[:,2],y_predicted[:,2])/100.

## Compile ##
if num_labels == 3:
    def CustomLoss(y_truth,y_predicted):
        energy_loss = EnergyLoss(y_truth,y_predicted)
        zenith_loss = ZenithLoss(y_truth,y_predicted)
        track_loss = TrackLoss(y_truth,y_predicted)
        return energy_loss + zenith_loss + track_loss

    model_DC.compile(loss=CustomLoss,
              optimizer=Adam(lr=learning_rate),
              metrics=[EnergyLoss,ZenithLoss,TrackLoss])

elif num_labels == 2:
    def CustomLoss(y_truth,y_predicted):
        energy_loss = EnergyLoss(y_truth,y_predicted)
        zenith_loss = ZenithLoss(y_truth,y_predicted)
        return energy_loss + zenith_loss

    model_DC.compile(loss=CustomLoss,
              optimizer=Adam(lr=learning_rate),
              metrics=[EnergyLoss,ZenithLoss])
else:
    def CustomLoss(y_truth,y_predicted):
        energy_loss = EnergyLoss(y_truth,y_predicted)
        return energy_loss

    model_DC.compile(loss=EnergyLoss,
                optimizer=Adam(lr=learning_rate),
                metrics=[EnergyLoss])


In [None]:
# Run neural network and record time ##
loss = []
val_loss = []
energy_loss = []
zenith_loss = []
val_energy_loss = []
val_zenith_loss = []


end_epoch = start_epoch + num_epochs
current_epoch = start_epoch
t0 = time.time()
for epoch in range(start_epoch,end_epoch):
    
    ## NEED TO HAVE OUTPUT DATA ALREADY TRANSFORMED!!! ##
    #print("True Epoch %i/%i"%(epoch+1,num_epochs))
    
    # Get new dataset
    input_file = file_names[epoch%len(file_names)]
    print("Now using file %s"%input_file)
    f = h5py.File(input_file, 'r')
    Y_train = f['Y_train'][:]
    X_train_DC = f['X_train_DC'][:]
    X_train_IC = f['X_train_IC'][:]
    X_validate_DC = f['X_validate_DC'][:]
    X_validate_IC = f['X_validate_IC'][:]
    Y_validate = f['Y_validate'][:]
    f.close()
    del f
    
    Y_train_use = Y_train[:,:train_variables]*100
    Y_val_use = Y_validate[:,:train_variables]*100
    
    # Compile model
    if num_labels == 2:
        model_DC.compile(loss=CustomLoss,
              optimizer=Adam(lr=learning_rate),
              metrics=[EnergyLoss,ZenithLoss])
    else:
        model_DC.compile(loss=EnergyLoss,
            optimizer=Adam(lr=learning_rate),
            metrics=[EnergyLoss])
        
    # Use old weights
    if epoch > 0 and not old_model_given:
        last_model = '%scurrent_model_while_running.hdf5'%save_folder_name
        model_DC.load_weights(last_model)
    elif old_model_given:
        print("Using given model %s"%old_model_given)
        model_DC.load_weights(old_model_given)
        old_model_given = None
    else:
        print("Training set: %i, Validation set: %i"%(len(Y_train_use),len(Y_val_use)))
        print(current_epoch,end_epoch)

    
    #Run one epoch with dataset
    network_history = model_DC.fit([X_train_DC, X_train_IC], Y_train_use,
                            validation_data= ([X_validate_DC, X_validate_IC], Y_val_use), #validation_split=0.2,
                            batch_size=batch_size,
                            initial_epoch= current_epoch,
                            epochs=current_epoch+1, #goes from intial to epochs, so need it to be greater than initial
                            callbacks = [ModelCheckpoint('%scurrent_model_while_running.hdf5'%save_folder_name)],
                            verbose=1)
    
    # Save loss
    loss = loss + network_history.history['loss']
    val_loss = val_loss + network_history.history['val_loss']
    energy_loss = energy_loss + network_history.history['EnergyLoss']
    val_energy_loss = val_energy_loss + network_history.history['val_EnergyLoss']
    
    if num_labels == 2:
        zenith_loss = zenith_loss + network_history.history['ZenithLoss']
        val_zenith_loss = val_zenith_loss + network_history.history['val_ZenithLoss']
    
    # Save checkpoint every 7 epochs
    if epoch%len(file_names) == (len(file_names)-1):
        print("saving checkpoint")
        model_DC.save("%s%s_%iepochs_model.hdf5"%(save_folder_name,filename,current_epoch+1))
        
        file = open("%ssaveloss__%iepochs.txt"%(save_folder_name,current_epoch+1),"w")
        if num_labels == 2:
            losses = [loss, energy_loss, zenith_loss, val_loss, val_energy_loss, val_zenith_loss]
            losses_names = ['loss', 'energy_loss', 'zenith_loss', 'val_loss', 'val_energy_loss', 'val_zenith_loss']
        else:
            losses = [loss, energy_loss, val_loss, val_energy_loss]
            losses_names = ['loss', 'energy_loss', 'val_loss', 'val_energy_loss']
        losslen = len(losses_names)
        for a_list in range(0,losslen): 
            file.write("%s = ["%losses_names[a_list])
            for a_loss in losses[a_list]:
                file.write("%s, " %a_loss)
            file.write("]\n")
        file.close()
        
        
    #Kill model and reload -- DOES NOT WORK
    #tf.keras.backend.clear_session()
    #model_DC = make_network(X_train_DC,X_train_IC,num_labels,DC_drop_value,IC_drop_value,connected_drop_value)
        
        
    current_epoch +=1
    
    
    
t1 = time.time()
print("This took me %f minutes"%((t1-t0)/60.))
model_DC.save("%s%s_model.hdf5"%(save_folder_name,filename))

In [None]:
file = open("%ssaveloss.txt"%save_folder_name,"w")
losses = [loss, energy_loss, zenith_loss, val_loss, val_energy_loss, val_zenith_loss]
losses_names = ['loss', 'energy_loss', 'zenith_loss', 'val_loss', 'val_energy_loss', 'val_zenith_loss']
losslen = len(losses_names)
for a_list in range(0,losslen): 
    file.write("%s = ["%losses_names[a_list])
    for a_loss in losses[a_list]:
        file.write("%s, " %a_loss)
    file.write("]\n")
file.close()

In [None]:
#Reformat loss if needed, usually not
print("plot_loss = plot_loss + ", loss)
print("plot_val_loss = plot_val_loss + ", val_loss)
print("plot_energy_loss = plot_energy_loss + ", energy_loss)
print("plot_zenith_loss = plot_zenith_loss + ", zenith_loss)
print("plot_val_energy_loss = plot_val_energy_loss + ", val_energy_loss)
print("plot_val_zenith_loss = plot_val_zenith_loss + ", val_zenith_loss)

In [None]:
plot_history_loss(loss,val_loss,save,save_folder_name,logscale=True)

In [None]:
# Put all the test sets together
Y_test_use = None
X_test_DC_use = None
X_test_IC_use = None

for file in file_names:
    f = h5py.File(file, 'r')
    Y_test = f['Y_test'][:]
    X_test_DC = f['X_test_DC'][:]
    X_test_IC = f['X_test_IC'][:]
    f.close()
    del f
    
    if Y_test_use is None:
        Y_test_use = Y_test
        X_test_DC_use = X_test_DC
        X_test_IC_use = X_test_IC
    else:
        Y_test_use = numpy.concatenate((Y_test_use, Y_test))
        X_test_DC_use = numpy.concatenate((X_test_DC_use, X_test_DC))
        X_test_IC_use = numpy.concatenate((X_test_IC_use, X_test_IC))
print(Y_test_use.shape)

In [None]:
#score = model_DC.evaluate([X_test_DC_use,X_test_IC_use], Y_test_use, batch_size=256)
#print("final score on test data: loss: {:.4f} / accuracy: {:.4f}".format(score[0], score[1]))
#print(network_history.history.keys())
#print(score)

In [None]:
# PREDICT USING NEURAL NETWORK (TEST)
t0 = time.time()
Y_test_predicted = model_DC.predict([X_test_DC_use,X_test_IC_use])
t1 = time.time(loss)
print("This took me %f seconds for %i events"%(((t1-t0)),Y_test_predicted.shape[0]))

In [None]:
from PlottingFunctions import plot_single_resolution
from PlottingFunctions import plot_2D_prediction
from PlottingFunctions import plot_distributions
from PlottingFunctions import plot_bin_slices
from PlottingFunctions import plot_history_loss
from PlottingFunctions import plot_history_loss_split

In [None]:
plots_names = ["Energy", "CosZenith", "Track"]
plots_units = ["GeV", "", "m"]
maxabs_factors = [100., 1., 200.]
if num > 2:
    maxvals = [max_energy, 1., max(Y_test_use[:,2])*maxabs_factors[2]]
else:
    maxvals = [max_energy, 1., 0]
minvals = [min_energy, -1., 0.]
use_fractions = [True, False, True]
for num in range(1,num_labels+1):

    plot_num = 0
    plot_name = plots_names[num]
    plot_units = plots_units[num]
    maxabs_factor = maxabs_factors[num]
    maxval = maxvals[num]
    minval = minvals[num]
    use_frac = use_fractions[num]
    print("Plotting %s at position %i in test output"%(plot_name, num))

    plot_2D_prediction(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                        save,save_folder_name,\
                        minval=minval,maxval=maxval,\
                        variable=plot_name,units=plot_units)
    plot_2D_prediction(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                        save,save_folder_name,\
                        minval=None,maxval=None,\
                        variable=plot_name,units=plot_units)
    plot_single_resolution(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                       save=save,savefolder=save_folder_name,\
                       variable=plot_name,units=plot_units)
    plot_single_resolution(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                       minaxis=-2*maxval,maxaxis=maxval*2,
                       save=save,savefolder=save_folder_name,\
                       variable=plot_name,units=plot_units)
    plot_distributions(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                        save,save_folder_name,\
                        variable=plot_name,units=plot_units)
    plot_bin_slices(Y_test_use[:,plot_num]*maxabs_factor, Y_test_predicted[:,plot_num]*maxabs_factor,\
                        use_fraction = use_frac,\
                        bins=10,min_val=minval,max_val=maxval,\
                       save=True,savefolder=save_folder_name,\
                       variable=plot_name,units=plot_units)
    if num > 0:
        plot_bin_slices(Y_test_use[:,num], Y_test_predicted[:,num], \
                       min_energy = min_energy, max_energy=max_energy, true_energy=Y_test_use[:,0]*max_energy, \
                       use_fraction = False, \
                       bins=10,min_val=minval,max_val=maxval,\
                       save=True,savefolder=save_folder_name,\
                       variable=plot_name,units=plot_units)

In [None]:
#plot_history(network_history,save,save_folder_name)
# ONE WAY TO PLOT, ABOVE METHOD IS MORE COMPACT BUT ALSO MORE CONFUSING 
plot_2D_prediction(Y_test_use[:,0]*max_energy, Y_test_predicted[:,0]*max_energy,save,save_folder_name,bins=int(max_energy-min_energy),minval=min_energy,maxval=max_energy,variable="Energy",units='GeV')
plot_single_resolution(Y_test_use[:,0]*max_energy, Y_test_predicted[:,0]*max_energy,\
                       save=save,savefolder=save_folder_name,\
                       variable="Energy",units='GeV')
plot_bin_slices(Y_test_use[:,0]*max_energy, Y_test_predicted[:,0]*max_energy, \
                       use_fraction = True, \
                       bins=10,min_val=5.,max_val=100.,\
                       save=True,savefolder=save_folder_name,\
                       variable="Energy",units="GeV")
if num_labels > 1:
    plot_single_resolution(Y_test_use[:,1], Y_test_predicted[:,1],\
                       save=save,savefolder=save_folder_name,\
                       variable="CosZenith",units='')
    plot_2D_prediction(Y_test_use[:,1], Y_test_predicted[:,1],save,save_folder_name,minval=-1,maxval=1,variable="CosZenith",units='')
    plot_bin_slices(Y_test_use[:,1], Y_test_predicted[:,1], \
                       use_fraction = False, \
                       bins=10,min_val=-1.,max_val=1.,\
                       save=True,savefolder=save_folder_name,\
                       variable="CosZenith",units="")
if num_labels > 2:
    plot_single_resolution(Y_test_use[:,2], Y_test_predicted[:,2],\
                       save=save,savefolder=save_folder_name,\
                       variable="Track",units='m')
    plot_2D_prediction(Y_test_use[:,2], Y_test_predicted[:,2],save,save_folder_name,minval=0,maxval=150,variable="Track",units='m')


In [None]:
plot_loss = [0.5638999951078618, 0.2239334302986391, 0.19283099259803776, 0.17831793874810273, 0.16933146295170354, 0.1623940180171268, 0.15701697301086504, 0.15175904069072405, 0.1485293520498928, 0.14525219177062468, 0.143026797637058, 0.14061696721136813, 0.13922084267327053, 0.13808920109869366, 0.1359546336845622, 0.1349402300857223, 0.13356690338660182, 0.13329635906354365, 0.13162642174002134, 0.13152821981740725, 0.13094031182784108, 0.12926689196630475, 0.1293103421377404, 0.1285733627405748, 0.12844751202992166, 0.12706007775183095, 0.1274929534685955, 0.1270949258915883, 0.12597768137328932, 0.1255366959911322, 0.1252374675083325, 0.12501614519976817, 0.12405872912757321, 0.12439822209244182, 0.12435718178285252, 0.12360958461380876, 0.12342124097300619, 0.12267334657091009, 0.12307837775705717, 0.12219016584707065, 0.1221145060866197, 0.12247331352731142, 0.12141296436113697, 0.12139676884637733, 0.12083616820901563, 0.12121685678656272, 0.1203979852246249, 0.12055927246278816, 0.12083020453266813, 0.12016058958029342, 0.11986749792894351, 0.11929275963300631, 0.11944469246214576, 0.11897223062079108, 0.11924624522125524, 0.1193746657295588, 0.11855405238334613, 0.11869491323007006, 0.11806207767843921, 0.11838750547033454, 0.1178212810809745, 0.1181448194027188, 0.11824699643142106, 0.11757574292507986, 0.1176088695095417, 0.11708256893992935, 0.11723069140625247, 0.11686804579628911, 0.11721885055449155, 0.11719036693118179, 0.11659451264420748, 0.11640508353711529, 0.11615885877753887, 0.11638572454084159, 0.11564945329298172, 0.11614410708383167, 0.11632392365138963, 0.1156984103696168, 0.11563130562199765, 0.1152476899514935, 0.1150937840757938, 0.11478221347051103, 0.11550504801362754, 0.1154217529873212, 0.11473050427115059, 0.11466010398359529, 0.11469759701219227, 0.11474595719174076, 0.11429667204426827, 0.11474037077880964, 0.11480370693452306, 0.11424703278574988, 0.11396389726147428, 0.11345375379346413, 0.11387615585143376, 0.11342607003951331, 0.11387816977910441, 0.11406401470386696, 0.11349957245695425, 0.11356921858791517, 0.11316319086574932, 0.11331888141492565, 0.11297303198822799, 0.11319221749235142, 0.11348130788365368, 0.11281584686423048, 0.11300189230350934, 0.11251181622213526, 0.11293486133412485, 0.1120793790256926, 0.1126066575789328, 0.11275569432653208, 0.1122723216734285, 0.11204583556251536, 0.11174303001911853, 0.11216458637211742, 0.11165973027744922, 0.1122614345105422, 0.11228974950989071, 0.11163957688843272, 0.11154822432520584, 0.11103302637759702, 0.11146977024289442, 0.11120121633026293, 0.11142351819220121, 0.11199338040009736, 0.11113276943248536, 0.1110711180181008, 0.11089381241321433, 0.11117872452253381, 0.11093390523958062, 0.11106507492654394, 0.11147786068852941, 0.11056505789247685, 0.11062843444662218, 0.11015588289037755, 0.11047463732607224, 0.110280273830119, 0.11055805779870127, 0.1109046842209636]
plot_energy_loss = [0.15550967, 0.037588038, 0.035568137, 0.034639355, 0.033945713, 0.03327894, 0.032966755, 0.032625023, 0.0322728, 0.03226088, 0.032222714, 0.03186591, 0.03165008, 0.031648275, 0.031532273, 0.031265575, 0.031386234, 0.031446733, 0.03118187, 0.031053169, 0.03117458, 0.030990256, 0.030871568, 0.031020671, 0.031003218, 0.030852046, 0.030774754, 0.030767942, 0.030717527, 0.030529108, 0.030689718, 0.030725678, 0.030544037, 0.030486267, 0.03051118, 0.030521581, 0.030305902, 0.030519001, 0.030561931, 0.030357732, 0.030343901, 0.030383606, 0.030305084, 0.030173955, 0.03038362, 0.030393701, 0.03020146, 0.030174686, 0.030171003, 0.030144816, 0.029999666, 0.030219214, 0.030247126, 0.03010139, 0.02998931, 0.030090354, 0.030008905, 0.029896554, 0.0300258, 0.030118976, 0.029999677, 0.029937148, 0.029967625, 0.029886352, 0.029849455, 0.029981095, 0.030034583, 0.029847555, 0.029787917, 0.02986984, 0.029800914, 0.029693434, 0.029927066, 0.02996283, 0.029773384, 0.029693473, 0.029741764, 0.029753594, 0.029670749, 0.029785246, 0.029834652, 0.029716214, 0.029657431, 0.029666867, 0.029662525, 0.029524075, 0.029764755, 0.029761577, 0.029657222, 0.029691195, 0.02965315, 0.02956341, 0.02942251, 0.029621795, 0.029744212, 0.029585026, 0.02952467, 0.029585836, 0.029502572, 0.029405752, 0.029627448, 0.02959573, 0.029502533, 0.029428137, 0.02947501, 0.029440405, 0.02933643, 0.029499734, 0.029578578, 0.029441897, 0.02938074, 0.029420534, 0.029440029, 0.029273937, 0.029457398, 0.0294969, 0.029380139, 0.029315494, 0.029369796, 0.029357472, 0.029196473, 0.029364625, 0.029439462, 0.029300557, 0.029298505, 0.029320626, 0.029277813, 0.029149361, 0.029397367, 0.029456638, 0.029314142, 0.02926144, 0.02933749, 0.029230272, 0.029159846, 0.02927212, 0.029369324, 0.02925237, 0.029204385, 0.029278098]
plot_zenith_loss = [0.40826288, 0.18632348, 0.15726434, 0.1436809, 0.1353834, 0.12911461, 0.124052815, 0.11912437, 0.1162543, 0.11300839, 0.110805124, 0.10874642, 0.10756007, 0.10645682, 0.104434416, 0.10367584, 0.102185026, 0.10185147, 0.10042583, 0.10047168, 0.099778384, 0.09827941, 0.09842643, 0.097561724, 0.09745144, 0.09622283, 0.096716605, 0.0963242, 0.09526584, 0.09499137, 0.09454786, 0.09429675, 0.09350851, 0.093909584, 0.09385347, 0.09310239, 0.09312135, 0.09215571, 0.09253166, 0.091832735, 0.09177018, 0.0920797, 0.09112598, 0.091220856, 0.090459116, 0.090823464, 0.09020548, 0.09037769, 0.09067296, 0.09001863, 0.08986939, 0.08908186, 0.089189686, 0.08886574, 0.08924892, 0.089282274, 0.08854811, 0.088796094, 0.08803863, 0.088267796, 0.08782257, 0.088200666, 0.08827636, 0.08769377, 0.08776977, 0.08710579, 0.08720519, 0.08701516, 0.08742606, 0.08733001, 0.086790845, 0.08670682, 0.08623904, 0.086427584, 0.08588332, 0.08644478, 0.086572744, 0.085946515, 0.08595981, 0.085468285, 0.0852601, 0.08504919, 0.08584607, 0.085759975, 0.08507587, 0.08513062, 0.08493146, 0.08498455, 0.08464664, 0.085048944, 0.08515062, 0.08468747, 0.084551305, 0.08382776, 0.084135115, 0.0838471, 0.084350444, 0.084483504, 0.083988585, 0.08416785, 0.08352567, 0.083717525, 0.08348033, 0.083770126, 0.08399774, 0.08337707, 0.08366991, 0.08302133, 0.08335715, 0.0826491, 0.08321763, 0.083328165, 0.08282708, 0.08276288, 0.082279354, 0.082673885, 0.08228025, 0.08296915, 0.08292618, 0.08229982, 0.082360156, 0.08166964, 0.08203063, 0.08189389, 0.08212457, 0.08266897, 0.08185607, 0.08192282, 0.08149198, 0.08171702, 0.08162981, 0.08179736, 0.082149215, 0.08133359, 0.08146776, 0.08087539, 0.08111539, 0.08102661, 0.08134692, 0.081617005]
plot_val_loss = [0.394022817072, 0.573355834896, 0.250535101387, 0.228237041945, 0.233041720022, 0.247482862849, 0.151887489232, 0.152687380063, 0.190007121193, 0.207976313826, 0.161015151467, 0.226186754887, 0.289305672327, 0.133072072175, 0.210011766468, 0.143608885297, 0.182214033821, 0.164273559331, 0.154381006992, 0.339779197202, 0.248383012771, 0.127894555715, 0.129773491128, 0.140571468295, 0.210105873732, 0.142254334373, 0.12772407969, 0.176625721526, 0.19905991375, 0.200084088553, 0.127621968444, 0.132488157884, 0.153470396469, 0.122736966303, 0.131031131465, 0.169515412473, 0.128074867064, 0.117216309743, 0.188398196719, 0.138289330989, 0.15674331288, 0.168093228985, 0.343597103952, 0.144657619732, 0.119342761664, 0.133307886822, 0.158677554659, 0.219265044858, 0.121408505262, 0.116091851549, 0.11703690687, 0.120524416426, 0.149856689332, 0.118199916274, 0.161296591539, 0.146916129106, 0.124428334351, 0.127492061259, 0.158533703127, 0.131427133765, 0.21825614932, 0.142256931701, 0.118666523368, 0.165502889985, 0.136004000557, 0.120213584607, 0.112217760043, 0.141374603849, 0.118238516617, 0.175675325272, 0.129140498792, 0.118393313852, 0.115885860156, 0.113745267031, 0.131933056648, 0.154631275642, 0.122732612727, 0.128352834146, 0.115539795819, 0.10983907473, 0.11109437114, 0.131788325542, 0.1170774916, 0.113352908733, 0.132185875365, 0.112582477489, 0.127420854052, 0.135296090329, 0.123481410244, 0.130575774666, 0.132585826084, 0.135616242912, 0.113452993252, 0.122115692074, 0.115205643437, 0.113597899602, 0.150643600432, 0.138697112971, 0.148586288891, 0.136189970679, 0.116151088268, 0.112081134482, 0.119023801092, 0.125231547828, 0.126967870403, 0.124733341263, 0.117867576252, 0.132714470727, 0.132237718251, 0.116900582158, 0.190628802653, 0.110554488097, 0.120119609195, 0.124267466409, 0.106809807461, 0.113784762551, 0.114008703089, 0.139564014417, 0.120376021375, 0.114262362795, 0.118270922935, 0.112380055454, 0.144797024767, 0.122139844046, 0.129144456244, 0.110558108886, 0.114949967681, 0.115811549424, 0.142504110647, 0.113877401439, 0.112862766196, 0.120207531305, 0.165889982365, 0.12573349922, 0.120982592423, 0.109405072289, 0.120607218346, 0.153879591373, 0.109135548652, 0.116331835364]
plot_val_energy_loss = [0.0775719955564, 0.0593158043921, 0.0465048626065, 0.0628187432885, 0.0413427315652, 0.0377364903688, 0.0352823026478, 0.0348116159439, 0.0600509867072, 0.0370377153158, 0.0408153794706, 0.0877494812012, 0.0365088619292, 0.031599689275, 0.0495697408915, 0.0343536734581, 0.0365188382566, 0.0321844927967, 0.031935788691, 0.0499625578523, 0.0528747923672, 0.0321327857673, 0.0322325155139, 0.0345133244991, 0.102750509977, 0.0346213467419, 0.036926984787, 0.0609267577529, 0.0484205447137, 0.0566472224891, 0.0304790660739, 0.0338081754744, 0.0326008461416, 0.0306869056076, 0.0380394160748, 0.0327578857541, 0.0345912463963, 0.0299802850932, 0.0624068379402, 0.0362360104918, 0.0389572419226, 0.0440796203911, 0.0749352946877, 0.0422819927335, 0.0297917630523, 0.0413207635283, 0.0406340807676, 0.0383983701468, 0.0303600449115, 0.030128179118, 0.0329958684742, 0.0307873357087, 0.0460160113871, 0.0321281589568, 0.0324660949409, 0.0446351505816, 0.0354426056147, 0.0336453020573, 0.0510150268674, 0.0375271216035, 0.0364569909871, 0.0474719442427, 0.0298326872289, 0.0433725714684, 0.0485762283206, 0.0296570509672, 0.0288932397962, 0.0428793691099, 0.0319277867675, 0.0544824041426, 0.0366701595485, 0.0378783009946, 0.0303648952395, 0.0298844985664, 0.0353197529912, 0.0585823729634, 0.0370423197746, 0.0334083475173, 0.0291820783168, 0.0286782179028, 0.0289947707206, 0.0320187434554, 0.0298707298934, 0.0295429453254, 0.0311469566077, 0.0291054584086, 0.0395205505192, 0.0416622012854, 0.030048949644, 0.0381202101707, 0.0361802577972, 0.0432674735785, 0.0295066833496, 0.0355735532939, 0.0304613169283, 0.0313080698252, 0.0516285337508, 0.0317232310772, 0.0449534840882, 0.0545871593058, 0.034324105829, 0.0303748976439, 0.033246640116, 0.0365763753653, 0.0328441224992, 0.035665884614, 0.0364547967911, 0.0404234193265, 0.0416997410357, 0.0291034728289, 0.0330190062523, 0.02905536443, 0.0312474407256, 0.039144679904, 0.0284027867019, 0.029917165637, 0.0330932214856, 0.0529744401574, 0.0311031397432, 0.0327763147652, 0.0394699536264, 0.031456977129, 0.0352955944836, 0.0342988520861, 0.0370184443891, 0.0289616864175, 0.0334815755486, 0.0377130322158, 0.0544466748834, 0.0337300077081, 0.0315233618021, 0.029530832544, 0.0315783843398, 0.0406441949308, 0.031077362597, 0.0288178343326, 0.0313184298575, 0.0469093173742, 0.0302567780018, 0.03014289774, ]
plot_val_zenith_loss = [0.316430836916, 0.514036953449, 0.204032227397, 0.165422201157, 0.191696166992, 0.209748089314, 0.116612009704, 0.117879398167, 0.129962176085, 0.170938447118, 0.12020304054, 0.138427242637, 0.252778351307, 0.10147818923, 0.160441815853, 0.109250001609, 0.145700469613, 0.132101535797, 0.122440561652, 0.289800554514, 0.195505276322, 0.0957618057728, 0.0975393503904, 0.106063455343, 0.107360117137, 0.107627414167, 0.0907909572124, 0.115702651441, 0.150646179914, 0.143438339233, 0.0971457585692, 0.0986868143082, 0.120859593153, 0.0920443460345, 0.0929920598865, 0.13674557209, 0.0934845209122, 0.0872400924563, 0.12599158287, 0.102048307657, 0.117784507573, 0.124009788036, 0.268672645092, 0.102375894785, 0.0895549729466, 0.0919854566455, 0.118030257523, 0.180852621794, 0.0910465195775, 0.0859640911222, 0.0840432345867, 0.0897377431393, 0.103849686682, 0.0860685929656, 0.128827437758, 0.102281995118, 0.0889874175191, 0.0938480421901, 0.107515193522, 0.0939011573792, 0.181794852018, 0.0947755873203, 0.0888345018029, 0.122133642435, 0.0874286666512, 0.090555883944, 0.0833254605532, 0.098494105041, 0.0863076001406, 0.121190838516, 0.0924712195992, 0.0805192291737, 0.0855230465531, 0.0838647335768, 0.0966108888388, 0.0960490852594, 0.0856944397092, 0.0949526131153, 0.0863588824868, 0.0811661705375, 0.0820997208357, 0.099762186408, 0.0872043520212, 0.0838102772832, 0.101032055914, 0.083475753665, 0.0879005938768, 0.0936380624771, 0.0934232994914, 0.092447668314, 0.096409201622, 0.0923470705748, 0.0839487537742, 0.0865423157811, 0.0847510322928, 0.0822817459702, 0.0990079641342, 0.106972083449, 0.103636652231, 0.0816063806415, 0.0818281248212, 0.0817087739706, 0.0857662111521, 0.0886528864503, 0.0941220670938, 0.089069083333, 0.0814144536853, 0.0922924503684, 0.0905432552099, 0.0877931043506, 0.157614216208, 0.0814997553825, 0.088879160583, 0.0851240754128, 0.0784079059958, 0.0838715583086, 0.0809079259634, 0.0865878611803, 0.0892720520496, 0.0814873203635, 0.0788025483489, 0.0809253454208, 0.10949857533, 0.0878368839622, 0.0921229720116, 0.0815981850028, 0.0814665406942, 0.0781021863222, 0.0880583226681, 0.0801501199603, 0.0813334062696, 0.0906741097569, 0.1343113482, 0.085088737309, 0.0899101346731, 0.0805906802416, 0.0892946720123, 0.106975831091, 0.0788729190826, 0.0861921831965, ]

In [None]:
plot_loss = plot_loss + [0.10991829392893103, 0.11017622807060687, 0.10980781718468646, 0.1102257953427338, 0.10944640406289152, 0.11020226664924969, 0.1103650317832762, 0.10972683219771887, 0.10962380198807738, 0.1092282798760557, 0.10949566513008778, 0.10931205108409343, 0.10982436650358733, 0.10984844686503388, 0.10921370255078497, 0.10913355107968363, 0.1089000931222987, 0.10946971010804823, 0.10881310725899108, 0.10932522480990196, 0.10953113573563283, 0.1088439852782717, 0.108916149915214, 0.10834853946935791, 0.10896559641355305, 0.1086332226940246, 0.10929904286297998, 0.1089212283464707, 0.10841479531258857, 0.10817803078455539, 0.10804837135463866, 0.10837045170222367, 0.10825180127714655, 0.10849728999425007, 0.10847512311209906, 0.10795520223717617, 0.10796817343848708, 0.1079291186767512, 0.10795573705049244, 0.10759288962548966, 0.10806378055053008, 0.10854836108394794, 0.10762614416922539, 0.1076798833079034, 0.1073643726990586, 0.10750004853112478, 0.10730273484478828, 0.10788024465513872, 0.10795079024775939, 0.10728362598221035, 0.10733369583665148, 0.10674998484034075, 0.10729071245383238, 0.10695347178527552, 0.10753407259797051, 0.10768764367798221, 0.10668775160100495, 0.10703836090339164, 0.10686864341039061, 0.10721398673778451, 0.10642119348313243, 0.10688629158015057, 0.10735676838509609, 0.10644077722309822, 0.10657183781059351, 0.10632569075641009, 0.10691504878479992, 0.10621829062474789, 0.10660122590368842, 0.10680027815309273 ]
plot_energy_loss = plot_energy_loss + [0.029138973, 0.029081358, 0.029296922, 0.029310776, 0.029242357, 0.02917456, 0.02918312, 0.029172672, 0.029029772, 0.02920605, 0.029252272, 0.02916364, 0.029168999, 0.029155595, 0.029127637, 0.029056627, 0.02920234, 0.029289512, 0.02910375, 0.029037697, 0.029110642, 0.029126149, 0.028956888, 0.029145438, 0.029218549, 0.029133018, 0.029102467, 0.029133653, 0.02899453, 0.02896206, 0.029136086, 0.029177785, 0.029041683, 0.02899293, 0.029057145, 0.028998215, 0.028935153, 0.029135926, 0.029129399, 0.029026987, 0.028968386, 0.02905134, 0.028967079, 0.02887976, 0.02907758, 0.029124146, 0.02901462, 0.028981164, 0.028968856, 0.028935766, 0.028869124, 0.02904879, 0.02907842, 0.028921178, 0.028943744, 0.02898672, 0.02890321, 0.028861715, 0.028997684, 0.029074242, 0.028948637, 0.028874433, 0.02897862, 0.028924357, 0.028807448, 0.028993323, 0.029055335, 0.028936187, 0.02888634, 0.028879171]
plot_zenith_loss = plot_zenith_loss + [0.080789186, 0.081082374, 0.080523714, 0.0809128, 0.080202416, 0.08102792, 0.08118782, 0.08054758, 0.08060123, 0.08002806, 0.08024807, 0.08014849, 0.08065907, 0.0806936, 0.080087796, 0.08007207, 0.07970134, 0.08018929, 0.07971244, 0.08028202, 0.08041049, 0.079731494, 0.07996243, 0.07919696, 0.079744905, 0.07949863, 0.08021285, 0.07979045, 0.079412706, 0.07920249, 0.078920804, 0.07918835, 0.07920819, 0.079502515, 0.07941368, 0.07895945, 0.079031914, 0.07878297, 0.07882694, 0.078561604, 0.079103835, 0.07949359, 0.07866591, 0.07881098, 0.078279786, 0.07838536, 0.07829409, 0.07890411, 0.07899027, 0.07836102, 0.078463495, 0.077703506, 0.078211956, 0.07803517, 0.07858998, 0.07870036, 0.07777488, 0.07817831, 0.07788679, 0.07815569, 0.07748163, 0.07801273, 0.07837176, 0.077517726, 0.0777627, 0.07733289, 0.077861995, 0.07729061, 0.0777098, 0.07793483]
plot_val_loss = plot_val_loss + [0.141041573864, 0.106204366043, 0.129304644519, 0.11279274756, 0.113638713852, 0.131642428264, 0.150804674667, 0.114048917035, 0.117638524639, 0.108708437444, 0.161494200033, 0.124672583914, 0.16404557663, 0.106979214303, 0.111194762935, 0.116429949059, 0.116666371859, 0.12006722354, 0.107142147563, 0.125568532622, 0.108623125923, 0.108739283938, 0.14402418011, 0.109624515914, 0.137716910469, 0.12300079407, 0.111884810854, 0.107742063022, 0.10806494702, 0.1341643604, 0.116198132388, 0.120131259054, 0.118553284553, 0.123811286898, 0.114579847015, 0.119070598362, 0.110218468106, 0.106277802185, 0.126297671618, 0.106699870973, 0.107474640752, 0.121637861397, 0.112373810399, 0.10909041906, 0.10746504487, 0.112737053313, 0.140975886198, 0.141314039117, 0.128497127285, 0.114965444389, 0.105298184654, 0.10640360242, 0.106606740364, 0.156642075803, 0.112980025217, 0.111781594565, 0.118674670343, 0.127448699936, 0.123843126995, 0.122309698971, 0.135920303516, 0.135456622896, 0.130257609121, 0.120112814918, 0.12403980122, 0.123634161972, 0.109042447888, 0.115540045008, 0.143987012654, 0.108756286112]
plot_val_energy_loss = plot_val_energy_loss + [0.050470802933, 0.0290748011321, 0.0305640175939, 0.0351977944374, 0.0346945375204, 0.0317458622158, 0.038238696754, 0.0314339771867, 0.0359351374209, 0.028605369851, 0.0315122157335, 0.0329869203269, 0.0432911626995, 0.0290328487754, 0.0309980083257, 0.0376694537699, 0.0319585055113, 0.0344555675983, 0.0296017341316, 0.0404364801943, 0.0287220478058, 0.0301562324166, 0.050127916038, 0.0285475328565, 0.0418572463095, 0.0425641536713, 0.0307184644043, 0.0283198859543, 0.0306278821081, 0.0460299104452, 0.0342551283538, 0.0312695875764, 0.0334042459726, 0.0318680219352, 0.0310960784554, 0.038249656558, 0.0324535407126, 0.0286438018084, 0.0405994281173, 0.0285669285804, 0.0289953332394, 0.0359781496227, 0.0299219191074, 0.0285216756165, 0.0303792431951, 0.0303864199668, 0.0535961911082, 0.0483165942132, 0.0489188618958, 0.0330305248499, 0.0284034367651, 0.0285720359534, 0.0290869828314, 0.0342359021306, 0.0320095941424, 0.0347361080348, 0.0341264605522, 0.0376222617924, 0.0311523675919, 0.0386647731066, 0.0334633737803, 0.0305991861969, 0.0373284108937, 0.0383871048689, 0.038463357836, 0.0381085984409, 0.0303225181997, 0.0296203680336, 0.0475682653487, 0.0290893986821 ]
plot_val_zenith_loss = plot_val_zenith_loss + [0.0905662253499, 0.0771290510893, 0.0987328067422, 0.0775972008705, 0.0789380446076, 0.0998969003558, 0.112570129335, 0.0826216191053, 0.0817032530904, 0.0801083669066, 0.129981890321, 0.0916740298271, 0.12074995786, 0.0779497921467, 0.0802035853267, 0.0787634775043, 0.0847107842565, 0.0856143087149, 0.0775328278542, 0.0851218774915, 0.0798995643854, 0.07858402282, 0.093898139894, 0.081080250442, 0.09586147964, 0.0804301947355, 0.0811609774828, 0.0794218853116, 0.0774372965097, 0.0881385952234, 0.081949904561, 0.0888647958636, 0.0851471424103, 0.0919379368424, 0.0834871083498, 0.0808272212744, 0.0777679085732, 0.0776365697384, 0.085697427392, 0.0781268775463, 0.0784753039479, 0.0856613516808, 0.0824556350708, 0.0805719196796, 0.0770882740617, 0.0823527574539, 0.0873786360025, 0.0929834321141, 0.0795871540904, 0.0819352343678, 0.0768972858787, 0.0778296366334, 0.0775202289224, 0.122393108904, 0.0809638425708, 0.0770436525345, 0.0845472738147, 0.0898233950138, 0.0926927030087, 0.0836504027247, 0.102462083101, 0.104854740202, 0.0929405167699, 0.0817229002714, 0.0855797380209, 0.0855218842626, 0.0787225663662, 0.0859119743109, 0.0964067578316, 0.0796663910151]

In [None]:
plot_history_loss(plot_loss,plot_val_loss,save,save_folder_name,logscale=True)
plot_history_loss_split(plot_energy_loss,plot_val_energy_loss,plot_zenith_loss,plot_val_zenith_loss,save=save,savefolder=save_folder_name,logscale=True)