In [None]:
import numpy as np
import h5py
import setGPU

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Lambda, BatchNormalization, Activation, Concatenate, Dropout, Layer
from tensorflow.keras.layers import ReLU, LeakyReLU
from tensorflow.keras import backend as K
import math
import pickle
from datetime import datetime
from tensorboard import program
import os
import tensorflow_model_optimization as tfmot
from qkeras import QDense, QActivation, QBatchNormalization

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

from functions import preprocess_anomaly_data, custom_loss_negative, custom_loss_training,\
roc_objective,load_model, save_model
from custom_layers import Sampling
from autoencoder_classes import VAE

tsk = tfmot.sparsity.keras

In [None]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

### Load data

In [None]:
# Data = (N,19,3,1).flatten()
with open('/eos/user/e/epuljak/forDelphes/Delphes_QCD_BSM_data.pkl', 'rb') as f:
    X_train_flatten, X_train_scaled, X_test_flatten, X_test_scaled, bsm_data, bsm_target, pt_scaler = pickle.load(f)

### Define parameters for QKeras and Pruning

In [None]:
quant_size = 12
integer = 4
symmetric = 0
pruning='pruned'

In [None]:
if pruning == 'pruned':
    ''' How to estimate the enc step:
            num_samples = input_train.shape[0] * (1 - validation_split)
            end_step = np.ceil(num_samples / batch_size).astype(np.int32) * pruning_epochs
            so, stop pruning at the 7th epoch
    '''
    begin_step = np.ceil((X_train_flatten.shape[0]*0.8)/1024).astype(np.int32)*5
    end_step = np.ceil((X_train_flatten.shape[0]*0.8)/1024).astype(np.int32)*15
    print('Begin step: ' + str(begin_step) + ', End step: ' + str(end_step))
    
    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
                            initial_sparsity=0.0, final_sparsity=0.5,
                            begin_step=begin_step, end_step=end_step)
    print(pruning_schedule.get_config())

### Define model
Prune and quantize only encoder.

In [None]:
latent_dim = 3
input_shape = 57

