In [None]:
import setGPU

In [None]:
import numpy as np
import h5py
import time
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Model,model_from_json
from tensorflow.keras.layers import Input, Dense, Lambda, BatchNormalization, Activation, Concatenate, Dropout, Layer
from tensorflow.keras.layers import ReLU, LeakyReLU
from tensorflow.keras import backend as K
from qkeras import QDense, QActivation
import math

from datetime import datetime
from tensorboard import program
import os
import pathlib
import tensorflow_model_optimization as tfmot
tsk = tfmot.sparsity.keras

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

from functions import preprocess_anomaly_data, make_mse_loss,\
roc_objective,load_model, save_model
from autoencoder_classes import AE
import pickle

## Load data

In [None]:
# Data = (N,19,3,1).flatten()
with open('/eos/user/e/epuljak/forDelphes/Delphes_QCD_BSM_data.pkl', 'rb') as f:
    X_train_flatten, X_train_scaled, X_test_flatten, X_test_scaled, bsm_data, bsm_target, pt_scaler = pickle.load(f)

### Define parameters for QKeras and Pruning

In [None]:
quant_size = 0
integer = 1
pruning = 'pruned'

In [None]:
if pruning == 'pruned':
    begin_step = np.ceil((X_train_flatten.shape[0]*0.8)/1024).astype(np.int32)*5
    end_step = np.ceil((X_train_flatten.shape[0]*0.8)/1024).astype(np.int32)*15
    print('Begin step: ' + str(begin_step) + ', End step: ' + str(end_step))
    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
                            initial_sparsity=0.0, final_sparsity=0.5,
                            begin_step=begin_step, end_step=end_step)
    print(pruning_schedule.get_config())

### Define model

In [None]:
latent_dim = 3
input_shape = 57

In [None]:
#encoder
inputArray = Input(shape=(input_shape))
# x = Activation('linear')(inputArray) if quant_size==0\
#     else QActivation(f'quantized_bits(16,10,1, alpha=1.0)')(inputArray)
x = BatchNormalization()(inputArray)
x = tsk.prune_low_magnitude(Dense(32, kernel_initializer=tf.keras.initializers.HeUniform()),\
                           pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(32, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)',
               bias_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)'),\
                                 pruning_schedule=pruning_schedule)(x)
x = BatchNormalization()(x)
x = tsk.prune_low_magnitude(LeakyReLU(alpha=0.3),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(Dense(16, kernel_initializer=tf.keras.initializers.HeUniform()),\
                            pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(16, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)',
               bias_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)'),\
                                 pruning_schedule=pruning_schedule)(x)
x = BatchNormalization()(x)
x = tsk.prune_low_magnitude(LeakyReLU(alpha=0.3),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)
encoder = tsk.prune_low_magnitude(Dense(latent_dim, kernel_initializer=tf.keras.initializers.HeUniform()),\
                            pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(latent_dim, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(16) + ',' + str(6) + ',1, alpha=1.0)',\
               bias_quantizer='quantized_bits(' + str(16) + ',' + str(6) + ',1, alpha=1.0)'),\
                                 pruning_schedule=pruning_schedule)(x)
#x = BatchNormalization()(x)
#encoder = tsk.prune_low_magnitude(LeakyReLU(alpha=0.3),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    #else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)

#decoder
x = tsk.prune_low_magnitude(Dense(16, kernel_initializer=tf.keras.initializers.HeUniform()),\
                            pruning_schedule=pruning_schedule)(encoder) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(16, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)',
               bias_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)'),\
                                 pruning_schedule=pruning_schedule)(encoder)
