In [None]:
import sys, os
import setGPU
import hls4ml
import numpy as np
import h5py
import pickle

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Model, model_from_json
from tensorflow.keras.layers import Input, Dense, Lambda, BatchNormalization, Activation, Concatenate, Dropout, Layer
from tensorflow.keras.layers import ReLU, LeakyReLU, Reshape
from tensorflow.keras import backend as K

from qkeras import QDense, QActivation

import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

from custom_layers import KLLoss
from functions import load_model

In [None]:
with open('/eos/user/e/epuljak/autoencoder_models/test_data.pkl', 'rb') as f:
    X_test_flatten, bsm_data, _ = pickle.load(f)

In [None]:
quant_size=8
custom_objects={'QDense': QDense, 'QActivation': QActivation, 'KLLoss': KLLoss}
encoder = load_model('/eos/user/e/epuljak/autoencoder_models/VAE_encoder_PTQ_qkeras8', custom_objects=custom_objects)

In [None]:
encoder.summary()

In [None]:
hardware = 'xcvu9p-flgb2104-2-e'

In [None]:
config = hls4ml.utils.config_from_keras_model(encoder, default_precision='ap_fixed<16,6,AP_RND_CONV,AP_SAT>',
        granularity='name')

In [None]:
# update config
config['LayerName']['input_1'].update({
        'Precision': 'ap_fixed<22,12,AP_RND_CONV,AP_SAT>'
        })
config['LayerName']['q_activation']['Precision']['result'] = 'ap_fixed<16,11,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization']['Precision']['scale'] = 'ap_fixed<16,8>'
config['LayerName']['batch_normalization']['Precision']['bias'] = 'ap_fixed<16,4>'
config['LayerName']['batch_normalization']['Precision']['result'] = 'ap_fixed<22,10,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense']['Precision']['result'] = 'ap_fixed<16,10,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense']['Precision']['accum'] = 'ap_fixed<16,10,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization_1']['Precision']['scale'] = 'ap_fixed<8,2,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization_1']['Precision']['bias'] = 'ap_fixed<8,1,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization_1']['Precision']['result'] = 'ap_fixed<16,10,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_activation_1']['Precision']['result'] = 'ap_fixed<9,4,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_1']['Precision']['result'] = 'ap_fixed<16,7,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_1']['Precision']['accum'] = 'ap_fixed<16,7,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization_2']['Precision']['scale'] = 'ap_fixed<8,2,AP_RND_CONV,AP_SAT>'
config['LayerName']['batch_normalization_2']['Precision']['bias'] = 'ap_fixed<8,1,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_activation_2']['Precision']['result'] = 'ap_fixed<9,4,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_2']['Precision']['weight'] = 'ap_fixed<8,-1,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_2']['Precision']['bias'] = 'ap_fixed<8,-1,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_2']['Precision']['result'] = 'ap_fixed<18,3,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_2']['Precision']['accum'] = 'ap_fixed<18,3,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_3']['Precision']['weight'] = 'ap_fixed<8,-1,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_3']['Precision']['bias'] = 'ap_fixed<8,-1,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_3']['Precision']['result'] = 'ap_fixed<18,3,AP_RND_CONV,AP_SAT>'
config['LayerName']['q_dense_3']['Precision']['accum'] = 'ap_fixed<18,3,AP_RND_CONV,AP_SAT>'
config['LayerName']['kl_loss'].update({
        'Precision': {
            'accum': 'ap_fixed<32,10,AP_RND,AP_SAT>',
            'result': 'ap_fixed<32,10,AP_RND,AP_SAT>'
        },
        'sum_t': 'ap_fixed<32,10>',
        'exp_range': 0.5,
        'exp_table_t': 'ap_fixed<32,10,AP_RND,AP_SAT>',
        'table_size': 1024*4
    })

