In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tensorflow import data
from tensorflow.keras import Sequential, Input, Model
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Softmax
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.applications import resnet50

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.image import rgb_to_grayscale, grayscale_to_rgb
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
from fl_tissue_model_tools import data_prep
from fl_tissue_model_tools import models

In [9]:
root_data_path = f"D:/malaria_data/cell_images_split"
seed = 123
# Resize using "area" method
# Multiples of 32 work best with ResNet
resnet_inp_shape = (128, 128, 3)
# Binary classification -> only need 1 output unit
n_outputs = 1
val_split = 0.2
batch_size = 32
# n_epochs = 10
frozen_epochs = 10
fine_tune_epochs = 10
fine_tune_lr = 1e-5

# Early stopping
es_criterion = "val_loss"
es_mode = "min"
# Update these depending on seriousness of experiment
es_patience = 1
es_min_delta = 0.1

# Model saving
cp_criterion = "val_loss"
cp_mode = "min"
cp_frozen_filepath = "../../malaria_v1_output/resnet50_malaria_best_frozen_weights.h5"
cp_fine_tune_filepath = "../../malaria_v1_output/resnet50_malaria_best_fine_tune_weights.h5"
cp_best_only = True
# Need to set to True otherwise base model "layer" won't save/load properly
cp_weights_only = True

In [4]:
# Malaria images are RGB, want to see how will perform on grayscale
def rgb2gray2rbg_then_resnet50(img):
    img = rgb_to_grayscale(img)
    img = grayscale_to_rgb(img)
    return resnet50.preprocess_input(img)

# Data

## Data split

In [5]:
train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip=True,
    rotation_range=360,
    # Make sure to preprocess the same as the original model
    # preprocessing_function=resnet50.preprocess_input,
    preprocessing_function=rgb2gray2rbg_then_resnet50,
    # Use 20% of data for validation
    validation_split=val_split
)

In [6]:
test_datagen = ImageDataGenerator(
    # preprocessing_function=resnet50.preprocess_input
    preprocessing_function=rgb2gray2rbg_then_resnet50
)

## Generators

In [7]:
train_generator = train_datagen.flow_from_directory(
    f"{root_data_path}/train",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    # Want uninfected to be labeled "0"
    classes={"uninfected": 0, "parasitized": 1},
    class_mode="binary",
    seed=seed,
    subset="training"
)

val_generator = train_datagen.flow_from_directory(
    f"{root_data_path}/train",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    classes={"uninfected": 0, "parasitized": 1},
    class_mode="binary",
    seed=seed,
    subset="validation"
)

Found 17638 images belonging to 2 classes.
Found 4408 images belonging to 2 classes.


In [8]:
test_generator = test_datagen.flow_from_directory(
    f"{root_data_path}/test",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    classes={"uninfected": 0, "parasitized": 1},
    class_mode=None,
    shuffle=False
)

Found 5512 images belonging to 2 classes.


# Build model

In [10]:
tl_model = models.build_ResNet50_TL(
    n_outputs,
    resnet_inp_shape,
    # base_last_layer="conv5_block3_out",
    # base_last_layer="conv5_block2_out",
    # base_last_layer="conv5_block1_out",
    base_last_layer="conv4_block6_out",
    # base_last_layer="conv3_block4_out",
    # Switch to softmax once n_outputs > 1
    output_act="sigmoid",
    base_model_trainable=False
)
tl_model.compile(optimizer=Adam(), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])

In [11]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 8, 8, 1024)        8589184   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
Total params: 8,590,209
Trainable params: 1,025
Non-trainable params: 8,589,184
_________________________________________________________________


In [12]:
# tl_model.get_layer("base_model").summary()

# Train model (just new layers)

In [13]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
cp_callback = ModelCheckpoint(cp_frozen_filepath, monitor=cp_criterion, mode=cp_mode, save_best_only=cp_best_only, save_weights_only=cp_weights_only)

In [14]:
h1 = tl_model.fit(
    train_generator,
    validation_data=val_generator,
    steps_per_epoch=train_generator.n // batch_size,
    validation_steps=val_generator.n // batch_size,
    epochs=frozen_epochs,
    callbacks=[es_callback, cp_callback]
)

Epoch 1/10
Epoch 2/10
 68/551 [==>...........................] - ETA: 1:57 - loss: 0.2696 - binary_accuracy: 0.8915

KeyboardInterrupt: 

# Load best frozen weights before fine tuning

In [None]:
tl_model.load_weights(cp_frozen_filepath)

# Train model (all layers)

In [None]:
# Make base model trainable (leave layers in inference mode)
models.toggle_TL_freeze(tl_model)

In [None]:
tl_model.compile(optimizer=Adam(learning_rate=fine_tune_lr), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])

In [None]:
tl_model.summary()

In [None]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
cp_callback = ModelCheckpoint(cp_fine_tune_filepath, monitor=cp_criterion, mode=cp_mode, save_best_only=cp_best_only, save_weights_only=cp_weights_only)

In [None]:
h2 = tl_model.fit(
    train_generator,
    validation_data=val_generator,
    steps_per_epoch=train_generator.n // batch_size,
    validation_steps=val_generator.n // batch_size,
    epochs=fine_tune_epochs,
    callbacks=[es_callback, cp_callback]
)