#### Imports, check that GPU is used

In [None]:
! nvcc  --version
! pip install tensorflow keras --quiet
! pip install keras-tuner --quiet
! pip install keras-applications
! pip install seaborn --quiet
! pip install kaggle --quiet

! pip install pyyaml h5py  # Required to save models in HDF5 format

# needed for AugMix (removed)
# !pip install keras-cv --quiet

In [None]:
# dataset
import shutil
import os

import pandas as pd
import collections

# model
import numpy as np
import tensorflow as tf
import keras
from keras import layers
from keras.applications import MobileNetV3Large

import keras_tuner as kt
# from keras_cv.layers import AugMix

# graphs/stats
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve
)

In [None]:
# confirm TensorFlow sees the GPU
from tensorflow.python.client import device_lib
#assert 'GPU' in str(device_lib.list_local_devices())
#assert len(tf.config.list_physical_devices('GPU')) > 0

tf.config.list_physical_devices('GPU')

### Dataset preparation

The dataset contains 2 folders
*   Infected
*   Parasitized
And a total of 27,558 images.
Acknowledgements
This Dataset is taken from the official NIH Website: https://ceb.nlm.nih.gov/repositories/malaria-datasets/
And uploaded here, so anybody trying to start working with this dataset can get started immediately, as to download the
dataset from NIH website is quite slow.

1. kaggle automatic download
2. load and split the dataset in train/val (80/20) & getting label names
3. compute dataset statistics
4. dataset standardization
5. dataset augmentation

#### Loading, splitting, standardizing the dataset and getting class names

In [4]:
# dataset folder
directory = "./cell_images"

filepath = []
label = []

folds = os.listdir(directory)

for fold in folds:
    f_path = os.path.join(directory, fold)
    imgs = os.listdir(f_path)
    for img in imgs:
        img_path = os.path.join(f_path, img)
        filepath.append(img_path)
        label.append(fold)

# Concatenate data paths with labels
file_path_series = pd.Series(filepath, name='filepath')
Label_path_series = pd.Series(label, name='label')
df_train = pd.concat([file_path_series, Label_path_series], axis=1)

In [None]:
# splitting the dataset and getting class names
img_height = 224 # UPDATED FOR MOBILENET
img_width = 224 # UPDATED FOR MOBILENET
batch_size = 32
SEED = 123 # for reproducibility

# training, test set split
# resizing already handled by TensorFlow
# no need to reshape

print('Loading and splitting the tf_dataset')
train_set, test_set = keras.utils.image_dataset_from_directory(
  directory,
  validation_split=0.2,  # 80/20%
  subset="both",
  # shuffle=False,
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  label_mode='binary'    # parasitized/uninfected
)

In [None]:
# getting the class names
classes = train_set.class_names
num_classes = len(classes)
print(f'[0={classes[0]}, 1={classes[1]}]')

The images need to be standardized, so we use a lambda function to standardize them so that they have values in (0,1)

In [None]:
# train_set = train_set.map(lambda x, y: (x/255, y))
test_set = test_set.map(lambda x, y: (x/255, y))

for image, _ in train_set.take(5):
    img = image.numpy()
    print("Image shape:", img.shape)
    print("Pixel value range: min =", img.min(), ", max =", img.max())

In [None]:

# calculates how many images there are for each class
def num_img_per_class(dataset):
  class_counts = collections.Counter()
  for _, labels in dataset:
      class_indices = labels.numpy().squeeze().astype(int)  # (batch_size, 1) → (batch_size,)
      class_counts.update(map(int, class_indices))
  return class_counts

# Get counts
train_class_counts = num_img_per_class(train_set)
val_class_counts = num_img_per_class(test_set)
print('Training set:', train_class_counts)
print('Validation set:', val_class_counts)

In [None]:
# bar diagram of training and validation classes distribution
labels_map = {0: 'Parasitized', 1: 'Uninfected'}

# Prepare data
x_labels = [labels_map[i] for i in sorted(labels_map.keys())]
x_pos = range(len(x_labels))
train_vals = [train_class_counts[i] for i in sorted(labels_map.keys())]
val_vals = [val_class_counts[i] for i in sorted(labels_map.keys())]

# Set style
sns.set_style("whitegrid")
plt.figure(figsize=(8, 6))

# Plot bars
bar_width = 0.35
bars1 = plt.bar([i - bar_width/2 for i in x_pos], train_vals, width=bar_width, label='Training', color='skyblue')
bars2 = plt.bar([i + bar_width/2 for i in x_pos], val_vals, width=bar_width, label='Test', color='orange')