x = BatchNormalization()(x)
x = tsk.prune_low_magnitude(LeakyReLU(alpha=0.3),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(Dense(32, kernel_initializer=tf.keras.initializers.HeUniform()),\
                            pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(32, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)',\
               bias_quantizer='quantized_bits(' + str(quant_size) + ',' + str(integer) + ',1, alpha=1.0)'),\
                                 pruning_schedule=pruning_schedule)(x)
x = BatchNormalization()(x)
x = tsk.prune_low_magnitude(LeakyReLU(alpha=0.3),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)
decoder = tsk.prune_low_magnitude(Dense(input_shape, kernel_initializer=tf.keras.initializers.HeUniform(),  name='output_dense'),\
                                  pruning_schedule=pruning_schedule)(x) if quant_size==0\
        else tsk.prune_low_magnitude(QDense(input_shape, kernel_initializer=tf.keras.initializers.HeUniform(),\
               kernel_quantizer='quantized_bits(' + str(16) + ',' + str(10) + ',1, alpha=1.0)',\
               bias_quantizer='quantized_bits(' + str(16) + ',' + str(10) + ',1, alpha=1.0)'),\
                                  pruning_schedule=pruning_schedule)(x)

#create autoencoder
autoencoder = Model(inputs = inputArray, outputs=decoder)
autoencoder.summary()

In [None]:
ae = AE(autoenc=autoencoder)
ae.compile(optimizer=keras.optimizers.Adam(lr=0.00001))

In [None]:
# transfer weights
model_dir = 'AE_models/final_models/withCorrectPrefiltering/'
name_encoder ='AE_notpruned'
baseline_AE = load_model(model_dir+name_encoder)

# set weights for encoder
for i, l in enumerate(ae.autoencoder.layers):
    if i < 1: continue
    ae.autoencoder.layers[i].set_weights(baseline_AE.layers[i].get_weights())

## Train

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN

callbacks=[]
if pruning=='pruned':
    callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1, mode='auto', min_delta=0.0001, cooldown=2, min_lr=1E-6))
callbacks.append(TerminateOnNaN())
#tf.keras.callbacks.ModelCheckpoint(filepath='{}/AUTOQKERAS_best_2.h5'.format(odir),monitor="val_loss",verbose=0,save_best_only=True),
#tf.keras.callbacks.ModelCheckpoint(filepath='{}/AUTOQKERAS_best_weights_2.h5'.format(odir),monitor="val_loss",verbose=0,save_weights_only=True),
callbacks.append(tf.keras.callbacks.EarlyStopping(monitor='val_loss',verbose=1, patience=10))

In [None]:
EPOCHS = 100
BATCH_SIZE = 1024
VALIDATION_SPLIT = 0.3

In [None]:
print("TRAINING")
history = ae.fit(X_train_flatten, X_train_scaled, epochs = EPOCHS, batch_size = BATCH_SIZE,
                  validation_split=0.2,
                  callbacks=callbacks)


### Save model

In [None]:
final_autoencoder = tfmot.sparsity.keras.strip_pruning(ae.autoencoder)
final_autoencoder.summary()

In [None]:
save_model('AE_models/final_models/withCorrectPrefiltering/AE_pruned', final_autoencoder)

In [None]:
# final_autoencoder = load_model('AE_models/final_models/withCorrectPrefiltering/AE_pruned')

### Plot training/validation loss

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure()
plt.plot(history.history['loss'][0:], label='Training loss')
plt.plot(history.history['val_loss'][0:], label='Validation loss')
plt.title('Training and validation loss - MSE')
#plt.yscale('log', nonposy='clip')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

### Check sparsity of weights

In [None]:
# check pruned weights
for i, w in enumerate(final_autoencoder.get_weights()):
    print(
        "{} -- Total:{}, Zeros: {:.2f}%".format(
            final_autoencoder.weights[i].name, w.size, np.sum(w == 0) / w.size * 100
        )
    )

