In [None]:
import os 
import pickle
import hashlib

import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy 
from sklearn.metrics import accuracy_score
import numpy as np

from qkeras.qlayers import QDense
from qkeras.quantizers import ternary

os.environ['PATH'] = os.environ['XILINX_VIVADO'] + '/bin:' + os.environ['PATH']
keras.utils.set_random_seed(32)

## Download data

In [None]:
!wget https://zenodo.org/api/records/14427490/files-archive
!unzip files-archive -d data
# !rm files-archive # optional cleanup 

## Setup data

In [None]:
train_data_dir = "./data"
test_data_dir = "./data"
start_location = 100
window_size = 400
end_window = start_location + window_size

In [None]:
"""Loadning training split"""
x_train_path = os.path.join(train_data_dir, f'0528_X_train_0_770.npy')
y_train_path = os.path.join(train_data_dir, f'0528_y_train_0_770.npy')

assert os.path.exists(x_train_path), f"ERROR: File {x_train_path} does not exist."
assert os.path.exists(y_train_path), f"ERROR: File {y_train_path} does not exist."

X_train_val = np.load(x_train_path)
y_train_val = np.load(y_train_path)

# Insure same dataset is loaded 
assert hashlib.md5(X_train_val).hexdigest() == 'b61226c86b7dee0201a9158455e08ffb',  "Checksum failed. Wrong file was loaded or file may be corrupted."
assert hashlib.md5(y_train_val).hexdigest() == 'c59ce37dc7c73d2d546e7ea180fa8d31',  "Checksum failed. Wrong file was loaded or file may be corrupted."

# Get readout window
X_train_val = X_train_val[:,start_location*2:end_window*2]
assert len(X_train_val[0]) == (end_window-start_location)*2, f"ERROR: X_test sample size {len(X_train_val[0])} does not match (start window, end window) ({start_location},{end_window}) size."

In [None]:
"""Loading testing split"""
x_test_path = os.path.join(test_data_dir, f'0528_X_test_0_770.npy')
y_test_path = os.path.join(test_data_dir, f'0528_y_test_0_770.npy')

assert os.path.exists(x_test_path), f"ERROR: File {x_test_path} does not exist."
assert os.path.exists(y_test_path), f"ERROR: File {y_test_path} does not exist."

X_test = np.load(x_test_path)
y_test = np.load(y_test_path)

# Insure same dataset is loaded 
assert hashlib.md5(X_test).hexdigest() == 'b7d85f42522a0a57e877422bc5947cde', "Checksum failed. Wrong file was loaded or file may be corrupted."
assert hashlib.md5(y_test).hexdigest() == '8c9cce1821372380371ade5f0ccfd4a2', "Checksum failed. Wrong file was loaded or file may be corrupted."

# Get readout window
X_test = X_test[:,start_location*2:end_window*2]
assert len(X_test[0]) == (end_window-start_location)*2, f"ERROR: X_test sample size {len(X_test[0])} does not match (start window, end window) ({start_location},{end_window}) size."

## Construct the model
 
QKeras is "Quantized Keras" for deep heterogeneous quantization of ML models. We're using QDense layer instead of Dense. We're also training with model sparsity, since QKeras layers are prunable.

In [None]:
model = keras.models.Sequential()
model.add(QDense(
    4, 
    activation=None, 
    name='fc1',
    input_shape=(800,), 
    kernel_quantizer=ternary(),
    bias_quantizer=ternary(),
))
model.add(BatchNormalization(name='batchnorm1'))
model.add(QDense(
    1, 
    name='fc2', 
    activation='sigmoid', 
    kernel_quantizer=ternary(),
    bias_quantizer=ternary(),
))

print(model.summary())

## Train the model

In [None]:
init_learning_rate = 1e-3
validation_split = 0.05  # 45,000 sample size 
batch_size = 256
epochs = 50
early_stopping_patience = 10

checkpoint_dir = f'checkpoints/'
checkpoint_filename = f'qkeras_ckp_model_best.h5'
ckp_filename = os.path.join(checkpoint_dir, checkpoint_filename)

if os.path.exists(checkpoint_dir) == False:
    print(f'Checkpoint directory {checkpoint_dir} does not exist.')
    print('Creating directory...')
    os.mkdir(checkpoint_dir)

In [None]:
train = False

if train: 
    opt = Adam(learning_rate=init_learning_rate)
    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=early_stopping_patience,
            restore_best_weights=True,
        ),
    ] 
    model.compile(
        optimizer=opt, 
        loss=BinaryCrossentropy(from_logits=False), 
        metrics=['accuracy']
    )
    history = model.fit(
        X_train_val, 
        y_train_val, 
        batch_size=batch_size,
        epochs=epochs, 
        validation_split=validation_split, 
        shuffle=True, 
        callbacks=callbacks,
    )
    
    model.save_weights(os.path.join(checkpoint_dir, 'qkeras_model_best.h5'))
    # Save the history dictionary
    with open(os.path.join(checkpoint_dir, f'qkeras_training_history.pkl'), 'wb') as f:
        pickle.dump(history.history, f)    
else: 
    model.load_weights(os.path.join(checkpoint_dir, checkpoint_filename))

## Check performance

In [None]:
# get ground and excited indices 
e_indices = np.where(y_test == 1)[0]
g_indices = np.where(y_test == 0)[0]

# separate ground and excited samples 
Xe_test = X_test[e_indices]
ye_test = y_test[e_indices]

Xg_test = X_test[g_indices]
yg_test = y_test[g_indices]

# compute total correct for excited state 
ye_pred = model.predict(Xe_test)
ye_pred = np.where(ye_pred < 0.5, 0, 1).reshape(-1)
e_accuracy = accuracy_score(ye_test, ye_pred)

total_correct = (ye_test==ye_pred).astype(np.int8).sum()
total_incorrect = (ye_test!=ye_pred).astype(np.int8).sum()

# compute total correct for ground state 
yg_pred = model.predict(Xg_test)
yg_pred = np.where(yg_pred < 0.5, 0, 1)
g_accuracy = accuracy_score(yg_test, yg_pred)

total_correct = (yg_test==yg_pred).astype(np.int8).sum()
total_incorrect = (yg_test!=yg_pred).astype(np.int8).sum()

# compute fidelity 
test_fidelity = 0.5*(e_accuracy + g_accuracy)
test_fidelity = test_fidelity*2-1
test_fidelity = 1/2 + (0.5*test_fidelity)

y_pred = model.predict(X_test)
test_acc = accuracy_score(y_test, np.where(y_pred < 0.5, 0, 1).reshape(-1))

print('\n===================================')
print('    Accuracy', test_acc)
print('    Fidelity', test_fidelity)