# Add counts on top of bars
for bar in bars1 + bars2:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height + 100, f'{height}', ha='center', va='bottom', fontsize=9)

# Final touches
plt.xticks(ticks=x_pos, labels=x_labels)
plt.ylabel("Number of Images")
plt.title("Class Distribution in Training vs Test Set")
plt.legend()
plt.tight_layout()
plt.show()

As you can see the class distribution is almost identical both for the training set and test set.
We will now show some images

In [None]:
plt.figure(figsize=(10, 10))

for images, labels in train_set.take(1):
    num_images = batch_size  # Show more images
    for i in range(num_images):
        ax = plt.subplot(4, 8, i+1)  # 4x8 grid
        plt.imshow((images[i].numpy()).astype("uint8"))
        plt.title(classes[int(labels[i].numpy())], fontsize=6)
        plt.axis("off")

plt.subplots_adjust(hspace=0)
plt.show()

#### Data augmentation (custom pipeline)

We will now extend (double) the training dataset by adding images obtained with various data augmentation techniques. The augmented dataset is merged with the training set to enalrge it. At last, the merged dataset is normalized.

In [None]:

# we have discovered that AugMix, which has great theoretical performances does not work well
# on our medical images as it introduces too many color variations, so we moved to a more
# conventional data augmentation pipeline
# AugMix(
#         severity=1,
#         chain_depth=1,
#         alpha=0.1,
#         value_range=(0, 255),
#     ),
# see https://arxiv.org/abs/1912.02781 for details on AugMix

preprocessing = keras.Sequential([
    # geometric transformations
    layers.RandomRotation(factor=0.2),
    layers.RandomFlip(mode='horizontal_and_vertical'),

    # illumination transformations
    layers.RandomBrightness(factor=0.15),
    layers.RandomContrast(factor=0.15),

    # some noise
    layers.GaussianNoise(stddev=0.05),
    #layers.Rescaling(1./255), # data standardization
])

augmented_dataset = train_set.map(lambda x, y: (preprocessing(x), y))
# concatenate the two datasets to form a big one
train_set = train_set.concatenate(augmented_dataset)

print('Augmented train set size:', augmented_dataset.cardinality().numpy()*batch_size)
print('Merged train set size:', train_set.cardinality().numpy()*batch_size)



Do not rerun this cell on its own otherwise the dataset doubles each time

Now we standardize the training set as well as the merged dataset

In [None]:
train_set = train_set.map(lambda x, y: (x/255, y))

for image, _ in train_set.take(5):
    img = image.numpy()
    print("Image shape:", img.shape)
    print("Pixel value range: min =", img.min(), ", max =", img.max())

The images are preprocessed and standardized

In [None]:
plt.figure(figsize=(10, 10))

for images, labels in train_set.take(1):
    num_images = batch_size  # Show more images
    for i in range(num_images):
        ax = plt.subplot(4, 8, i+1)  # 4x8 grid
        plt.imshow((images[i].numpy() * 255).astype("uint8"))
        plt.title(classes[int(labels[i].numpy())], fontsize=6)
        plt.axis("off")

plt.subplots_adjust(hspace=0)
plt.show()


As you can see the images are changed

### Transfer learning using MobileNetV3 & KerasTuner

1. Build the model replacing the top layers of MobileNetV3 with custom ones
2. Use KerasTuner to select the best number of neurons 
3. Retrain the model on the entire train_set
4. Evaluate the model's performance