In [None]:
# plot model after training
plt.figure(figsize=(7,5))
plt.hist(final_model.layers[3].get_weights()[0].reshape((57*32,)), label='Encoder 32', bins=100, alpha=0.5)
plt.hist(final_model.layers[6].get_weights()[0].reshape((32*16,)), label='Encoder 16',bins=100, alpha = 0.7)
plt.hist(final_model.layers[9].get_weights()[0].reshape((16*3,)), label='latent',bins=100)
plt.hist(final_model.layers[10].get_weights()[0].reshape((16*3,)), label='Decoder 16',bins=100, alpha=0.6)
plt.hist(final_model.layers[13].get_weights()[0].reshape((32*16,)), label='Decoder 32',bins=100, alpha=0.7)
plt.hist(final_model.layers[16].get_weights()[0].reshape((32*57,)), label='Output',bins=100,alpha=0.5)

#plt.yscale('log', nonpositive='clip')
plt.legend(loc='best')
plt.xlabel('Weights')
plt.ylabel('Number of Weights')
plt.title('Pruned 5 to 15')
plt.show();

### Prediction - background

In [None]:
with open('/eos/user/e/epuljak/forDelphes/Delphes_QCD_BSM_data_half1.pkl', 'rb') as f:
    X_train_flatten, X_train_scaled, X_test_flatten, X_test_scaled, bsm_data, bsm_target = pickle.load(f)

In [None]:
qcd_prediction = final_autoencoder.predict(X_test_flatten)

### Prediction - Beyond Standard Model events

In [None]:
bsm_labels = ['Leptoquark','A to 4 leptons', 'hChToTauNu', 'hToTauTau']

In [None]:
bsm_results = []

for i, label in enumerate(bsm_labels[:]):
    bsm_prediction = autoencoder.predict(bsm_data[i])
    bsm_results.append([label, bsm_target[i], bsm_prediction])

### Save results

In [None]:
output_result = 'AE_result_pruned.h5'

In [None]:
h5f = h5py.File(output_result, 'w')
h5f.create_dataset('QCD', data = X_test_scaled)
h5f.create_dataset('QCD_input', data=X_test_flatten)
h5f.create_dataset('predicted_QCD', data = qcd_prediction)

for i, bsm in enumerate(bsm_results):
    h5f.create_dataset('%s_scaled' %bsm[0], data=bsm[1])
    h5f.create_dataset('%s_input' %bsm[0], data=bsm_data[i])
    h5f.create_dataset('predicted_%s' %bsm[0], data=bsm[2])

h5f.close()

# T-SNE projections of latent space representations - for QCD

In [None]:
import hls4ml
keras_trace = hls4ml.model.profiling.get_ymodel_keras(autoencoder, X_test_flatten[:10000])

In [None]:
bsm_traces = []

for i, label in enumerate(bsm_labels[:2]):
    bsm_trace = hls4ml.model.profiling.get_ymodel_keras(autoencoder, bsm_data[i][:10000])
    bsm_traces.append(bsm_trace)

In [None]:
# 2D PROJECTIONS
from sklearn.manifold import TSNE

idx_max = 1000
#prediction
z_dset1 = keras_trace['dense_2'][:1000]
z_dset2 = bsm_traces[0]['dense_2'][:idx_max]
z_dset3 = bsm_traces[1]['dense_2'][:idx_max]

z_embedded1 = TSNE(n_components=2).fit_transform(z_dset1)
z_embedded2 = TSNE(n_components=2).fit_transform(z_dset2)
z_embedded3 = TSNE(n_components=2).fit_transform(z_dset3)


plt.figure(figsize=(8,8))
plt.plot(z_embedded1[:,0], z_embedded1[:,1],'o', mew=1.2, mfc='none', label='Standard Model', color='indigo')
plt.plot(z_embedded2[:,0], z_embedded2[:,1],'s', mew=1.2, mfc='none', label=r'LQ $\rightarrow$ b$\tau$', color='forestgreen')
plt.plot(z_embedded3[:,0], z_embedded3[:,1],'v', mew=1.2, mfc='none', label=r'A $\rightarrow$ 4L', color='tomato')

plt.xlabel('$z_\mathrm{1}$')
plt.ylabel('$z_\mathrm{2}$')
plt.legend(loc='best')
plt.savefig('TSNE_AE_pruned_1000.pdf')