In [1]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
from keras.utils import image_dataset_from_directory
from PIL import Image
import io
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
import tensorflow as tf


# Load datasets using Keras utilities
batch_size = 32
img_size = (512, 512)

train_ds = image_dataset_from_directory(
    'official_data/train',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)

val_ds = image_dataset_from_directory(
    'official_data/valid',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)

test_ds = image_dataset_from_directory(
    'official_data/test',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)




# Get class names
class_names = train_ds.class_names
num_classes = len(class_names)
input_shape = (512, 512, 1)

# model = keras.Sequential([
#     keras.layers.Input(shape=input_shape, name='intput_layer'),
#     keras.layers.Rescaling(1./255, name='scaling_layer'),  # Normalize pixel values [0,1]
#     keras.layers.Normalization(mean=.5, variance=0.25, name='normalize_layer'),

#     keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', name='conv2d_01'),
#     keras.layers.Conv2D(32, kernel_size=(3, 3), strides=(2, 2), activation='relu', name='conv2d_downscaling01'),
    
#     keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', name='conv2d_02'),
#     keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', strides=(2, 2), name='conv2d_02_downscaling'),
    
#     keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', name='conv2d_03'),
#     #ECA_Layer(k_size=3, name='eca_01'),# ECA insert
#     keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', strides=(2, 2), name='conv2d_03_downscaling'),
#     #ECA_Layer(k_size=3, name='eca_02'),# ECA insert
    
    
#     keras.layers.Conv2D(256, kernel_size=(3, 3), activation='relu',name='conv2d_04'),
#     keras.layers.GlobalAveragePooling2D(name='global_avg_pooling'),
#     keras.layers.Dropout(.2, name='final_dropout'),
#     keras.layers.Dense(num_classes, activation="softmax", name='output_layer'),
# ])




Found 20549 files belonging to 8 classes.
Found 1113 files belonging to 8 classes.
Found 586 files belonging to 8 classes.


In [2]:
epochs = 100

# Learning rate scheduler
def cosine_annealing_scheduler(epoch, lr):
    initial_lr = 1e-3
    min_lr = 1e-6
    T_max = int(epochs / 2)
    
    cosine_decay = 0.5 * (1 + np.cos(np.pi * (epoch % T_max) / T_max))
    new_lr = (initial_lr - min_lr) * cosine_decay + min_lr
    
    return float(new_lr)