In [14]:
def model_builder(hp):
    input_shape = (224, 224, 3) # Height, Width, Channels (RGB)
    mobile_net = keras.applications.MobileNetV3Large(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    mobile_net.trainable = False

    model = keras.Sequential()
    model.add(mobile_net) # add MobileNetV3 to base model
    model.add(layers.GlobalAveragePooling2D())

    # Tune the number of units in the first dense layer
    neurons = [hp.Int('units_1', min_value=64, max_value=512, step=64), hp.Int('units_2', min_value=64, max_value=512, step=64)]
    drops = [hp.Float('dropout_1', min_value=0.0, max_value=0.5, step=0.05), hp.Float('dropout_2', min_value=0.0, max_value=0.5, step=0.05)]

    num_layers = hp.Choice('num_layers', values=[1,2])
    for i in range(num_layers):
        model.add(layers.Dense(units=neurons[i], activation='relu'))
        # tune dropout rate 
        model.add(layers.Dropout(rate=drops[i]))

    # Output layer for binary classification
    model.add(layers.Dense(1, activation='sigmoid'))

    # tune the optimizer
    optimizer_choice = hp.Choice('optimizer', values=['adam', 'sgd', 'rmsprop'])
    lrate = hp.Float('learning_rate', min_value=1e-4, max_value=1e-1, sampling='LOG')

    # defaults to adam
    optimizer = keras.optimizers.Adam()
    match optimizer_choice:
        case 'adam':
            optimizer = keras.optimizers.Adam(learning_rate=lrate)
        case 'sgd':
            optimizer = keras.optimizers.SGD(learning_rate=lrate)
        case 'rmsprop':
            optimizer = keras.optimizers.RMSprop(learning_rate=lrate)

    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    return model

In [None]:
# split train into train_reduced/val
train_size = int(train_set.cardinality().numpy() * 0.8)
train_reduced = train_set.take(train_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_set = train_set.skip(train_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# 3. Verify the split (Optional)
print("Reduced training dataset size:", train_reduced.cardinality().numpy()*batch_size)
print("Validation dataset size:", val_set.cardinality().numpy()*batch_size)

In [None]:
tuner = kt.Hyperband(
    hypermodel=model_builder,
    objective='val_accuracy',
    max_epochs=5,
    factor = 3,
    directory = 'tuned_models',
    project_name='malaria_transfer_learning'
)

stop_early = keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
tuner.search(train_reduced, epochs=50, validation_data=val_set, callbacks=[stop_early])

#### Re-train the best model on the entire train_set

In [17]:
# cache and prefetch for faster training
train_set = train_set.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_set = test_set.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

In [18]:
# Get the best model
best_params = tuner.get_best_hyperparameters(num_trials=1)[0]

# create the best model found
best_model = tuner.hypermodel.build(best_params)

Now we Fine-Tune the MobileNet model to improve the performance on our dataset

In [None]:
# retrain the best model on the entire training set
# unfreeze MobileNetV3
best_model.layers[0].trainable = True

fine_tune_lr=1e-4
best_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=fine_tune_lr),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

stop_early = keras.callbacks.EarlyStopping(monitor='loss', patience=2)
history = best_model.fit(train_set, epochs=10, callbacks=[stop_early])

if you want to avoid retraining (very long process) here are the statistics
training accuracy: 0.9682 - loss: 0.0921

Saving the model in an h5 format

In [61]:
# Save the entire model as a SavedModel.
! mkdir -p saved_model
keras.saving.save_model(best_model, 'saved_model/malaria_mobile_net.keras')

#### Testing the model on the test_set

To make this code runnable even without retraining the model we load it from the saved file

In [None]:
# load on the same variable to save memory
best_model = keras.saving.load_model('saved_model/malaria_mobile_net.keras')

After training the model we have achieved a training accuracy of 0.96.. (hardcoded value! may vary slighly), we now test it on the test_set to get an estimate of the generalization error. It may take a while as the test set is composed of more than 5K images.

In [None]:
# Evaluate it on the test set
loss, accuracy = best_model.evaluate(test_set, verbose=0)

print(f'Best model test loss: {loss:.4f}')
print(f'Best model test accuracy: {accuracy:.4f}')

### Some model stats

#### Model summary

In [None]:
best_model.summary()

#### Precision/Recall/F1-score and AUC-ROC

Once again this operation may take a while as the test set is composed of lots of images!

In [40]:
# testing the model performance on the test_set
# to get an estimate of the generalization error
y_true = []
y_probs = []

for x_batch, y_batch in test_set:
    preds = best_model.predict(x_batch, verbose=0).ravel()
    y_probs.extend(preds)
    y_true.extend(y_batch.numpy())

y_true = np.array(y_true)
y_probs = np.array(y_probs)
y_pred = (y_probs > 0.5).astype(int)

In [None]:
print(classification_report(y_true, y_pred, target_names=['Uninfected', 'Parasitized']))
print(f"AUC-ROC: {roc_auc_score(y_true, y_probs):.4f}")

We now plot the ROC curve

In [None]:
fpr, tpr, _ = roc_curve(y_true, y_probs)

plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc_score(y_true, y_probs):.4f}")
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid(True)
plt.show()

We now plot the precision/recall curve

In [None]:
precision, recall, _ = precision_recall_curve(y_true, y_probs)

plt.figure(figsize=(6, 5))
plt.plot(recall, precision, color='purple')
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.grid(True)
plt.show()

#### Confusion matrix

In [None]:
# confusion matrix
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Uninfected', 'Parasitized'],
            yticklabels=['Uninfected', 'Parasitized'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()