In [None]:
import matplotlib as mpl
mpl.use('Agg')  # don't display mpl windows (will cause error in non-gui environment)

from collections import defaultdict
import math
import os
import shutil
import pandas as pd
from bayes_opt import BayesianOptimization
import keras

import core.history as ch
import core.fine_model as cm
from core.fine_model import FineModel

import cr_interface as cri
import keras_utils as ku
import analysis
from lib import Timer, notify
import traceback

In [None]:
BATCH_SIZE = 32

# number of folds
K = 5

# save intermediate weights
SAVE_ALL_WEIGHTS = True

# interval for saving intermediate weights (in epochs)
T = 10

# multiplier for out_of_myocardial (OAP, OBS) slices
BALANCE = 5

LEARNING_RATES = [
    0.01, 0.00001
]

EPOCHS = 2

# experiment index to track saved model weights, training history etc.
# iterate this index for each run (make sure to keep track of this index)
EXP = 98

# whether to sample 10% of all slices (for sanity checking purposes)
SAMPLE = False

# seed for k-fold split
K_SPLIT_SEED = 1

# models to train
MODEL_KEYS = [
    #'xception',
    'mobileneta25',
    #'mobilenetv2a35',
    #'vgg16',
    #'resnet50v2',
    #'inception_v3',
    #'inception_resnet_v2',
    #'densenet121',
    #'nasnet_mobile',
]

In [None]:
def optimize_learning_rate(fm: FineModel, depth_index, train_gens, val_gens, test_gen):
    """
    Train the fine model (frozen at some given depth) for all five folds of data,
    and choose the optimal learning rate BASED ON THE FINAL VALIDATION ACCURACY.
    Consider learning rates defined in the global variable LEARNING_RATES
    
    
    Save model with the following KEYS: [load weights via fm.load_weights(KEY)]
    EXP01_D01
    Fully trained model for the optimal learning rate
    
    
    :param fm:
    FineModel to train, i.e., the base network to train on
    
    :param depth_index:
    The INDEX of the "freeze depth" for the given FineModel
    
    :param train_gens
    List of train ImageDataGenerators for each fold
    
    :param val_gens  
    List of validation ImageDataGenerators for each fold
    
    :param val_gens  
    Test ImageDataGenerator for each fold
    
    :return: None
    """
    
    
def train_model_all_folds(fm, depth_index, lr_index,
                          epochs, train_gens, val_gens, test_gen):
    """
    Train the model (frozen at some depth) for all five folds


    Saves intermediate models with the following KEYS: [load weights via fm.load_weights(KEY)]
    EXP01_D01_L03_F01:
    Fully trained model for the 1st freeze depth, 3rd learning rate, fold 1
    EXP01_D01_L03_F01_E025:
    Partially trained model for the 1st freeze depth, 3rd learning rate, fold 1, until the 25th epoch

    Saves training history with the following KEYS: [get data via ch.get_history(model_name, KEY)]
    EXP01_D01_L03_F01:
    Training history for the 1st freeze depth, 3rd learning rate, fold 1


    :param fm:
    FineModel to train, i.e., the base network to train on

    :param depth_index:
    The INDEX of the "freeze depth" for the given FineModel

    :param lr_index:
    The INDEX of the learning rate, i.e., lr = LEARNING_RATES[lr_index]

    :param epochs:
    Number of epochs to train. MUST BE MULTIPLE OF 5.

    :param train_gens
    List of train ImageDataGenerators for each fold

    :param val_gens
    List of validation ImageDataGenerators for each fold

    :param val_gens
    Test ImageDataGenerator for each fold

    :return:
    tuple(val_loss, val_acc): AVERAGE validation loss and accuracy at FINAL EPOCH
    """
    _exp_key = 'EXP{:02}'.format(EXP)
    _depth_key = _exp_key + '_D{:02}'
    _fold_key = _depth_key + '_L{:02}_F{:02}'
    _epoch_key = _fold_key + '_E{:03}'

    lr = LEARNING_RATES[lr_index]
    loss_list = []
    acc_list = []

    # train the model K times, one for each fold
    for i in range(K):
        # load model at previous state
        previous_depth_index = depth_index - 1
        if previous_depth_index < 0:
            fm.reload_model()
        else:
            fm.load_weights(_depth_key.format(previous_depth_index))
        fm.set_depth(depth_index)
        fm.compile_model(lr=lr)
        model = fm.get_model()

        print('[debug] batch: {}'.format(BATCH_SIZE))
        print('[debug] size: {}'.format(train_gens[i].n))
        print('[debug] steps: {}'.format(len(train_gens[i])))

        # train T epochs at a time
        start_epoch = 0
        save_interval = T
        while start_epoch < epochs:
            print('[debug] epoch {}'.format(start_epoch))
            target_epoch = start_epoch + save_interval
            if target_epoch > epochs:
                target_epoch = epochs
            result = model.fit_generator(
                train_gens[i],
                validation_data=val_gens[i],
                steps_per_epoch=len(train_gens[i]),
                validation_steps=len(val_gens[i]),
                workers=16,
                use_multiprocessing=True,
                shuffle=True,
                epochs=target_epoch,
                initial_epoch=start_epoch,
            )
            start_epoch = target_epoch

            # update training history
            ch.append_history(result.history, fm.get_name(), _fold_key.format(
                depth_index, lr_index, i
            ))
            
            if SAVE_ALL_WEIGHTS:
                # save intermediate weights
                fm.save_weights(_epoch_key.format(
                    depth_index, lr_index, i, target_epoch,
                ))

        # save final weights
        fm.save_weights(_fold_key.format(
            depth_index, lr_index, i
        ))
        
        print('[debug] test size: {}'.format(test_gen.n))
        print('[debug] test steps: {}'.format(len(test_gen)))

        loss, acc = model.evaluate_generator(
            test_gen,
            steps=len(test_gen),
            #workers=4,
            #use_multiprogressing=True,
        )

        print('[debug] test_loss={}, test_acc={}'.format(loss, acc))

        loss_list.append(loss)
        acc_list.append(acc)
    
    print('Exporting analysis')
    for metric in analysis.metric_names.keys():
        analysis.analyze_lr(fm, fm.get_name(), depth_index, lr_index, lr, metric, exp=EXP)

    total_loss = 0
    for loss in loss_list:
        total_loss += loss
    avg_loss = total_loss / K

    total_acc = 0
    for acc in acc_list:
        total_acc += acc
    avg_acc = total_acc / K

    print('[debug] avg_test_loss={}, avg_test_acc={}'.format(avg_loss, avg_acc))

    return avg_loss, avg_acc


