# Import libraries

In [None]:
import os
import random
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.metrics import Precision, Recall, AUC
from keras.optimizers import Adam
from keras.applications.resnet import ResNet50 
from keras.layers import Dense, Dropout, Flatten, RandomCrop, Resizing, Rescaling
from keras.callbacks import TensorBoard, EarlyStopping
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix 


# Set configuration variables

In [None]:
random.seed(100)
np.random.seed(100)
tf.random.set_seed(100)

In [None]:
train_data_dir = 'dataset/train'
test_data_dir = 'dataset/test'

In [None]:
img_width_before_resizing, img_height_before_resizing = 256, 256
img_width, img_height = 224, 224
batch_size = 32
epochs = 50
lr = 1e-5

## Configure callbacks

In [None]:
logdir = os.path.join("logs")
tensorboard_callback = TensorBoard(log_dir=logdir)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=10 
)

# Load datasets

In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
  train_data_dir,
  validation_split=0.25,
  subset="training",
  seed=100,
  image_size=(img_height_before_resizing, img_width_before_resizing),
  batch_size=batch_size)

In [None]:
val_ds = tf.keras.utils.image_dataset_from_directory(
  train_data_dir,
  validation_split=0.25,
  subset="validation",
  seed=100,
  image_size=(img_height_before_resizing, img_width_before_resizing),
  batch_size=batch_size)

In [None]:
test_ds = tf.keras.utils.image_dataset_from_directory(
  test_data_dir,
  seed=100,
  image_size=(img_height_before_resizing, img_width_before_resizing),
  batch_size=batch_size
)

## Print some samples

In [None]:
import matplotlib.pyplot as plt

class_names = train_ds.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Define model

In [None]:
# base_model = ResNet50(input_shape=(img_width, img_height, 3))
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
for layer in base_model.layers:
    layer.trainable = False

In [None]:
# Add new classification layers
model = Sequential()
model.add(RandomCrop(img_height, img_width, seed=123))
model.add(Resizing(img_height, img_width))
model.add(Rescaling(1./255, input_shape=(img_height, img_width, 3)))
model.add(base_model)
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

In [None]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy',
                       Precision(name='precision'),
                       Recall(name='recall'),
                       AUC(name='auc')])

# Train the model

In [None]:
model.fit(train_ds,
          validation_data=val_ds,
          epochs=epochs,
          callbacks=[tensorboard_callback, early_stopping_callback],
)

# Save the model and visualize training on tensorboard

In [None]:
model.save("modelV2")

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

# Testing

In [None]:
loss, acc, prec, rec, auc = model.evaluate(test_ds)
y_true_ds = test_ds.map(lambda x, y: y)
y_true = np.concatenate(list(y_true_ds.as_numpy_iterator()))

y_pred_prob = model.predict(test_ds)
y_pred = np.where(y_pred_prob >= 0.5, 1, 0)  

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix with a plot
plt.imshow(cm)
plt.title('Confusion matrix')
plt.colorbar()
plt.xticks([0, 1], class_names)
plt.yticks([0, 1], class_names)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.show()

# Print the evaluation metrics
print('Loss: {:.4f}'.format(loss))
print('Accuracy: {:.4f}'.format(acc))
print('Precision: {:.4f}'.format(prec))
print('Recall: {:.4f}'.format(rec))
print('AUC: {:.4f}'.format(auc))
print('Confusion Matrix: ')
print(cm)