In [1]:
import tensorflow.keras.backend as K
import numpy as np
import dask as d
import os
import cv2
import json
import multiprocessing
import matplotlib.pyplot as plt

from glob import glob
from copy import deepcopy
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import resnet50

In [2]:
n_cores = multiprocessing.cpu_count()
n_cores

16

In [3]:
from fl_tissue_model_tools import data_prep, dev_config, models, defs
import fl_tissue_model_tools.preprocessing as prep

In [4]:
dirs = dev_config.get_dev_directories("../dev_paths.txt")

# Set up model training parameters

In [5]:
with open("../model_training/invasion_depth_training_values.json", 'r') as fp:
       training_values = json.load(fp)
training_values["rs_seed"] = None if (training_values["rs_seed"] == "None") else training_values["rs_seed"]

In [6]:
training_values

{'batch_size': 32,
 'frozen_epochs': 50,
 'fine_tune_epochs': 50,
 'val_split': 0.2,
 'early_stopping_patience': 25,
 'early_stopping_min_delta': 0.0001,
 'rs_seed': None}

In [7]:
root_data_path = f"{dirs.data_dir}/invasion_data/development"
model_training_path = f"{dirs.analysis_dir}/resnet50_invasion_model"
demo_model_training_path = f"{model_training_path}/demo_v1"
mcp_best_frozen_weights_file = f"{demo_model_training_path}/best_frozen_weights.h5"
mcp_best_finetune_weights_file = f"{demo_model_training_path}/best_finetune_weights.h5"


### General training parameters ###
resnet_inp_shape = (128, 128, 3)
# Binary classification -> only need 1 output unit
n_outputs = 1
seed = training_values["rs_seed"]
val_split = training_values["val_split"]
batch_size = training_values["batch_size"]
frozen_epochs = training_values["frozen_epochs"]
fine_tune_epochs = training_values["fine_tune_epochs"]
fine_tune_lr = 1e-5
class_labels = {"no_invasion": 0, "invasion": 1}


### Early stopping ###
es_criterion = "val_loss"
es_mode = "min"
# Update these depending on seriousness of experiment
es_patience = training_values["early_stopping_patience"]
es_min_delta = training_values["early_stopping_min_delta"]


### Model saving ###
mcp_criterion = "val_loss"
mcp_mode = "min"
mcp_best_only = True
# Need to set to True otherwise base model "layer" won't save/load properly
mcp_weights_only = True

In [8]:
data_prep.make_dir(demo_model_training_path)

# Prep for loading data

In [9]:
rs = np.random.RandomState(seed)

In [10]:
data_paths = {v: glob(f"{root_data_path}/train/{k}/*.tif") for k, v in class_labels.items()}
for k, v in data_paths.items():
    rs.shuffle(v)

In [11]:
data_counts = {k: len(v) for k, v in data_paths.items()}
val_counts = {k: round(v * val_split) for k, v in data_counts.items()}
train_counts = {k: v - val_counts[k] for k, v in data_counts.items()}
train_class_weights = prep.balanced_class_weights_from_counts(train_counts)

In [12]:
data_counts

{0: 454, 1: 102}

In [13]:
train_class_weights

{0: 0.6129476584022039, 1: 2.7134146341463414}

In [14]:
train_data_paths = {k: v[val_counts[k]:] for k, v in data_paths.items()}
val_data_paths = {k: v[:val_counts[k]] for k, v in data_paths.items()}

# Datasets

In [16]:
# With class weights
train_datagen = data_prep.InvasionDataGenerator(train_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)
val_datagen = data_prep.InvasionDataGenerator(val_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)
# # Without class weights
# train_datagen = InvasionDataGenerator(train_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, augmentation_function=prep.augment_imgs)
# val_datagen = InvasionDataGenerator(val_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, augmentation_function=prep.augment_imgs)

# Model

In [17]:
K.clear_session()

In [18]:
tl_model = models.build_ResNet50_TL(
    n_outputs,
    resnet_inp_shape,
    # base_last_layer="conv5_block3_out",
    # base_last_layer="conv5_block2_out",
    # base_last_layer="conv5_block1_out",
    base_last_layer="conv4_block6_out",
    # base_last_layer="conv4_block5_out",
    # base_last_layer="conv4_block4_out",
    # base_last_layer="conv4_block3_out",
    # base_last_layer="conv4_block2_out",
    # base_last_layer="conv4_block1_out",
    # base_last_layer="conv3_block4_out",
    # Switch to softmax once n_outputs > 1
    output_act="sigmoid",
    base_model_trainable=False
)
# tl_model.compile(optimizer=Adam(), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])
tl_model.compile(optimizer=Adam(), loss=BinaryCrossentropy(), weighted_metrics=[BinaryAccuracy()])

In [19]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 8, 8, 1024)        8589184   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
_________________________________________________________________
activation (Activation)      (None, 1)                 0         
Total params: 8,590,209
Trainable params: 1,025
Non-trainable params: 8,589,184
_________________________________________________________________


In [20]:
# tl_model.get_layer("base_model").summary()

In [21]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
cp_callback = ModelCheckpoint(mcp_best_frozen_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

In [22]:
h1 = tl_model.fit(
    train_datagen,
    validation_data=val_datagen,
    epochs=frozen_epochs,
    callbacks=[es_callback, cp_callback],
    workers=n_cores
)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [23]:
np.argmin(h1.history["val_loss"]), np.min(h1.history["val_loss"])

(33, 0.30371546745300293)

In [24]:
np.argmax(h1.history["val_binary_accuracy"]), np.max(h1.history["val_binary_accuracy"])

(48, 0.8894818425178528)

# Load best frozen weights before fine tuning

In [25]:
tl_model.load_weights(mcp_best_frozen_weights_file)

# Train model (all layers)

In [26]:
# Make base model trainable (leave layers in inference mode)
models.toggle_TL_freeze(tl_model)

In [27]:
# tl_model.compile(optimizer=Adam(learning_rate=fine_tune_lr), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])
tl_model.compile(optimizer=Adam(learning_rate=fine_tune_lr), loss=BinaryCrossentropy(), weighted_metrics=[BinaryAccuracy()])

In [28]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 8, 8, 1024)        8589184   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1025      
_________________________________________________________________
activation (Activation)      (None, 1)                 0         
Total params: 8,590,209
Trainable params: 8,559,617
Non-trainable params: 30,592
_________________________________________________________________


In [29]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
cp_callback = ModelCheckpoint(mcp_best_finetune_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

In [30]:
h2 = tl_model.fit(
    train_datagen,
    validation_data=val_datagen,
    epochs=fine_tune_epochs,
    callbacks=[es_callback, cp_callback],
    workers=n_cores
)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [31]:
np.argmin(h2.history["val_loss"]), np.min(h2.history["val_loss"])

(35, 0.2274419516324997)

In [32]:
np.argmax(h2.history["val_binary_accuracy"]), np.max(h2.history["val_binary_accuracy"])

(30, 0.9275327324867249)