In [1]:
import tensorflow.keras.backend as K
import numpy as np
import dask as d
import os
import cv2
import json
import multiprocessing
import matplotlib.pyplot as plt

from glob import glob
from copy import deepcopy
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import resnet50

In [2]:
n_cores = multiprocessing.cpu_count()
n_cores

16

In [3]:
from fl_tissue_model_tools import data_prep, dev_config, models, defs
import fl_tissue_model_tools.preprocessing as prep

In [4]:
dirs = dev_config.get_dev_directories("../dev_paths.txt")

# Set up model training parameters

In [5]:
with open("../model_training/invasion_depth_training_values.json", 'r') as fp:
       training_values = json.load(fp)
training_values["rs_seed"] = None if (training_values["rs_seed"] == "None") else training_values["rs_seed"]

In [6]:
training_values

{'batch_size': 32,
 'frozen_epochs': 50,
 'fine_tune_epochs': 50,
 'val_split': 0.2,
 'early_stopping_patience': 25,
 'early_stopping_min_delta': 0.0001,
 'rs_seed': None}

In [7]:
with open("../model_training/invasion_depth_best_hp.json", 'r') as fp:
    best_hp = json.load(fp)

In [8]:
best_hp

{'adam_beta_1': 0.9066466810625207,
 'adam_beta_2': 0.9947698608592066,
 'fine_tune_lr': 9.141354728081903e-05,
 'frozen_lr': 0.00022767552973325001,
 'last_resnet_layer': 'conv5_block1_out'}

In [9]:
root_data_path = f"{dirs.data_dir}/invasion_data/development"
model_training_path = f"{dirs.analysis_dir}/resnet50_invasion_model"
demo_model_training_path = f"{model_training_path}/demo"
mcp_best_frozen_weights_file = f"{demo_model_training_path}/best_frozen_weights.h5"
mcp_best_finetune_weights_file = f"{demo_model_training_path}/best_finetune_weights.h5"


### General training parameters ###
resnet_inp_shape = (128, 128, 3)
# Binary classification -> only need 1 output unit
n_outputs = 1
seed = training_values["rs_seed"]
val_split = training_values["val_split"]
batch_size = training_values["batch_size"]
# frozen_epochs = training_values["frozen_epochs"]
# fine_tune_epochs = training_values["fine_tune_epochs"]
frozen_epochs = 10
fine_tune_epochs = 10
adam_beta_1 = best_hp["adam_beta_1"]
adam_beta_2 = best_hp["adam_beta_2"]
frozen_lr = best_hp["frozen_lr"]
fine_tune_lr = best_hp["fine_tune_lr"]
class_labels = {"no_invasion": 0, "invasion": 1}
last_resnet_layer = best_hp["last_resnet_layer"]


### Early stopping ###
es_criterion = "val_loss"
es_mode = "min"
# Update these depending on seriousness of experiment
es_patience = training_values["early_stopping_patience"]
es_min_delta = training_values["early_stopping_min_delta"]


### Model saving ###
mcp_criterion = "val_loss"
mcp_mode = "min"
mcp_best_only = True
# Need to set to True otherwise base model "layer" won't save/load properly
mcp_weights_only = True

In [10]:
data_prep.make_dir(demo_model_training_path)

# Prep for loading data

In [11]:
rs = np.random.RandomState(seed)

In [12]:
tv_class_paths = {v: glob(f"{root_data_path}/train/{k}/*.tif") for k, v in class_labels.items()}
for k, v in tv_class_paths.items():
    rs.shuffle(v)

In [13]:
# TODO: handle this in the InvasionDataGenerator class
tv_counts = {k: len(v) for k, v in tv_class_paths.items()}
val_counts = {k: round(v * val_split) for k, v in tv_counts.items()}
train_counts = {k: v - val_counts[k] for k, v in tv_counts.items()}
train_class_weights = prep.balanced_class_weights_from_counts(train_counts)

