In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.utils import Sequence
import cv2
import os
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras import Model, Input
from sequence_generator import SequenceGenerator
from sklearn.metrics import confusion_matrix, classification_report
from collections import Counter
from tensorflow.keras.layers import (Conv2D, BatchNormalization, MaxPooling2D,
                                     GlobalAveragePooling2D, Dense, Dropout, TimeDistributed, LSTM)



In [None]:
csv_path= "...dataset_SPT_lstm/metadata_SPT_images_sequences.csv"
root_path = "...dataset_SPT_lstm/images_SPT_sequences/"

class_names = ['EXPLOSION', 'NO_EXPLOSION', 'NO_SIGNAL', 'SPATTERING']

batch_size = 32
seq_length = 100
img_size = (128,128)
classes = len(class_names)
epochs = 5

In [None]:
df_full_dataset = pd.read_csv(csv_path)
train_split = 0.7
df_train = df_full_dataset.sample(frac=train_split,
                                  random_state=42)

df_val = df_full_dataset.drop(df_train.index)

train_generator = SequenceGenerator(df_train, 
                                    root_path,
                                    seq_length=seq_length,
                                    img_size=img_size,
                                    classes=classes,
                                    batch_size=batch_size, 
                                    shuffle=True)

validation_generator = SequenceGenerator(df_val,
                                         root_path,
                                         seq_length=seq_length,
                                         img_size=img_size,
                                         classes=classes,
                                         batch_size=batch_size, 
                                         shuffle=False)

In [None]:

x, y = next(iter(train_generator)) 

seq_idx = 0        
start_frame = 45     
num_frames = 11  

plt.figure(figsize=(20, 4))

for i in range(num_frames):
    plt.subplot(1, num_frames, i + 1)
    frame_idx = start_frame + i
    plt.imshow(x[seq_idx, frame_idx][..., ::-1], interpolation='none')
    clase_idx = np.argmax(y[seq_idx, frame_idx])
    plt.title("Frame:" +str(frame_idx)+"\n"+str(class_names[clase_idx]), fontsize=9)
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:

## Build the cnn model 

def cnn_stromboli_model(input_shape_2d=(128,128,3)):
    inputs = Input(shape=input_shape_2d)

    conv1 = Conv2D(4, (4,4), strides=(2,2), padding='same', activation='relu')(inputs)
    batch1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2,2))(batch1)

    conv2 = Conv2D(4, (3,3), strides=(2,2), padding='same', activation='relu')(pool1)
    batch2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2,2))(batch2)

    conv3 = Conv2D(8, (3,3), strides=(2,2), padding='same', activation='relu')(pool2)
    batch3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2,2))(batch3)

    conv4 = Conv2D(256, (3,3), strides=(2,2), padding='same', activation='relu')(pool3)
    batch4 = BatchNormalization()(conv4)
    pool4 = GlobalAveragePooling2D()(batch4)
     


    flatten1 = Dense(512, activation='relu')(pool4)
    drop1 = Dropout(0.5)(flatten1)

    model = Model(inputs, drop1, name="CNN_model")
    return model

## Build the td_cnn_lstm model 

def cnn_lstm_stromboli_model(input_shape, classes):
    input_tensor = Input(shape=input_shape)

    cnn_base_model = cnn_stromboli_model(input_shape_2d=(input_shape[1], input_shape[2], input_shape[3]))
    cnn_base_model.trainable = True

    tdcnn_features = TimeDistributed(cnn_base_model)(input_tensor)

    lstm_1 = LSTM(units=256, return_sequences=True, dropout=0.5)(tdcnn_features)
    lstm_2 = LSTM(units=128, return_sequences=True, dropout=0.5)(lstm_1)
    
    classify = TimeDistributed(Dense(classes, activation='softmax'))(lstm_2)

    model = Model(inputs=input_tensor, outputs=classify)
    return model



input_shape = (seq_length, *img_size, 3)
model = cnn_lstm_stromboli_model(input_shape, classes)


print(model.summary())



In [None]:
#compiler

checkpoint= tf.keras.callbacks.ModelCheckpoint('... path_to_SAve_model .keras', 
                                              save_best_only=True, 
                                              monitor='val_loss')

stop= tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                       patience=5)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = model.fit(train_generator,
                    epochs=epochs,
                    validation_data=validation_generator,
                    callbacks=[checkpoint, stop])

In [None]:
#Plots trends metrics

plt.figure(figsize=(16, 10))
plt.subplot(221)
plt.plot(history.history['loss'], label='Training')
plt.plot(history.history['val_loss'], label='Validation')
plt.title('Loss function TD-CNN-LSTM')
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Loss function', fontsize=12)
plt.legend()
plt.grid(True)

plt.subplot(222)
plt.plot(history.history['accuracy'], label='Training')
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title('Accuracy TD-CNN-LSTM')
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.legend()
plt.grid(True)
plt.show()


In [None]:
cm = np.zeros((classes, classes), dtype=np.int64)

#with tf.device("/CPU:0"): 
for i in range(len(validation_generator)):
    Xb, yb = validation_generator[i]
    pb = model.predict(Xb, batch_size=1, verbose=0)

    yt = np.argmax(yb, axis=-1).ravel()
    yp = np.argmax(pb, axis=-1).ravel()

    cm += confusion_matrix(yt, yp, labels=range(classes))

# Plot
plt.figure(figsize=(8,8))
plt.imshow(cm)
plt.colorbar()
plt.title("Confusion Matrix (Frame)")
plt.xticks(range(len(class_names)), class_names, rotation=45, ha="right")
plt.yticks(range(len(class_names)), class_names)
plt.xlabel("Predicted")
plt.ylabel("True")

thr = cm.max() / 2 if cm.max() else 0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, f"{int(cm[i, j])}", ha="center", va="center",
                 color="black" if cm[i, j] > thr else "white")

plt.tight_layout()
plt.show()
