In [None]:
import time

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv1D, Dense,  MaxPooling2D, MaxPooling1D, Flatten, Concatenate, GlobalAveragePooling2D, GlobalAveragePooling1D

In [None]:
tf.__version__

In [None]:
PATH = "/tmp/lhcf-cnn"

EPOCH = 2

In [None]:
# Define the TFRecord schema
feature_description = {
    "posdE_01xy": tf.io.FixedLenFeature([384 * 384 * 2], tf.float32),
    "posdE_23x": tf.io.FixedLenFeature([384 * 2], tf.float32),
    "posdE_23y": tf.io.FixedLenFeature([384 * 2], tf.float32),
    "dE": tf.io.FixedLenFeature([16], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

# Function to parse TFRecord records
def parse_tfrecord_fn(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_description)
    
    # Reconstruct the original shapes
    posdE_01xy = tf.reshape(example["posdE_01xy"], (384, 384, 2))
    posdE_23x = tf.reshape(example["posdE_23x"], (384, 2))
    posdE_23y = tf.reshape(example["posdE_23y"], (384, 2))
    dE = tf.reshape(example["dE"], (16,))
    label = example["label"]
    
    return {"posdE_01xy_input": posdE_01xy, "posdE_23x_input": posdE_23x, "posdE_23y_input": posdE_23y, "dE_input": dE}, label

# Load and preprocess data in batches
def load_dataset(tfrecord_file, batch_size=32, shuffle_buffer=1000):
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle_buffer == None:
        dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    else:
        dataset = dataset.shuffle(shuffle_buffer).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# Create training and validation datasets
train_dataset = load_dataset(f"{PATH}/train.tfrecord", batch_size=32)
validation_dataset = load_dataset(f"{PATH}/validation.tfrecord", batch_size=32, shuffle_buffer=None)

In [None]:
train_dataset.element_spec

In [None]:
train_dataset_noshuffle = load_dataset(f"{PATH}/train.tfrecord", batch_size=32, shuffle_buffer=None)

In [None]:
# Count the number of examples with label 0 and 1 in the train_dataset
count_label_0 = 0
count_label_1 = 0

# Iterate over the train_dataset to count the labels
for _, labels in train_dataset:
    # Convert the label tensors to numpy for easy operations
    labels_numpy = labels.numpy()
    count_label_0 += np.sum(labels_numpy == 0)
    count_label_1 += np.sum(labels_numpy == 1)

print(f"Number of examples with label 0: {count_label_0}")
print(f"Number of examples with label 1: {count_label_1}")
ratio = count_label_0 / count_label_1
print("Ratio: ", ratio)

In [None]:
# Extract a batch from the dataset
for batch in train_dataset_noshuffle.take(1):  # Take the first batch
    example = batch[0]  # Extract features
    label = batch[0]  # Extract labels
    break

# Extract a single example
posdE_01xy_example = example["posdE_01xy_input"].numpy()[2]  # Get the second example from the batch

# Visualize the image from posdE_01xy_input
plt.imshow(posdE_01xy_example[:, :, 0], cmap='viridis')  # Display the first channel
plt.colorbar()
plt.title("posdE_01xy_input - Plane 1")
plt.show()

# To visualize the second channel separately
plt.imshow(posdE_01xy_example[:, :, 1], cmap='viridis')  # Display the second channel
plt.colorbar()
plt.title("posdE_01xy_input - Plane 2")
plt.show()

In [None]:
# Extract a batch from the dataset
for batch in train_dataset_noshuffle.take(1):  # Take the first batch
    example = batch[0]  # Extract features
    label = batch[0]  # Extract labels
    break

# Extract a single example
posdE_23x_example = example["posdE_23x_input"].numpy()[2]  # Get the second example from the batch

# Visualize the two channels as separate lines
plt.figure(figsize=(10, 6))
plt.plot(posdE_23x_example[:, 0], label="Plane 2")
plt.plot(posdE_23x_example[:, 1], label="Plane 3")
plt.title("posdE_23x_input")
plt.xlabel("Index")
plt.ylabel("Value")
plt.legend()
plt.show()


In [None]:
# Extract a batch from the dataset
for batch in train_dataset_noshuffle.take(1):  # Take the first batch
    example = batch[0]  # Extract features
    label = batch[0]  # Extract labels
    break

# Extract a single example
posdE_23y_example = example["posdE_23y_input"].numpy()[2]  # Get the second example from the batch

# Visualize the two channels as separate lines
plt.figure(figsize=(10, 6))
plt.plot(posdE_23y_example[:, 0], label="Plane 2")
plt.plot(posdE_23y_example[:, 1], label="Plane 3")
plt.title("posdE_23y_input")
plt.xlabel("Index")
plt.ylabel("Value")
plt.legend()
plt.show()

In [None]:
# Extract a batch from the dataset
for batch in train_dataset.take(1):  # Take the first batch
    example = batch[0]  # Extract features
    label = batch[0]  # Extract labels
    break

# Extract a single example of dE_input
dE_example = example["dE_input"].numpy()[2]  # Get the second example from the batch

# Visualize dE_input as a bar chart
plt.figure(figsize=(8, 5))
plt.bar(range(len(dE_example)), dE_example)
plt.title("dE_input")
plt.xlabel("Index")
plt.ylabel("Value")
plt.show()

In [None]:
# Neural Network Definition

# Input for Conv2D 
input_posdE_01xy = Input(shape=(384, 384, 2), name="posdE_01xy_input")
x1 = Conv2D(4, (3, 3), activation="relu", padding="same")(input_posdE_01xy)
x1 = MaxPooling2D((2, 2))(x1)
# Uncomment this line to reduce parameters further
# x1 = GlobalAveragePooling2D()(x1)
x1 = Flatten()(x1)

# Input for Conv1D 
input_posdE_23x = Input(shape=(384, 2), name="posdE_23x_input")
x2 = Conv1D(4, 3, activation="relu", padding="same")(input_posdE_23x)
x2 = MaxPooling1D(2)(x2)
# Uncomment this line to reduce parameters further
# x2 = GlobalAveragePooling1D()(x2)
x2 = Flatten()(x2)

input_posdE_23y = Input(shape=(384, 2), name="posdE_23y_input")
x3 = Conv1D(4, 3, activation="relu", padding="same")(input_posdE_23y)
x3 = MaxPooling1D(2)(x3)
# Uncomment this line to reduce parameters further
# x3 = GlobalAveragePooling1D()(x3)
x3 = Flatten()(x3)

# Input for Dense
input_dE = Input(shape=(16,), name="dE_input")
x4 = Dense(4, activation="relu")(input_dE)

# Combine the outputs of all branches
x = Concatenate()([x1, x2, x3, x4])

# Output for binary classification
output = Dense(1, activation="sigmoid", name="output")(x)

# Define the model
model = Model(inputs=[input_posdE_01xy, input_posdE_23x, input_posdE_23y, input_dE], outputs=output)

In [None]:
# Compile the model
from tensorflow.keras.metrics import AUC
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", AUC()])

# Display the model summary
model.summary()

In [None]:
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)

In [None]:
# Define class weights (optional)
# class_weight = {0: 1, 1: 3}

# Train the model
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=EPOCH,
    # class_weight=class_weight
)

