In [1]:
from database_io import *
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import ModelCheckpoint


In [2]:
# path of databases (must exist)
db_path = "db/"

# filenames of databases (this must be sqlite3 databases)
train_fname = "surf17_train.db"
validation_fname = "surf17_validation.db"
test_fname = "surf17_test.db"

dim_syndr = 8
dim_fsyndr = 4
n_steps_net1 = 20
n_steps_net2 = 3

data = DatabaseIO(dim_syndr, dim_fsyndr, n_steps_net1, n_steps_net2)

In [3]:
try:
    data.close_databases()
except:
    pass
data.load_data(db_path + train_fname, db_path + validation_fname, db_path + test_fname)



batch_size = 64
n_batches_train = 5000
n_batches_validation = 10



class DecoderSequence(Sequence):
    def __init__(self, data, batch_size, n_batches, data_type):
        self.data = data
        self.batch_size = batch_size
        self.n_batches = n_batches
        self.data_type = data_type
        self.on_epoch_end()

    def __len__(self):
        return self.n_batches

    def __getitem__(self, idx):
        # return the idx-th batch captured at epoch start
        return self.epoch_batches[idx]

    def on_epoch_end(self):
        """Called automatically by Keras at the end of each epoch."""
        gen = self.data.gen_batches(
            self.batch_size,
            self.n_batches,
            data_type=self.data_type
        )

        self.epoch_batches = []
        for _ in range(self.n_batches):
            batch_x1, batch_x2, batch_fx, batch_l1, batch_l2, batch_y = next(gen)

            # Wrap into Keras multi-input format
            inputs = (batch_x1, batch_x2, batch_fx, batch_l1) #, batch_l1, batch_l2)
            outputs = batch_y
            self.epoch_batches.append((inputs, outputs))

train_seq = DecoderSequence(
    data,
    batch_size=batch_size,
    n_batches=n_batches_train,
    data_type='training'
)

val_seq = DecoderSequence(
    data,
    batch_size=batch_size,
    n_batches=n_batches_validation,
    data_type='validation'
)


loaded databases and checked exclusiveness training, validation, and test keys
N_training=400000, N_validaiton=1000, N_test=5000.


In [4]:
x1 = Input(shape=(None, dim_syndr), name="x1_full")
x2 = Input(shape=(n_steps_net2, dim_syndr), name="x2_recent")
fx = Input(shape=(dim_fsyndr,), name="final_increment")

l1_input = Input(shape=(), dtype=tf.int32, name="seq_len")

x1_masked = layers.Masking(mask_value=0.0)(x1)



mask_layer = layers.Lambda(lambda x: tf.sequence_mask(x[1], maxlen=tf.shape(x[0])[1]), output_shape=(None,))([x1, l1_input])

# ---- Network 1 ---- (full syndrome history)
h1 = layers.LSTM(64, return_sequences=True, kernel_regularizer=keras.regularizers.l2(1e-5))(x1_masked)
# h1 = layers.LSTM(64, return_sequences=True)(x1_masked)
h1 = layers.Dropout(0.2)(h1)
h1 = layers.LSTM(64, kernel_regularizer=keras.regularizers.l2(1e-5))(h1)
h1 = layers.Dropout(0.2)(h1)
p1 = layers.Dense(64, activation="relu", name="p1", kernel_regularizer=keras.regularizers.l2(1e-5))(h1)
p1 = layers.Dropout(0.2)(p1)
p1 = layers.Dense(1, activation="sigmoid", kernel_regularizer=keras.regularizers.l2(1e-5), name="p1_prob")(p1)

# ---- Network 2 ---- (recent syndrome + final increment)
h2 = layers.LSTM(64, kernel_regularizer=keras.regularizers.l2(1e-5), return_sequences=True)(x2)
h2 = layers.Dropout(0.2)(h2)
h2 = layers.LSTM(64, kernel_regularizer=keras.regularizers.l2(1e-5))(h2)
h2 = layers.Dropout(0.2)(h2)

h2_aug = layers.Concatenate()([h2, fx])

p2 = layers.Dense(64, activation="relu", name="p2", kernel_regularizer=keras.regularizers.l2(1e-5))(h2_aug)
p2 = layers.Dropout(0.2)(p2)
p2 = layers.Dense(1, activation="sigmoid", name="p2_prob", kernel_regularizer=keras.regularizers.l2(1e-5))(p2)

# ---- Final combination p = probabilistic sum ---- #
p_final = layers.Lambda(lambda x: x[0]*(1-x[1]) + x[1]*(1-x[0]))([p1, p2])

model = Model(inputs=[x1, x2, fx, l1_input], outputs=p_final)
model.summary()

model.compile(
    loss="binary_crossentropy",
    optimizer=tf.keras.optimizers.Adam(1e-3),
    metrics=["accuracy"]
)


num_epochs = 1000

checkpoint = ModelCheckpoint(
    'best_model.keras',       # file path to save the model
    monitor='val_accuracy',    # metric to monitor
    verbose=1,             # prints message when saving
    save_best_only=True,   # only save if improved
    mode='max'             # 'min' for loss, 'max' for accuracy
)

results = model.fit(
    train_seq,
    steps_per_epoch=n_batches_train,
    epochs=num_epochs,
    verbose=1,
    validation_data=val_seq,
    validation_steps=n_batches_validation,
    callbacks=[checkpoint]
)




  self._warn_if_super_not_called()


Epoch 1/1000
[1m5000/5000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.7789 - loss: 0.3999
Epoch 1: val_accuracy improved from None to 0.61719, saving model to best_model.keras
[1m5000/5000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m121s[0m 23ms/step - accuracy: 0.8635 - loss: 0.2834 - val_accuracy: 0.6172 - val_loss: 0.6096
Epoch 2/1000
[1m4999/5000[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 19ms/step - accuracy: 0.9189 - loss: 0.1959
Epoch 2: val_accuracy improved from 0.61719 to 0.64688, saving model to best_model.keras
[1m5000/5000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 22ms/step - accuracy: 0.9226 - loss: 0.1900 - val_accuracy: 0.6469 - val_loss: 0.6164
Epoch 3/1000
[1m4998/5000[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 20ms/step - accuracy: 0.9321 - loss: 0.1754
Epoch 3: val_accuracy improved from 0.64688 to 0.71719, saving model to best_model.keras
[1m5000/5000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[

KeyboardInterrupt: 

In [None]:
num_samples = 500
for batch in data.gen_batches(num_samples, 1, data_type='validation'):
    errors = 0
    x1, x2, fx, l1, _, y_actual = batch
    y_prob = model.predict((x1,x2,fx,l1))
    for idx in range(y_actual.size):
        y_pred = y_prob[idx] < 0.5
        # print(f"Predicted Probability: {y_prob[idx]}  Prediction: {y_pred}  Actual: {y_actual[idx]}")
        if (y_pred != y_actual[idx]) :
            errors += 1
print(f"Accuracy: {errors / num_samples}")