In [None]:
def print_dict(d, indent=0):
    align=20
    for key, value in d.items():
        print('  ' * indent + str(key), end='')
        if isinstance(value, dict):
            print()
            print_dict(value, indent+1)
        else:
            print(':' + ' ' * (20 - len(key) - 2 * indent) + str(value))

print_dict(config)

In [None]:
for layer in config['LayerName'].keys():
    config['LayerName'][layer]['Trace'] = True
hls_model = hls4ml.converters.convert_from_keras_model(encoder,
                                                       hls_config=config,
                                                       output_dir='output/DVAE_PTQ/xcvu9p-2/',
                                                       fpga_part=hardware)

In [None]:
hls4ml.utils.plot_model(hls_model, show_shapes=True, show_precision=True, to_file='ptq_VAE_qkeras_%d.pdf'%quant_size)
hls4ml.model.profiling.numerical(model=encoder, hls_model=hls_model, X=X_test_flatten[:100000])

In [None]:
hls4ml.model.profiling.compare(keras_model=encoder, hls_model=hls_model, X=X_test_flatten[:100000], plot_type='norm_diff')

## CHECK ROCs Keras vs HLS model

In [None]:
y = encoder.predict(X_test_flatten)
y_hls = hls_model.predict(X_test_flatten)

In [None]:
# for KL layer output
kl_loss_total = []
kl_loss_total.append(y) #keras
kl_loss_total.append(y_hls) #hls

In [None]:
bsm_labels = ['Leptoquark','A to 4 leptons', 'hChToTauNu', 'hToTauTau']
labels = ['QCD keras', 'QCD hls',\
          r'QKeras LQ $\rightarrow$ b$\tau$', r'HLS LQ $\rightarrow$ b$\tau$',\
          r'QKeras A $\rightarrow$ 4L', r'HLS A $\rightarrow$ 4L',\
          r'QKeras $h_{\pm} \rightarrow \tau\nu$', r'HLS $h_{\pm} \rightarrow \tau\nu$',\
          r'QKeras $h_{0} \rightarrow \tau\tau$', r'HLS $h_{0} \rightarrow \tau\tau$']
loss = '$D_{KL}$'

colors = ['C1','C2', 'C3', 'C4', 'C5', 'C6']

In [None]:
for i, label in enumerate(bsm_labels):
    hls4ml_pred = hls_model.predict(bsm_data[i])
    keras_pred = encoder.predict(bsm_data[i])
    
    kl_loss_total.append(keras_pred) #keras
    kl_loss_total.append(hls4ml_pred) #hls
    print("========================================================================")

In [None]:
minScore = 999999.
maxScore = 0
for i in range(len(labels)):
    thisMin = np.min(kl_loss_total[i])
    thisMax = np.max(kl_loss_total[i])
    minScore = min(thisMin, minScore)
    maxScore = max(maxScore, thisMax)

In [None]:
bin_size=100
plt.figure(figsize=(10,8))
z = 0
for i, label in enumerate(labels):
    if i%2==0:
        plt.hist(kl_loss_total[i].reshape(kl_loss_total[i].shape[0]*1), bins=bin_size, label=label, density = True, range=(minScore, maxScore),
         histtype='step', fill=False, linewidth=1.5, color=colors[z])
    if i%2==1:
        plt.hist(kl_loss_total[i].reshape(kl_loss_total[i].shape[0]*1), bins=bin_size, label=label, density = True, range=(minScore, maxScore),
         histtype='step', fill=False, linewidth=1.5, alpha=0.6, color=colors[z])
        z = z+1
#plt.semilogx()
plt.semilogy()
plt.xlabel("Loss")
plt.ylabel("Probability (a.u.)")
plt.grid(True)
plt.title('KL loss')
plt.legend(loc='best')
plt.show()

In [None]:
from sklearn.metrics import roc_curve, auc
tpr_lq=[];fpr_lq=[];auc_lq=[]
tpr_ato4l=[];fpr_ato4l=[];auc_ato4l=[]
tpr_ch=[];fpr_ch=[];auc_ch=[]
tpr_to=[];fpr_to=[];auc_to=[]