In [None]:
plt.plot(history.history['auc'], label='Training AUC')
plt.plot(history.history['val_auc'], label='Validation AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()
plt.show()

In [None]:
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
plt.plot(history.history['loss'], label='Training Accuracy')
plt.plot(history.history['val_loss'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Compute true labels and predictions for the training dataset without shuffle
y_train_true = np.concatenate([y for _, y in train_dataset_noshuffle.as_numpy_iterator()], axis=0)
train_predictions = model.predict(train_dataset_noshuffle)

# Split training predictions by labels
train_predictions_0 = train_predictions[y_train_true == 0]
train_predictions_1 = train_predictions[y_train_true == 1]

# Compute true labels and predictions for the validation dataset without shuffle
y_val_true = np.concatenate([y for _, y in validation_dataset.as_numpy_iterator()], axis=0)
val_predictions = model.predict(validation_dataset)

# Split validation predictions by labels
val_predictions_0 = val_predictions[y_val_true == 0]
val_predictions_1 = val_predictions[y_val_true == 1]

In [None]:
# Plot normalized histograms
plt.figure(figsize=(10, 6))

# Histogram for train_predictions_label_0 and train_predictions_label_1
# Normalize each histogram so that the total area equals 1
train_hist_0, bins_0, _ = plt.hist(
    train_predictions_0, bins=100, alpha=0.4, color='darkorange', 
    label='Train - Label 0', edgecolor='black', density=True
)
train_hist_1, bins_1, _ = plt.hist(
    train_predictions_1, bins=100, alpha=0.4, color='blue', 
    label='Train - Label 1', edgecolor='black', density=True
)

# Histograms for validation (without visualization)
val_hist_0, bin_val_0 = np.histogram(val_predictions_0, bins=100, density=True)
val_hist_1, bin_val_1 = np.histogram(val_predictions_1, bins=100, density=True)

# Plot validation histograms
plt.plot(bin_val_0[1:], val_hist_0, '*', color='darkorange', label='Validation - Label 0')
plt.plot(bin_val_1[1:], val_hist_1, '*', color='blue', label='Validation - Label 1')

# Add labels and legend
plt.xlabel("Prediction Values")
plt.ylabel("Normalized Density")
plt.title("Normalized Histogram of Predictions for Train and Validation")
plt.legend()
plt.grid(True)

plt.show()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Predictions for the entire dataset
y_train_pred = (train_predictions >= 0.5).astype(int)
y_val_pred = (val_predictions >= 0.5).astype(int)

# Compute the confusion matrix for training
train_cm = confusion_matrix(y_train_true, y_train_pred)
ConfusionMatrixDisplay(train_cm, display_labels=['Label 0', 'Label 1']).plot()
plt.title('Confusion Matrix - Training')
plt.show()

# Compute the confusion matrix for validation
val_cm = confusion_matrix(y_val_true, y_val_pred)
ConfusionMatrixDisplay(val_cm, display_labels=['Label 0', 'Label 1']).plot()
plt.title('Confusion Matrix - Validation')
plt.show()

## Test only dense model

In [None]:
# Input for Dense layer
input_dE = Input(shape=(16,), name="dE_input")
x4 = Dense(4, activation="relu")(input_dE)

# Output for binary classification
output = Dense(1, activation="sigmoid", name="output")(x4)

# Define the model
model_dense = Model(inputs=[input_dE], outputs=output)

In [None]:
# Compile the model
model_dense.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy", AUC()])

# Display the model summary
model_dense.summary()

In [None]:
keras.utils.plot_model(model_dense, "dense_model.png", show_shapes=True)

In [None]:
# Train the model
history_dense = model_dense.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=EPOCH,
)

In [None]:
plt.plot(history_dense.history['auc_1'], label='Training AUC')
plt.plot(history_dense.history['val_auc_1'], label='Validation AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()
plt.show()

In [None]:
plt.plot(history_dense.history['accuracy'], label='Training Accuracy')
plt.plot(history_dense.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
plt.plot(history_dense.history['loss'], label='Training Accuracy')
plt.plot(history_dense.history['val_loss'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Calculate labels and predictions for the training dataset without shuffle
train_predictions_dense = model_dense.predict(train_dataset_noshuffle)

# Split training predictions by labels
train_predictions_0 = train_predictions_dense[y_train_true == 0]
train_predictions_1 = train_predictions_dense[y_train_true == 1]

# Calculate labels and predictions for the validation dataset without shuffle
val_predictions_dense = model_dense.predict(validation_dataset)

# Split validation predictions by labels
val_predictions_0 = val_predictions_dense[y_val_true == 0]
val_predictions_1 = val_predictions_dense[y_val_true == 1]

In [None]:
# Plot normalized histograms
plt.figure(figsize=(10, 6))

# Histogram for train_predictions_label_0 and train_predictions_label_1
# Normalize each histogram so that the total area equals 1
train_hist_0, bins_0, _ = plt.hist(
    train_predictions_0, bins=100, alpha=0.4, color='darkorange', 
    label='Train - Label 0', edgecolor='black', density=True
)
train_hist_1, bins_1, _ = plt.hist(
    train_predictions_1, bins=100, alpha=0.4, color='blue', 
    label='Train - Label 1', edgecolor='black', density=True
)

# Histograms for validation (without direct plotting)
val_hist_0, bin_val_0 = np.histogram(val_predictions_0, bins=100, density=True)
val_hist_1, bin_val_1 = np.histogram(val_predictions_1, bins=100, density=True)

# Plot validation histograms
plt.plot(bin_val_0[1:], val_hist_0, '*', color='darkorange', label='Validation - Label 0')
plt.plot(bin_val_1[1:], val_hist_1, '*', color='blue', label='Validation - Label 1')

# Add labels and legend
plt.xlabel("Prediction Values")
plt.ylabel("Normalized Density")
plt.title("Normalized Histogram of Predictions for Train and Validation")
plt.legend()
plt.grid(True)

plt.show()

In [None]:
# Predictions for the entire dataset
y_train_pred_dense = (train_predictions_dense >= 0.5).astype(int)
y_val_pred_dense = (val_predictions_dense >= 0.5).astype(int)

# Compute the confusion matrix for training
train_cm = confusion_matrix(y_train_true, y_train_pred_dense)
ConfusionMatrixDisplay(train_cm, display_labels=['Label 0', 'Label 1']).plot()
plt.title('Confusion Matrix - Training')
plt.show()

# Compute the confusion matrix for validation
val_cm = confusion_matrix(y_val_true, y_val_pred_dense)
ConfusionMatrixDisplay(val_cm, display_labels=['Label 0', 'Label 1']).plot()
plt.title('Confusion Matrix - Validation')
plt.show()

In [None]:
# comparison of roc curves from different models
from sklearn.metrics import roc_curve, roc_auc_score

def plot_roc(X, c, model, title, fmt=''):
  fpr, tpr, thresholds = roc_curve(c, X)
  plt.plot(1.0 - fpr, tpr, fmt, label=f'{title} (AUC: {100*roc_auc_score(c,X):.1f}%)')

In [None]:
plt.title("ROC curves", fontsize=14)
plt.xlabel("True Negative Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)

plot_roc(train_predictions, y_train_true, model, 'Training Full')
plot_roc(val_predictions, y_val_true, model, "Validation Full")
plot_roc(train_predictions_dense, y_train_true, model_dense, 'Training Dense')
plot_roc(val_predictions_dense, y_val_true, model_dense, "Validation Dense")

plt.xlim(0.75, 1.01)
plt.ylim(0.75, 1.01)
plt.legend(loc='lower left')

plt.show()