In [None]:
#encoder
inputArray = Input(shape=(input_shape))
#proba
x = QActivation(f'quantized_bits(16,10,0,alpha=1)')(inputArray)
x = QBatchNormalization()(x)
x = tsk.prune_low_magnitude(Dense(32, kernel_initializer=tf.keras.initializers.HeNormal(seed=42)),\
                                             pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(32, kernel_initializer=tf.keras.initializers.HeNormal(seed=42),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ','+str(integer)+','+ str(symmetric) +'), alpha=1',\
               bias_quantizer='quantized_bits(' + str(quant_size) + ','+ str(integer) + ',' + str(symmetric) +', alpha=1)'),\
                                             pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(QBatchNormalization(), pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(Activation('relu'),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(Dense(16, kernel_initializer=tf.keras.initializers.HeNormal(seed=42)),\
                                        pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(16, kernel_initializer=tf.keras.initializers.HeNormal(seed=42),\
               kernel_quantizer='quantized_bits(' + str(quant_size) + ','+str(integer)+','+ str(symmetric) +', alpha=1)',\
               bias_quantizer='quantized_bits(' + str(quant_size) + ','+ str(integer) + ',' + str(symmetric) +', alpha=1)'),\
                                 pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(QBatchNormalization(), pruning_schedule=pruning_schedule)(x)
x = tsk.prune_low_magnitude(Activation('relu'),pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QActivation('quantized_relu(bits=' + str(quant_size) + ')'),\
                                 pruning_schedule=pruning_schedule)(x)
mu = tsk.prune_low_magnitude(Dense(latent_dim, name = 'latent_mu', kernel_initializer=tf.keras.initializers.HeNormal(seed=42)))(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(latent_dim, kernel_initializer=tf.keras.initializers.HeNormal(seed=42),\
               kernel_quantizer='quantized_bits(' + str(16) + ',6,'+ str(symmetric) +', alpha=1)',\
               bias_quantizer='quantized_bits(' + str(16) + ',6,'+ str(symmetric) +', alpha=1)'),\
                                 pruning_schedule=pruning_schedule)(x)
logvar = tsk.prune_low_magnitude(Dense(latent_dim, name = 'latent_logvar', kernel_initializer=tf.keras.initializers.HeNormal(seed=42)),\
                                 pruning_schedule=pruning_schedule)(x) if quant_size==0\
    else tsk.prune_low_magnitude(QDense(latent_dim, kernel_initializer=tf.keras.initializers.HeNormal(seed=42),\
               kernel_quantizer='quantized_bits(' + str(16) + ',6,'+ str(symmetric) +', alpha=1)',\
               bias_quantizer='quantized_bits(' + str(16) + ',6,'+ str(symmetric) +', alpha=1)'),\
                                 pruning_schedule=pruning_schedule)(x)
# Use reparameterization trick to ensure correct gradient
z = Sampling()([mu, logvar])

# Create encoder
encoder = Model(inputArray, [mu, logvar, z], name='encoder')    
encoder.summary()


#decoder
d_input = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(16, kernel_initializer=tf.keras.initializers.HeNormal(seed=42))(d_input)
x = BatchNormalization()(x)
#x = LeakyReLU(alpha=0.3)(x)
x = Activation('relu')(x)
x = Dense(32, kernel_initializer=tf.keras.initializers.HeNormal(seed=42))(x)    
x = BatchNormalization()(x)
#x = LeakyReLU(alpha=0.3)(x)
x = Activation('relu')(x)
dec = Dense(input_shape, kernel_initializer=tf.keras.initializers.HeNormal(seed=42))(x)
# Create decoder
decoder = Model(d_input, dec, name='decoder')
decoder.summary()

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

In [None]:
# load BP or higher bit width model
model_dir = 'VAE_models/final_models/withCorrectPrefiltering/'
name_encoder ='VAE_encoder_qkeras14_QBN_pw_higher'
name_decoder ='VAE_decoder_qkeras14_QBN_pw_higher'
# , 'QDense': QDense, 'QActivation': QActivation, 'QBatchNormalization': QBatchNormalization
custom_objects={'Sampling': Sampling, 'QDense': QDense, 'QActivation': QActivation, 'QBatchNormalization': QBatchNormalization}

BP_encoder = load_model(model_dir+name_encoder, custom_objects)
BP_decoder = load_model(model_dir+name_decoder, custom_objects)

In [None]:
# set weights for encoder
for i, l in enumerate(vae.encoder.layers):
    if i == 0: continue
    vae.encoder.layers[i].set_weights(BP_encoder.layers[i-1].get_weights()) # i-1 because of QActivation layer (remove when loading from qkeras model)

In [None]:
# check weights
# for i in range(0,12):
#     if i < 2: continue
#     print('QModel layer: '+str(vae.encoder.layers[i])+', BP model layer: '+str(BP_encoder.layers[i-1]))
#     for w1, w2 in zip(vae.encoder.layers[i].get_weights(), BP_encoder.layers[i-1].get_weights()):
#         if len(w1) > 1 and len(w2) > 1:
#             for weight1, weight2 in zip(w1, w2):
#                 print(np.array_equal(weight1, weight2))
#         else:
#             print(np.array_equal(vae.encoder.layers[i].get_weights(), BP_encoder.layers[i-1].get_weights()))

In [None]:
# set weights for encoder
#vae.decoder.load_weights('/eos/user/e/epuljak/autoencoder_models/VAE_decoder_pruned.h5')
for i, l in enumerate(vae.decoder.layers):
    if i == 0: continue
    vae.decoder.layers[i].set_weights(BP_decoder.layers[i].get_weights())

In [None]:
# check weights
# for i, layer in enumerate(vae.decoder.layers):
#     if i == 0: continue
#     print(' New Model layer: '+str(vae.decoder.layers[i])+', BP model layer: '+str(BP_decoder.layers[i]))
#     for w1, w2 in zip(vae.decoder.layers[i].get_weights(), BP_decoder.layers[i].get_weights()):
#         print(np.array_equal(w1,w2))
#         if len(w1) > 1 and len(w2) > 1:
#             for weight1, weight2 in zip(w1, w2):
#                 print(np.array_equal(weight1, weight2))
#         else:
#             print(np.array_equal(vae.decoder.layers[i].get_weights(), BP_decoder.layers[i].get_weights()))

### Train model

In [None]:
callbacks=[]
if pruning=='pruned':
    callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
    #callbacks.append(tfmot.sparsity.keras.PruningSummaries(log_dir='vae_prunning'))

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, verbose=1, mode='auto', min_delta=0.0001, cooldown=2, min_lr=1E-6))
callbacks.append(TerminateOnNaN())
callbacks.append(tf.keras.callbacks.EarlyStopping(monitor='val_loss',verbose=1, patience=10))

In [None]:
EPOCHS = 100
BATCH_SIZE = 1024
VALIDATION_SPLIT = 0.2

In [None]:
print("TRAINING")
history = vae.fit(X_train_flatten, X_train_scaled, epochs = EPOCHS, batch_size = BATCH_SIZE,
                  validation_split=0.2,
                  callbacks=callbacks)


### Save model

In [None]:
final_encoder = tfmot.sparsity.keras.strip_pruning(vae.encoder)
final_encoder.summary()

In [None]:
final_decoder = tfmot.sparsity.keras.strip_pruning(vae.decoder)
final_decoder.summary()

In [None]:
save_model('VAE_models/final_models/withCorrectPrefiltering/VAE_encoder_qkeras12_QBN_pw_higher', final_encoder)
save_model('VAE_models/final_models/withCorrectPrefiltering/VAE_decoder_qkeras12_QBN_pw_higher', final_decoder)

In [None]:
# final_encoder = load_model('VAE_models/final_models/withCorrectPrefiltering/VAE_encoder_pruned', custom_objects={'Sampling': Sampling, 'QDense': QDense, 'QActivation': QActivation})
# final_decoder = load_model('VAE_models/final_models/withCorrectPrefiltering/VAE_decoder_pruned', custom_objects={'Sampling': Sampling})


In [None]:
# check quantizers
for layer in final_encoder.layers:
    if hasattr(layer, "kernel_quantizer"):
        print(layer.name, "kernel:", str(layer.kernel_quantizer_internal), "bias:", str(layer.bias_quantizer_internal))
    elif hasattr(layer, "quantizer"):
        print(layer.name, "quantizer:", str(layer.quantizer))

### Check sparsity of weights - encoder

In [None]:
# check pruned weights
for i, w in enumerate(final_encoder.get_weights()):
    print(
        "{} -- Total:{}, Zeros: {:.2f}%".format(
            final_encoder.weights[i].name, w.size, np.sum(w == 0) / w.size * 100
        )
    )

In [None]:
colors   = ['#a6bddb','#67a9cf','#3690c0','#02818a','#016c59','#014636']

In [None]:
plt.figure(figsize=(10,8))
plt.hist(final_encoder.layers[2].get_weights()[0].reshape((57*32,)), label='Dense 32', bins=100, color=colors[0])
plt.hist(final_encoder.layers[5].get_weights()[0].reshape((32*16,)), label='Dense 16',bins=100, color=colors[1])
plt.hist(final_encoder.layers[8].get_weights()[0].reshape((16*3,)), label='Mu',bins=100, color=colors[2])
plt.hist(final_encoder.layers[9].get_weights()[0].reshape((16*3,)), label='Sigma',bins=100, color=colors[3])
#plt.yscale('log', nonpositive='clip')
plt.legend(loc='best')
plt.xlabel('Weights')
plt.ylabel('Number of Weights')
plt.title('Not pruned')
plt.show();

### Plot training/validation loss
MSE & KL loss

In [None]:
loss_train = np.array(history.history['loss'][:])
kl_loss_train = np.array(history.history['kl_loss'][:])
loss_val = np.array(history.history['val_loss'][:])
kl_loss_val = np.array(history.history['val_kl_loss'][:])

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(loss_train-kl_loss_train, label='Training loss')
plt.plot(loss_val-kl_loss_val, label='Validation loss')
plt.title('Training and validation loss - MSE')
#plt.yscale('log', nonposy='clip')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure()
plt.plot(history.history['kl_loss'][:], label='Training loss')
plt.plot(history.history['val_kl_loss'][:], label='Validation loss')
plt.title('Training and validation KL loss')
#plt.yscale('log', nonposy='clip')
plt.legend(loc='best')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()

### Prediction - background

In [None]:
with open('Delphes_QCD_BSM_data_half2.pkl', 'rb') as f:
    X_train_flatten, X_train_scaled, X_test_flatten, X_test_scaled, bsm_data, bsm_target = pickle.load(f)

In [None]:
qcd_mean, qcd_logvar, qcd_z = final_encoder.predict(X_test_flatten)
qcd_prediction = final_decoder.predict(qcd_z)

### Prediction - Beyond Standard Model events

In [None]:
bsm_labels = ['Leptoquark','A to 4 leptons', 'hChToTauNu', 'hToTauTau']

In [None]:
bsm_results = []

for i, label in enumerate(bsm_labels[:]):
    mean_pred, logvar_pred, z_pred = final_encoder.predict(bsm_data[i])
    bsm_prediction = final_decoder.predict(z_pred)
    bsm_results.append([label, bsm_target[i], bsm_prediction, mean_pred, logvar_pred, z_pred])

### Save results

In [None]:
output_result = 'VAE_result_pruned.h5'

In [None]:
h5f = h5py.File(output_result, 'w')
h5f.create_dataset('QCD', data = X_test_scaled)
h5f.create_dataset('QCD_input', data=X_test_flatten)
h5f.create_dataset('predicted_QCD', data = qcd_prediction)
h5f.create_dataset('encoded_mean_QCD', data = qcd_mean)
h5f.create_dataset('encoded_logvar_QCD', data = qcd_logvar)
h5f.create_dataset('encoded_z_QCD', data = qcd_z)
for i, bsm in enumerate(bsm_results):
    h5f.create_dataset('%s_scaled' %bsm[0], data=bsm[1])
    h5f.create_dataset('%s_input' %bsm[0], data=bsm_data[i])
    h5f.create_dataset('predicted_%s' %bsm[0], data=bsm[2])
    h5f.create_dataset('encoded_mean_%s' %bsm[0], data=bsm[3])
    h5f.create_dataset('encoded_logvar_%s' %bsm[0], data=bsm[4])
    h5f.create_dataset('encoded_z_%s' %bsm[0], data=bsm[5])

h5f.close()