# Transfer learning for Melanoma Classification using EfficientNetB0


## 1. Set up

#### Set up for importing utilities

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

#### Install efficientnet

In [None]:
import efficientnet.tfkeras as efn

 #### Random State

In [None]:
from utilities import random_state
random_state = random_state() 

#### Export Config

In [None]:
export_results = True
on_cluster = True # there is some issue with h5py in the local env but it works on cluster so w/e
export_folder = f'results/Transfer_Learning/161122-unfreeze50'

date_format = "%d%m%Y%H%M%S" # timestamp format in exported files
if export_results:
    import datetime
    import os
    if not os.path.exists(export_folder): 
      os.makedirs(export_folder)
      print("Created new directory %s" %export_folder)

#### Resolution Setup

In [None]:
img_pixel = 224 # default : 128

#### Timer

Start the timer

In [None]:
import time
start = time.time()

## 2. Get data

#### Data config

In [None]:
downsampled_data = True
base_path = "data/30" 
current_train_melanoma_percentage = 0.3

#### Get image paths
For developing models on the cluster the max_images parameter should be removed. Instead call the method get_all_img_paths(img_folder).

In [None]:
from utilities import get_all_img_paths, get_img_paths

img_folder_train = base_path + "/train" + ("_downsampled" if downsampled_data else "")
img_folder_test = base_path + "/test" + ("_downsampled" if downsampled_data else "")

max_images_train = int(13653*1)
max_images_test = int(5804*1)

img_paths_train = get_img_paths(img_folder_train, max_images_train) 
img_paths_test = get_img_paths(img_folder_test, max_images_test)

#### Load data
Loads the images specified in img_paths into a data frame. This includes resizing the images and flattening them into an array and may take a while.

In [None]:
from utilities import load_train_test

groundtruth_file_train = base_path + "/ISIC_2020_2019_train" + ("_downsampled" if downsampled_data else "") + ".csv" 
groundtruth_file_test = base_path + "/ISIC_2020_2019_test" + ("_downsampled" if downsampled_data else "") + ".csv"


# available options
options = ["sequential", # first load train, then load test
           "parallel_train_test", # load train and test parallel (load data within train and test sequential)
           "sequential_train_test_parallel_chunks", # load first train, then test, but load the data within the sets parallel
           "parallel_fusion" # run train and test parallel and parallely load data with train and test 
          ]

# chose an option
option = "parallel_fusion"

df_train, df_test = load_train_test(img_paths_train, groundtruth_file_train, 
                                    img_paths_test, groundtruth_file_test, 
                                    option, img_pixel);




#### Split into target and predictors

In [None]:
from utilities import split_predictors_target

X_train, y_train = split_predictors_target(df_train) 
X_test, y_test = split_predictors_target(df_test) 

In [None]:
from utilities import unflatten_images_df

# this should rlly be optimized so that this step is no longer necessary lol
X_train = unflatten_images_df(X_train, img_pixel=img_pixel)
X_test = unflatten_images_df(X_test, img_pixel=img_pixel)

#### Delete some no longer needed data

In [None]:
import gc

del(df_train)
del(df_test)
del(img_paths_train)
del(img_paths_test)
del(img_folder_train)
del(img_folder_test)
del(groundtruth_file_train)
del(groundtruth_file_test)

gc.collect()

## 3. Train model

#### Util functions for training etc.

In [None]:
from keras.callbacks import EarlyStopping
from keras.callbacks import ReduceLROnPlateau 
from keras.callbacks import CSVLogger

melanoma_weight = (1/current_train_melanoma_percentage)/2
class_weight = {0: 1.,
                1: melanoma_weight,}

lr_scheduler = ReduceLROnPlateau(factor=0.5, patience=5)
            
def unfreeze_model(model, base_model, num_layers = 20): # a function to unfreeze num_layers layers of a model
    model.trainable = True
    for layer in base_model.layers[:-num_layers]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = False
            
def compile_model(model, lr=0.001):
    model.compile(loss='binary_crossentropy', optimizer=Nadam(lr), metrics=['accuracy', Recall(name="recall")])
    
def fit_model(model):
    callbacks = [EarlyStopping(patience=10), lr_scheduler]
    if export_results:
        filename = export_folder + "/training_log_"+datetime.datetime.now().strftime(date_format)+".csv"
        callbacks.append(CSVLogger(filename, separator=",", append=False))

    history = model.fit(X_train, y_train, 
                        validation_data=(X_test, y_test), 
                        batch_size=32, epochs=100,
                        class_weight=class_weight,
                        callbacks=callbacks)
    return history

#### Util functions for evaluation

In [None]:
from matplotlib import pyplot
import seaborn as sns
from sklearn.metrics import confusion_matrix

def eval_model(model, history, name):
    display_history(history,name)
    gc.collect()
    display_confusion_matrix(model,name)