In [None]:
def print_all_stats(train, test, folds):
    # Print stats for each train/test set
    def print_stats(collection):
        df = collection.df
        print('{:<3} patients / {:<4} images'.format(df.pid.unique().shape[0], df.shape[0]))
        print(df.label.value_counts().to_string())

    print('Training/Validation Set'.center(80, '-'))
    print_stats(train)

    print('Test Set'.center(80, '-'))
    print_stats(test)

    print()
    print('Note that OAP, OBS images in the training/validation set will be duplicated 5 times')
    print('to solve the class imbalance issue')
    print()

    # Print number of images by fold by label (training data)
    stats = dict()
    for i, fold in enumerate(folds):
        counts = fold.df.label.value_counts()
        counts.loc['total'] = fold.df.shape[0]
        stats[i + 1] = counts
    stats = pd.DataFrame(stats)

    print('5-Fold Training Set Data'.center(80, '-'))
    print(stats.to_string(col_space=8))
    print()

    # Columnwise-print or cr_codes (training data)
    cr_codes_by_fold = list(sorted(fold.df.pid.unique()) for fold in folds)
    max_len = 0
    for codes in cr_codes_by_fold:
        if max_len < len(codes):
            max_len = len(codes)
    for i, _ in  enumerate(folds):
        print('Fold {}'.format(i + 1).ljust(16), end='')
    print()
    print('-' * 80)
    for i in range(max_len):
        for codes in cr_codes_by_fold:
            if i < len(codes):
                print('{:<16d}'.format(codes[i]), end='')
            else:
                print('{:<16s}'.format(''), end='')
        print()
    print()

In [None]:
def run_on_model(model_key, train_folds, test):
    print(' MODEL: {} '.format(model_key).center(100, '#'))
    keras.backend.clear_session()
    models = FineModel.get_dict()
    fm = models[model_key]()
    train_gens, val_gens = fm.get_train_val_generators(train_folds)
    test_gen = fm.get_test_generator(test)
    for i, lr in enumerate(LEARNING_RATES):
        print('Starting training {} lr={}'.format(fm.get_name(), lr).center(100, '-'))
        train_model_all_folds(fm, 0, i, EPOCHS, train_gens, val_gens, test_gen)

In [None]:
def main():
    train = cri.CrCollection.load().filter_by(dataset_index=0).tri_label().labeled()
    test = cri.CrCollection.load().filter_by(dataset_index=1).tri_label().labeled()
    if SAMPLE:
        train = train.sample(frac=0.1)
        test = test.sample(frac=0.1)
    folds = train.k_split(K, seed=K_SPLIT_SEED)
    
    print_all_stats(train, test, folds)
    
    for key in MODEL_KEYS:
        run_on_model(key, folds, test)

In [None]:
try:
    main()
except Exception as e:
    error = traceback.format_exc()
    error += '\n'
    error += str(e)
    print(error)
    notify(error)