target_qcd = np.zeros(kl_loss_total[0].shape[0])
target_qcd_hls = np.zeros(kl_loss_total[1].shape[0])

for i, label in enumerate(labels):
    if i == 0 and i==1: continue
    if i%2==0:
        trueVal = np.concatenate((np.ones(kl_loss_total[i].shape[0]), target_qcd))
        predVal_loss = np.concatenate((kl_loss_total[i], kl_loss_total[0]))

        fpr_loss, tpr_loss, threshold_loss = roc_curve(trueVal, predVal_loss)

        auc_loss = auc(fpr_loss, tpr_loss)
        if i==2:
            tpr_lq.append(tpr_loss)
            fpr_lq.append(fpr_loss)
            auc_lq.append(auc_loss)
        elif i == 4:
            tpr_ato4l.append(tpr_loss)
            fpr_ato4l.append(fpr_loss)
            auc_ato4l.append(auc_loss)
        elif i==6:
            tpr_ch.append(tpr_loss)
            fpr_ch.append(fpr_loss)
            auc_ch.append(auc_loss)
        elif i == 8:
            tpr_to.append(tpr_loss)
            fpr_to.append(fpr_loss)
            auc_to.append(auc_loss)
    if i%2==1:
        
        trueVal = np.concatenate((np.ones(kl_loss_total[i].shape[0]), target_qcd_hls))
        predVal_loss = np.concatenate((kl_loss_total[i], kl_loss_total[1]))

        fpr_loss, tpr_loss, threshold_loss = roc_curve(trueVal, predVal_loss)

        auc_loss = auc(fpr_loss, tpr_loss)
        if i==3:
            tpr_lq.append(tpr_loss)
            fpr_lq.append(fpr_loss)
            auc_lq.append(auc_loss)
        elif i == 5:
            tpr_ato4l.append(tpr_loss)
            fpr_ato4l.append(fpr_loss)
            auc_ato4l.append(auc_loss)
        elif i==7:
            tpr_ch.append(tpr_loss)
            fpr_ch.append(fpr_loss)
            auc_ch.append(auc_loss)
        elif i == 9:
            tpr_to.append(tpr_loss)
            fpr_to.append(fpr_loss)
            auc_to.append(auc_loss)

In [None]:
plt.figure(figsize=(12,8))
for i, (tpr, fpr, auc, L) in enumerate(zip(tpr_lq[:], fpr_lq[:], auc_lq[:], labels[2:4])):
    if i == 1:
        plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[0], alpha=0.6, linestyle='dashed')
    else: 
        plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[0])

for i, (tpr, fpr, auc, L) in enumerate(zip(tpr_ato4l[:], fpr_ato4l[:], auc_ato4l[:], labels[4:6])):
    if i == 1: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[1], alpha = 0.6, linestyle='dashed')
    else: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[1])
for i, (tpr, fpr, auc, L) in enumerate(zip(tpr_ch[:], fpr_ch[:], auc_ch[:], labels[6:8])):
    if i==1: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[2], alpha=0.6, linestyle='dashed')
    else: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[2])

for i, (tpr, fpr, auc, L) in enumerate(zip(tpr_to[:], fpr_to[:], auc_to[:], labels[8:])):
    if i==1: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[3], alpha=0.6, linestyle='dashed')
    else: plt.plot(fpr, tpr, "-", label='%s (auc = %.1f%%)'%(L,auc*100.), linewidth=1.5, color=colors[3])
plt.semilogx()
plt.semilogy()
plt.ylabel("True Positive Rate", fontsize=15)
plt.xlabel("False Positive Rate", fontsize=15)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.grid(True)
plt.legend(bbox_to_anchor=[1.2, 0.5],loc='best',frameon=True)
plt.tight_layout()
plt.plot(np.linspace(0, 1),np.linspace(0, 1), '--', color='0.75')
plt.axvline(0.00001, color='red', linestyle='dashed', linewidth=1)
plt.show()