In [14]:
tv_counts

{0: 454, 1: 102}

In [15]:
train_counts

{0: 363, 1: 82}

In [16]:
val_counts

{0: 91, 1: 20}

In [17]:
train_class_weights

{0: 0.6129476584022039, 1: 2.7134146341463414}

In [18]:
train_data_paths = {k: v[val_counts[k]:] for k, v in tv_class_paths.items()}
val_data_paths = {k: v[:val_counts[k]] for k, v in tv_class_paths.items()}

# Datasets

In [19]:
# With class weights
train_datagen = data_prep.InvasionDataGenerator(train_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)
val_datagen = data_prep.InvasionDataGenerator(val_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)

In [20]:
val_datagen.class_counts

{0: 91, 1: 20}

# Model

In [21]:
K.clear_session()

In [22]:
tl_model = models.build_ResNet50_TL(
    n_outputs,
    resnet_inp_shape,
    base_last_layer=last_resnet_layer,
    # Switch to softmax once n_outputs > 1
    output_act="sigmoid",
    base_model_trainable=False
)

# Frozen training

In [23]:
tl_model.compile(
    optimizer=Adam(learning_rate=frozen_lr, beta_1=adam_beta_1, beta_2=adam_beta_2),
    loss=BinaryCrossentropy(),
    weighted_metrics=[BinaryAccuracy()]
)

In [24]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 4, 4, 2048)        14644096  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 2049      
_________________________________________________________________
activation (Activation)      (None, 1)                 0         
Total params: 14,646,145
Trainable params: 2,049
Non-trainable params: 14,644,096
_________________________________________________________________


In [25]:
# tl_model.get_layer("base_model").summary()

In [26]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
mcp_callback = ModelCheckpoint(mcp_best_frozen_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

In [27]:
h1 = tl_model.fit(
    train_datagen,
    validation_data=val_datagen,
    epochs=frozen_epochs,
    callbacks=[es_callback, mcp_callback],
    workers=n_cores
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [28]:
np.argmin(h1.history["val_loss"]), np.min(h1.history["val_loss"])

(7, 0.417693167924881)

In [29]:
np.argmax(h1.history["val_binary_accuracy"]), np.max(h1.history["val_binary_accuracy"])

(8, 0.8607977628707886)

# Load best frozen weights before fine tuning

In [30]:
tl_model.load_weights(mcp_best_frozen_weights_file)

# Fine tune training

In [31]:
# Make base model trainable (leave layers in inference mode)
models.toggle_TL_freeze(tl_model)

In [32]:
# tl_model.compile(optimizer=Adam(learning_rate=fine_tune_lr), loss=BinaryCrossentropy(), metrics=[BinaryAccuracy()])
tl_model.compile(
    optimizer=Adam(learning_rate=fine_tune_lr, beta_1=adam_beta_1, beta_2=adam_beta_2),
    loss=BinaryCrossentropy(),
    weighted_metrics=[BinaryAccuracy()]
)

In [33]:
tl_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128, 128, 3)]     0         
_________________________________________________________________
base_model (Functional)      (None, 4, 4, 2048)        14644096  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 2049      
_________________________________________________________________
activation (Activation)      (None, 1)                 0         
Total params: 14,646,145
Trainable params: 14,605,313
Non-trainable params: 40,832
_________________________________________________________________


In [34]:
es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
mcp_callback = ModelCheckpoint(mcp_best_finetune_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

In [35]:
h2 = tl_model.fit(
    train_datagen,
    validation_data=val_datagen,
    epochs=fine_tune_epochs,
    callbacks=[es_callback, mcp_callback],
    workers=n_cores
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [36]:
np.argmin(h2.history["val_loss"]), np.min(h2.history["val_loss"])

(9, 0.2562311589717865)

In [37]:
np.argmax(h2.history["val_binary_accuracy"]), np.max(h2.history["val_binary_accuracy"])

(8, 0.890742301940918)