# Autoencoder for uninfected malaria cell images

## Imports

In [None]:
import os
import datetime

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.data.experimental import AUTOTUNE
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Conv2D, DepthwiseConv2D, MaxPooling2D, UpSampling2D, Dropout, LeakyReLU

In [None]:
tf.__version__

### Set logging to Error only

In [None]:
tf.get_logger().setLevel('ERROR')

### Check for GPU

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')

if len(physical_devices) > 0:
   tf.config.experimental.set_memory_growth(physical_devices[0], True)

physical_devices

## Constants

In [None]:
RANDOM_STATE = 7
BASE_PATH = r"..\..\Datasets\Malaria Cell Images\Uninfected"
IMAGE_SIZE = (128, 128)
VAL_SIZE = 0.05
BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 1_000
EPOCHS = 100
LEARNING_RATE = 0.002
PLOTS_DPI = 200
MODEL_NAME = 'Autoencoder_Reconstruction'
PLOTS_DIR = os.path.join('plots', MODEL_NAME)
TB_LOGS = "tensorboard_logs/Autoencoder_Reconstruction/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

## Data Loading

In [None]:
image_names = Dataset.list_files(os.path.join(BASE_PATH, '*.png'), seed = RANDOM_STATE)
image_count = image_names.cardinality().numpy()
print(f"\nTotal number of image files: {image_count}")

### Image data loading with augmentations

In [None]:
def load_augmented_images(file_path):
    img = tf.io.read_file(file_path)
    img = tf.io.decode_jpeg(img, channels = 3)
    img = tf.image.resize(img, IMAGE_SIZE)

    img_rotneg90 = tf.image.rot90(img, k = -1)
    img_rotpos90 = tf.image.rot90(img, k = 1)

    return img_rotneg90/255.0, img/255.0, img_rotpos90/255.0

image_data = image_names.map(load_augmented_images, num_parallel_calls = AUTOTUNE)
image_data = image_data.flat_map(lambda rotneg90, original, rotpos90: Dataset.from_tensor_slices([rotneg90, original, rotpos90]))


### Augmented data visualization

In [None]:
title_suffixes = ['Rotated -90', 'Original', 'Rotated +90']

plt.subplots(nrows = 6, ncols = 6, figsize = (18, 18))
plt.suptitle('Uninfected cell images with Image Augmentation', fontsize = 24)
plt.tight_layout(rect = [0, 0, 1, 0.97], h_pad = 2)

for i, img in enumerate(image_data.take(36)):
    plt.subplot(6, 6, i + 1)
    plt.axis(False)
    plt.grid(False)
    plt.title(f"Img {i + 1} - {title_suffixes[i % 3]}")
    plt.imshow(img.numpy())

### Data splitting

In [None]:
def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size = SHUFFLE_BUFFER_SIZE)
  ds = ds.batch(BATCH_SIZE)
  ds = ds.prefetch(buffer_size = AUTOTUNE)
  return ds

def create_autoencoder_dataset(img):
    return img, img

In [None]:
val_image_count = int(image_count * VAL_SIZE * 3)

train_images = image_data.skip(val_image_count)
val_images = image_data.take(val_image_count)

train_ds = configure_for_performance(train_images.map(create_autoencoder_dataset))
val_ds = configure_for_performance(val_images.map(create_autoencoder_dataset))

## Model Creation

### Input Layer

In [None]:
inputLayer = Input(shape = (*IMAGE_SIZE, 3), name = 'Input')

### Encoder

In [None]:
depth_conv = DepthwiseConv2D((4, 4), activation = LeakyReLU(), padding = 'same', depth_multiplier = 2, name = "Depth_Conv")(inputLayer)
depth_conv.shape

In [None]:
conv_1 = Conv2D(8, (4, 4), activation = LeakyReLU(), padding = 'same', name = "Enc_Conv_1")(depth_conv)
pool_1 = MaxPooling2D((4, 4), padding = 'same', name = "Enc_MaxPool_1")(conv_1)
pool_1.shape

In [None]:
conv_2 = Conv2D(16, (4, 4), activation = LeakyReLU(), padding = 'same', name = "Enc_Conv_2")(pool_1)
pool_2 = MaxPooling2D((4, 4), padding = 'same', name = "Enc_MaxPool_2")(conv_2)
pool_2.shape

In [None]:
conv_3 = Conv2D(32, (3, 3), activation = LeakyReLU(), padding = 'same', name = "Enc_Conv_3")(pool_2)
pool_3 = MaxPooling2D((2, 2), padding = 'same', name = "Enc_MaxPool_3")(conv_3)
pool_3.shape

In [None]:
conv_4 = Conv2D(64, (3, 3), activation = LeakyReLU(), padding = 'same', name = "Enc_Conv_4")(pool_3)
pool_4 = MaxPooling2D((2, 2), padding = 'same', name = "Enc_MaxPool_4")(conv_4)
pool_4.shape