def display_history(history, name):
    _, axs = pyplot.subplots(3, 1, figsize=(20,15))

    # plot loss during training
    axs[0].plot(history_frozen.history['loss'], label='train')
    axs[0].plot(history_frozen.history['val_loss'], label='test')
    axs[0].set_title("Loss")
    axs[0].legend()

    # plot accuracy during training
    axs[1].plot(history_frozen.history['accuracy'], label='train')
    axs[1].plot(history_frozen.history['val_accuracy'], label='test')
    axs[1].set_title("Accuracy")
    axs[1].legend()

    # plot recall during training
    axs[2].plot(history_frozen.history['recall'], label='train')
    axs[2].plot(history_frozen.history['val_recall'], label='test')
    axs[2].set_title("Recall")
    axs[2].legend()

    if export_results:
        pyplot.savefig(export_folder + "/loss_and_accuracy_during_training_"+name+"_"+datetime.datetime.now().strftime(date_format)+".png")

    pyplot.show()
    
def display_confusion_matrix(mode, name):
    y_pred_continuous = model.predict(X_test)
    y_pred_discrete = (model.predict(X_test) > 0.5).astype("int32")
    y_pred = y_pred_discrete
    
    from sklearn.metrics import classification_report
    report = classification_report(y_test, y_pred, digits=4)
    print(f'\nClassification_report=\n{report}')

    if export_results:
        file = open(export_folder + "/classification_report_"+name+"_"+datetime.datetime.now().strftime(date_format)+".txt", 'w')
        file.write(report)
        file.close()
        
    class_names = ["no melanoma", "melanoma"]

    cf = confusion_matrix(y_test, y_pred)
    plot = sns.heatmap(cf, annot= True, fmt=".0f",
               xticklabels = class_names,
               yticklabels = class_names)
    plot.set(xlabel='Prediction', ylabel='Actual')

    if export_results:
        plot.get_figure().savefig(export_folder + '/confusion_matrix_' +name + "_" + datetime.datetime.now().strftime(date_format) + ".png")

#### Build model

In [None]:
import sklearn 
from tensorflow.keras.models import Sequential
from tensorflow.keras import optimizers, losses, activations, models
from tensorflow.keras.layers import Convolution2D, Dense, Flatten, Dropout, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D, Concatenate
from tensorflow.keras.metrics import Recall
from tensorflow.keras.optimizers import RMSprop, Adam, SGD, Nadam
import keras
from tensorflow.keras import layers

# build model
base_model = efn.EfficientNetB0(input_shape=(img_pixel ,img_pixel , 3),weights='imagenet',include_top=False)
base_model.trainable = False # freezing all the layers

print(base_model.layers)

add_model = keras.Sequential(base_model)

add_model.add(GlobalAveragePooling2D(name="avg_pool"))
add_model.add(BatchNormalization())
top_dropout_rate = 0.2
add_model.add(Dropout(top_dropout_rate, name="top_dropout"))
add_model.add(Dense(1, activation='sigmoid'))
model = add_model

gc.collect()

#### Train model (while still frozen)

In [None]:
# first training
compile_model(model)
history_frozen = fit_model(model)

In [None]:
gc.collect()
eval_model(model, history_frozen, "frozen")
gc.collect()

#### Unfreeze and retrain

In [None]:
unfreeze_model(model, base_model, 20)
print('done1')
gc.collect()
compile_model(model, 0.0001) # extra low learning rate to avoid overfitting
print('done2')
gc.collect()
history = fit_model(model)

In [None]:
gc.collect()
eval_model(model, history, "unfrozen")
gc.collect()

#### Investigate final model

In [None]:
# print out model
model.summary()

In [None]:
# store model
if export_results and on_cluster:
    model.save(export_folder + "/model_"+datetime.datetime.now().strftime(date_format)+".h5")

#### Timer
Stop the timer

In [None]:
stop = time.time()
print(f'It took {stop - start} s to load the data and train the model')

if export_results:
    f = open(f'{export_folder}/overall_time.txt', 'w')
    f.write(f'Time it took : {stop - start} s')
    f.close()

## 4. Test model

#### Evaluate loss and accuracy during training

In [None]:
# # evaluate the model
# _, train_acc, train_recall = model.evaluate(X_train, y_train)
# _, test_acc, test_recall  = model.evaluate(X_test, y_test)

# print('Accuracy\tTrain: %.3f, Test: %.3f' % (train_acc, test_acc))
# print('Recall\tTrain: %.3f, Test: %.3f' % (train_recall, test_recall))

#### Predict test set (again)

In [None]:
y_pred_continuous = model.predict(X_test)
y_pred_discrete = (model.predict(X_test) > 0.5).astype("int32")
y_pred = y_pred_discrete

#### Display images and predictions

In [None]:
from utilities import display_results

plt_all = display_results(X_test, y_pred, y_test, 15, img_pixel, flat=False)

if export_results:
    plt_all.savefig(export_folder + "/classification_results_"+datetime.datetime.now().strftime(date_format)+".png")

plt_all.show()

#### Display wrongly classified images

In [None]:
from utilities import display_interesting_results

plt_wrong = display_interesting_results(X_test, y_pred, y_test, img_pixel=img_pixel, flat=False)

if export_results:
    plt_wrong.savefig(export_folder + "/incorrect_classification_results_"+datetime.datetime.now().strftime(date_format)+".png")
    
plt_wrong.show()