In [None]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
import keras_cv
from keras.utils import image_dataset_from_directory
from PIL import Image
import io
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
import jax.numpy as jnp
import tensorflow as tf



# Load datasets using Keras utilities
batch_size = 32
img_size = (512, 512)

train_ds = image_dataset_from_directory(
    'official_data/train',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)

val_ds = image_dataset_from_directory(
    'official_data/valid',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)

test_ds = image_dataset_from_directory(
    'official_data/test',
    image_size=img_size,
    batch_size=batch_size,
    color_mode='grayscale',
)



# Get class names
class_names = train_ds.class_names
num_classes = len(class_names)
input_shape = (512, 512, 1)

# Build model
backbone = keras_cv.models.ResNet18V2Backbone(
    input_shape=input_shape,
    include_rescaling=False,
)
inputs = keras.Input(shape=input_shape)
x = keras.layers.Normalization(mean=.5, variance=0.25)(inputs)  # Normalize for grayscale
x = backbone(x)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs, outputs)

# Learning rate scheduler
def cosine_annealing_scheduler(epoch, lr):
    initial_lr = 1e-3
    min_lr = 1e-6
    T_max = 100
    
    cosine_decay = 0.5 * (1 + np.cos(np.pi * (epoch % T_max) / T_max))
    new_lr = (initial_lr - min_lr) * cosine_decay + min_lr
    
    return float(new_lr)

# Confusion Matrix callback using Keras TensorBoard
class ConfusionMatrixCallback(keras.callbacks.Callback):
    def __init__(self, val_data, class_names=None):
        super().__init__()
        self.val_data = val_data
        self.class_names = class_names
        self.file_writer = tf.summary.create_file_writer('logs/cm')

    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []
        
        for images, labels in self.val_data:
            preds = self.model.predict(images, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_true.extend(labels.numpy())
            y_pred.extend(preds)

        fig = plt.figure(figsize=(15, 15))
        ax = fig.add_subplot(111)
        ConfusionMatrixDisplay.from_predictions(
            y_true,
            y_pred,
            labels=range(len(self.class_names)),
            display_labels=self.class_names,
            ax=ax
        )
        ax.set_title(f"Confusion Matrix Epoch: {epoch}")
        
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close(fig)
        buf.seek(0)
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        
        with self.file_writer.as_default():
            tf.summary.image("Confusion Matrix", image, step=epoch)

# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(filepath="models/best_model.keras", save_best_only=True, monitor="val_loss"),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=8),
    keras.callbacks.TensorBoard(log_dir="logs"),
    keras.callbacks.LearningRateScheduler(cosine_annealing_scheduler, verbose=1),
    ConfusionMatrixCallback(val_ds, class_names=class_names),
]

# Compile and fit
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3, weight_decay=1e-6),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

epochs = 150
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks,
)

# Evaluate
score = model.evaluate(test_ds, verbose=0)
print(f"Test loss: {score[0]}")
print(f"Test accuracy: {score[1]}")