In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tensorflow import data
from tensorflow.keras import Sequential, Input, Model
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Softmax
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.applications import resnet50

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
from fl_tissue_model_tools import data_prep
from fl_tissue_model_tools import models

In [3]:
root_data_path = f"D:/malaria_data/cell_images_split"
seed = 123
# Resize using "area" method
# Multiples of 32 work best with ResNet
resnet_inp_shape = (128, 128, 3)
val_split = 0.2
batch_size = 32
n_epochs = 2
fine_tune_lr = 1e-5

# Data

## Data split

In [4]:
train_datagen = ImageDataGenerator(
    horizontal_flip=True,
    vertical_flip=True,
    # Make sure to preprocess the same as the original model
    preprocessing_function=resnet50.preprocess_input,
    # Use 20% of data for validation
    validation_split=val_split
)

In [5]:
test_datagen = ImageDataGenerator(
    preprocessing_function=resnet50.preprocess_input
)

## Generators

In [6]:
train_generator = train_datagen.flow_from_directory(
    f"{root_data_path}/train",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    # Want uninfected to be labeled "0"
    classes={"uninfected": 0, "parasitized": 1},
    class_mode="binary",
    seed=seed,
    subset="training"
)

val_generator = train_datagen.flow_from_directory(
    f"{root_data_path}/train",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    classes={"uninfected": 0, "parasitized": 1},
    class_mode="binary",
    seed=seed,
    subset="validation"
)

Found 17638 images belonging to 2 classes.
Found 4408 images belonging to 2 classes.


In [7]:
test_generator = test_datagen.flow_from_directory(
    f"{root_data_path}/test",
    target_size=resnet_inp_shape[:2],
    batch_size=batch_size,
    classes={"uninfected": 0, "parasitized": 1},
    class_mode=None,
    shuffle=False
)

Found 5512 images belonging to 2 classes.


# Build model

In [8]:
tl_model = models.build_ResNet50_TL(
    resnet_inp_shape,
    # base_last_layer="conv5_block3_out",
    # base_last_layer="conv5_block2_out",
    # base_last_layer="conv5_block1_out",
    base_last_layer="conv4_block6_out",
    # base_last_layer="conv3_block4_out",
    output_act="sigmoid",
    base_model_trainable=False
)
tl_model.compile(optimizer=Adam(), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])

In [9]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 8, 8, 1024)        8589184   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
Total params: 8,590,209
Trainable params: 1,025
Non-trainable params: 8,589,184
_________________________________________________________________


In [10]:
tl_model.get_layer("base_model").summary()

Model: "base_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 134, 134, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 64, 64, 64)   9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 64, 64, 64)   256         conv1_conv[0][0]                 
_________________________________________________________________________________________

# Train model (just new layers)

In [11]:
h1 = tl_model.fit(
    train_generator,
    validation_data=val_generator,
    steps_per_epoch=train_generator.n // batch_size,
    validation_steps=val_generator.n // batch_size,
    epochs=n_epochs
)

Epoch 1/2
Epoch 2/2


# Train model (all layers)

In [12]:
# Should this be model_0?
# model_1.trainable = True
models.toggle_TL_freeze(tl_model)

In [13]:
tl_model.compile(optimizer=Adam(learning_rate=fine_tune_lr), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])

In [14]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 8, 8, 1024)        8589184   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
Total params: 8,590,209
Trainable params: 8,559,617
Non-trainable params: 30,592
_________________________________________________________________


In [15]:
h2 = tl_model.fit(
    train_generator,
    validation_data=val_generator,
    steps_per_epoch=train_generator.n // batch_size,
    validation_steps=val_generator.n // batch_size,
    epochs=n_epochs // 2
)

