In [1]:
import tensorflow.keras.backend as K
import numpy as np
import pandas as pd
import json
import multiprocessing
import matplotlib.pyplot as plt

from glob import glob
from copy import deepcopy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [2]:
n_cores = multiprocessing.cpu_count()
n_cores

16

In [3]:
from fl_tissue_model_tools import data_prep, dev_config, models, defs
import fl_tissue_model_tools.preprocessing as prep

In [4]:
dirs = dev_config.get_dev_directories("../dev_paths.txt")

# Set up model training parameters

In [5]:
with open("../model_training/invasion_depth_training_values.json", 'r') as fp:
    training_values = json.load(fp)
training_values["rs_seed"] = None if (training_values["rs_seed"] == "None") else training_values["rs_seed"]

In [6]:
training_values

{'batch_size': 32,
 'frozen_epochs': 50,
 'fine_tune_epochs': 50,
 'val_split': 0.2,
 'early_stopping_patience': 25,
 'early_stopping_min_delta': 0.0001,
 'rs_seed': None}

In [7]:
with open("../model_training/invasion_depth_best_hp.json", 'r') as fp:
    best_hp = json.load(fp)

In [8]:
best_hp

{'adam_beta_1': 0.9066466810625207,
 'adam_beta_2': 0.9947698608592066,
 'fine_tune_lr': 9.141354728081903e-05,
 'frozen_lr': 0.00022767552973325001,
 'last_resnet_layer': 'conv5_block1_out'}

In [9]:
root_data_path = f"{dirs.data_dir}/invasion_data/development"
model_training_path = f"{dirs.analysis_dir}/resnet50_invasion_model"
best_ensemble_training_path = f"{model_training_path}/best_ensemble"


### Model building & training parameters ###
resnet_inp_shape = (128, 128, 3)
class_labels = {"no_invasion": 0, "invasion": 1}

# Binary classification -> only need 1 output unit
n_outputs = 1
seed = training_values["rs_seed"]
val_split = training_values["val_split"]
batch_size = training_values["batch_size"]
frozen_epochs = training_values["frozen_epochs"]
fine_tune_epochs = training_values["fine_tune_epochs"]
adam_beta_1 = best_hp["adam_beta_1"]
adam_beta_2 = best_hp["adam_beta_2"]
frozen_lr = best_hp["frozen_lr"]
fine_tune_lr = best_hp["fine_tune_lr"]
last_resnet_layer = best_hp["last_resnet_layer"]


### Early stopping ###
es_criterion = "val_loss"
es_mode = "min"
# Update these depending on seriousness of experiment
es_patience = training_values["early_stopping_patience"]
es_min_delta = training_values["early_stopping_min_delta"]


### Model saving ###
mcp_criterion = "val_loss"
mcp_mode = "min"
mcp_best_only = True
# Need to set to True otherwise base model "layer" won't save/load properly
mcp_weights_only = True

In [10]:
data_prep.make_dir(best_ensemble_training_path)

In [11]:
rs = np.random.RandomState(seed)

# Train ensemble of models using best hyperparameters

In [12]:
n_models = 5

In [13]:
for model_idx in range(n_models):
    ### Prepare data (each model should be trained on a randomly assigned train/validation set) ###
    data_paths = {v: glob(f"{root_data_path}/train/{k}/*.tif") for k, v in class_labels.items()}
    mcp_best_frozen_weights_file = f"{best_ensemble_training_path}/best_frozen_weights_{model_idx}.h5"
    mcp_best_finetune_weights_file = f"{best_ensemble_training_path}/best_finetune_weights_{model_idx}.h5"

    for k, v in data_paths.items():
        rs.shuffle(v)

    data_counts = {k: len(v) for k, v in data_paths.items()}
    val_counts = {k: round(v * val_split) for k, v in data_counts.items()}
    train_counts = {k: v - val_counts[k] for k, v in data_counts.items()}
    train_class_weights = prep.balanced_class_weights_from_counts(train_counts)

    train_data_paths = {k: v[val_counts[k]:] for k, v in data_paths.items()}
    val_data_paths = {k: v[:val_counts[k]] for k, v in data_paths.items()}


    ### Build datasets ###
    # With class weights
    train_datagen = data_prep.InvasionDataGenerator(train_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)
    val_datagen = data_prep.InvasionDataGenerator(val_data_paths, class_labels, batch_size, resnet_inp_shape[:2], rs, class_weights=train_class_weights, augmentation_function=prep.augment_imgs)


    ### Build model ###
    K.clear_session()
    tl_model = models.build_ResNet50_TL(
        n_outputs,
        resnet_inp_shape,
        base_last_layer=last_resnet_layer,
        base_model_trainable=False
    )


    ### Frozen training ###
    tl_model.compile(
        optimizer=Adam(learning_rate=frozen_lr, beta_1=adam_beta_1, beta_2=adam_beta_2),
        loss=BinaryCrossentropy(),
        weighted_metrics=[BinaryAccuracy()]
    )

    es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
    mcp_callback = ModelCheckpoint(mcp_best_frozen_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

    h1 = tl_model.fit(
        train_datagen,
        validation_data=val_datagen,
        epochs=frozen_epochs,
        callbacks=[es_callback, mcp_callback],
        workers=n_cores
    )


    ### Finetune training ###
    models.toggle_TL_freeze(tl_model)
    tl_model.compile(
        optimizer=Adam(learning_rate=fine_tune_lr, beta_1=adam_beta_1, beta_2=adam_beta_2),
        loss=BinaryCrossentropy(),
        weighted_metrics=[BinaryAccuracy()]
    )

    es_callback = EarlyStopping(monitor=es_criterion, mode=es_mode, min_delta=es_min_delta, patience=es_patience)
    mcp_callback = ModelCheckpoint(mcp_best_finetune_weights_file, monitor=mcp_criterion, mode=mcp_mode, save_best_only=mcp_best_only, save_weights_only=mcp_weights_only)

    h2 = tl_model.fit(
        train_datagen,
        validation_data=val_datagen,
        epochs=fine_tune_epochs,
        callbacks=[es_callback, mcp_callback],
        workers=n_cores
    )


    ### Save results for comparison later ###
    h1_df = pd.DataFrame(h1.history)
    h1_df["training_stage"] = ["frozen"] * len(h1.epoch)

    h2_df = pd.DataFrame(h2.history)
    h2_df["training_stage"] = ["finetune"] * len(h2.epoch)

    h_df = pd.concat([h1_df, h2_df], axis=0, ignore_index=True)

    h_df.to_csv(f"{best_ensemble_training_path}/best_model_history_{model_idx}.csv", index=False)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/5