# Confusion Matrix callback using Keras TensorBoard
class ConfusionMatrixCallback(keras.callbacks.Callback):
    def __init__(self, val_data, class_names=None):
        super().__init__()
        self.val_data = val_data
        self.class_names = class_names
        self.file_writer = tf.summary.create_file_writer('logs/cm')

    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []
        
        for images, labels in self.val_data:
            preds = self.model.predict(images, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_true.extend(labels.numpy())
            y_pred.extend(preds)

        fig = plt.figure(figsize=(15, 15))
        ax = fig.add_subplot(111)
        ConfusionMatrixDisplay.from_predictions(
            y_true,
            y_pred,
            labels=range(len(self.class_names)),
            display_labels=self.class_names,
            ax=ax
        )
        ax.set_title(f"Confusion Matrix Epoch: {epoch}")
        
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close(fig)
        buf.seek(0)
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        
        with self.file_writer.as_default():
            tf.summary.image("Confusion Matrix", image, step=epoch)


In [3]:
import io
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tensorflow import keras
 
class ConfusionMatrixCallback(keras.callbacks.Callback):
 
    def __init__(self, val_data, class_names=None):
        super().__init__()
        self.val_data = val_data
        self.class_names = class_names
        self.file_writer = tf.summary.create_file_writer('logs/cm')
 
    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []
        for images, labels in self.val_data:
            preds = self.model.predict(images, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_true.extend(labels.numpy())
            y_pred.extend(preds)
 
        # Numeric confusion matrix
        fig1 = plt.figure(figsize=(10, 10))
        ax1 = fig1.add_subplot(111)
        ConfusionMatrixDisplay.from_predictions(
            y_true,
            y_pred,
            labels=range(len(self.class_names)),
            display_labels=self.class_names,
            ax=ax1
        )
        ax1.set_title(f"Confusion Matrix Epoch: {epoch}")
        buf1 = io.BytesIO()
        plt.savefig(buf1, format='png')
        plt.close(fig1)
        buf1.seek(0)
        image1 = tf.image.decode_png(buf1.getvalue(), channels=4)
        image1 = tf.expand_dims(image1, 0)
 
        # Percentage confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=range(len(self.class_names)))
        cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
 
        fig2, ax2 = plt.subplots(figsize=(10, 10))
        sns.heatmap(cm_percentage, annot=True, fmt=".1f", cmap="Blues",
                    xticklabels=self.class_names, yticklabels=self.class_names, ax=ax2)
        ax2.set_xlabel('Predicted Label')
        ax2.set_ylabel('True Label')
        ax2.set_title(f"Confusion Matrix (Percentage) Epoch: {epoch}")
        buf2 = io.BytesIO()
        plt.savefig(buf2, format='png')
        plt.close(fig2)
        buf2.seek(0)
        image2 = tf.image.decode_png(buf2.getvalue(), channels=4)
        image2 = tf.expand_dims(image2, 0)
 
        with self.file_writer.as_default():
            tf.summary.image("Confusion Matrix - Count", image1, step=epoch)
            tf.summary.image("Confusion Matrix - Percentage", image2, step=epoch)

In [4]:
# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(filepath="models/CBAM_model.keras", save_best_only=True, monitor="val_loss"),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=int(epochs/4)),
    keras.callbacks.TensorBoard(log_dir="logs"),
    keras.callbacks.LearningRateScheduler(cosine_annealing_scheduler, verbose=1),
    ConfusionMatrixCallback(val_ds, class_names=class_names),
]
from layers import build_cbam_model

model = build_cbam_model(input_shape=input_shape, num_classes=num_classes)


In [5]:
model.summary()

In [6]:

# Compile and fit
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3, weight_decay=1e-6),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks,
)

# Evaluate
score = model.evaluate(test_ds, verbose=0)
print(f"Test loss: {score[0]}")
print(f"Test accuracy: {score[1]}")


Epoch 1: LearningRateScheduler setting learning rate to 0.001.
Epoch 1/100


2025-07-17 13:16:15.371925: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng57{k2=0,k13=2,k14=2,k18=1,k23=0} for conv %cudnn-conv-bw-filter.8 = (f32[32,1,3,3]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.6304, %bitcast.6771), window={size=3x3}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(train_step)/jit(main)/transpose(jvp(conv2d_01))/conv_general_dilated" source_file="/home/lesliebinbin/codings/github-dh-cv/.venv/lib/python3.11/site-packages/keras/src/backend/jax/nn.py" source_line=356}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-07-17 13:16:15.557112: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.185378892s
Trying algorithm eng57{k2=0,k13=2,k14=2,k18=1,k23=0} for conv %cudn

[1m643/643[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 140ms/step - acc: 0.2952 - loss: 1.8083 - val_acc: 0.3270 - val_loss: 1.6641 - learning_rate: 0.0010

Epoch 2: LearningRateScheduler setting learning rate to 0.0009990143508499217.
Epoch 2/100
[1m643/643[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 118ms/step - acc: 0.3737 - loss: 1.5919 - val_acc: 0.3998 - val_loss: 1.5031 - learning_rate: 9.9901e-04

Epoch 3: LearningRateScheduler setting learning rate to 0.0009960612933065818.
Epoch 3/100
[1m643/643[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 115ms/step - acc: 0.4008 - loss: 1.4963 - val_acc: 0.4025 - val_loss: 1.4587 - learning_rate: 9.9606e-04

Epoch 4: LearningRateScheduler setting learning rate to 0.00099115248173898.
Epoch 4/100
[1m643/643[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m230s[0m 358ms/step - acc: 0.4174 - loss: 1.4387 - val_acc: 0.4160 - val_loss: 1.3838 - learning_rate: 9.9115e-04

Epoch 5: LearningRateScheduler setting 