### Decoder

In [None]:
conv_5 = Conv2D(64, (3, 3), activation = LeakyReLU(), padding = 'same', name = "Dec_Conv_1")(pool_4)
up_1 = UpSampling2D((2, 2), name = "Dec_Upsampling_1")(conv_5)
up_1.shape

In [None]:
conv_6 = Conv2D(32, (3, 3), activation = LeakyReLU(), padding = 'same', name = "Dec_Conv_2")(up_1)
up_2 = UpSampling2D((2, 2), name = "Dec_Upsampling_2")(conv_6)
up_2.shape

In [None]:
conv_7 = Conv2D(16, (4, 4), activation = LeakyReLU(), padding = 'same', name = "Dec_Conv_3")(up_2)
up_3 = UpSampling2D((4, 4), name = "Dec_Upsampling_3")(conv_7)
up_3.shape

In [None]:
conv_8 = Conv2D(8, (4, 4), activation = LeakyReLU(), padding = 'same', name = "Dec_Conv_4")(up_3)
up_4 = UpSampling2D((4, 4), name = "Dec_Upsampling_4")(conv_8)
up_4.shape

In [None]:
dropout = Dropout(0.1, name = "Dropout")(up_4)
dropout.shape

In [None]:
outputLayer = Conv2D(3, (1, 1), activation = 'sigmoid', padding = 'same', name = 'Reconstruction_Output')(dropout)
outputLayer.shape

### Model compilation

In [None]:
autoencoder = Model(inputs = inputLayer, outputs = outputLayer, name = MODEL_NAME)
autoencoder.compile(optimizer = Adam(LEARNING_RATE), loss = 'binary_crossentropy')

### Model Summary

In [None]:
tf.keras.utils.plot_model(autoencoder, to_file = os.path.join(PLOTS_DIR, 'model.jpg'), show_shapes = True, dpi = PLOTS_DPI)

In [None]:
autoencoder.summary()

## Model Training

### Callbacks

In [None]:
early_stop = EarlyStopping(monitor = 'val_loss', patience = 10, restore_best_weights = True)
tensorboard = TensorBoard(log_dir = TB_LOGS)
reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', factor = 0.2, patience = 4, verbose = 1, cooldown = 1)

### Training history

In [None]:
%%time

history = autoencoder.fit(
    train_ds, 
    epochs = EPOCHS, 
    verbose = 1, 
    validation_data = val_ds,
    callbacks = [early_stop, tensorboard, reduce_lr]
    )

## Model Evaluation

### Model Loss over Epochs

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = history.epoch

plt.figure(figsize = (16, 8))
plt.plot(epochs_range, loss, label = 'Training Loss')
plt.plot(epochs_range, val_loss, label = 'Validation Loss')
plt.legend(loc = 'upper right')
plt.title('Training and Validation Loss')
plt.savefig(os.path.join(PLOTS_DIR, 'acc_and_loss.jpg'), dpi = PLOTS_DPI, bbox_inches='tight')
plt.show()

### Prediction Visualization

In [None]:
plt.subplots(nrows = 6, ncols = 6, figsize = (18, 18))

plt.suptitle('Autoencoder predictions on uninfected cells', fontsize = 24)
plt.tight_layout(rect = [0, 0, 1, 0.97], h_pad = 2)

for i in val_ds.take(1):
    val_data = i[0][:12].numpy()

val_data = np.array(val_data)
pred = autoencoder.predict(val_data)
pred_error = val_data - pred
pred_error_min = pred_error.min(axis = (1, 2, 3)).reshape(12, 1, 1, 1)
pred_error_max = pred_error.max(axis = (1, 2, 3)).reshape(12, 1, 1, 1)
norm_error = (pred_error - pred_error_min)/(pred_error_max - pred_error_min)

for i in range(len(val_data)):
    plt.subplot(6, 6, (3 * i) + 1)
    plt.axis(False)
    plt.grid(False)
    plt.title(f"Original - {i + 1}")
    plt.imshow(val_data[i])

    plt.subplot(6, 6, (3 * i) + 2)
    plt.axis(False)
    plt.grid(False)
    plt.title(f"Prediction - {i + 1}")
    plt.imshow(pred[i])

    plt.subplot(6, 6, (3 * i) + 3)
    plt.axis(False)
    plt.grid(False)
    plt.title(f"Normalized Error - {i + 1}")
    plt.imshow(norm_error[i])

plt.savefig(os.path.join(PLOTS_DIR, 'predictions.jpg'), dpi = PLOTS_DPI, bbox_inches='tight')
plt.show()

## Model Saving

In [None]:
autoencoder.save(os.path.join('models', MODEL_NAME))