# SETUP

## is GPU on?

In [1]:
# from tensorflow.python.client import device_lib
# print(device_lib.list_local_devices())

## boilerplate

In [3]:
from tensorflow.keras.callbacks import Callback, EarlyStopping
from numpy import arange

MODEL_NAME = 'resnet50'
IMG_SIZE = (224, 224)
INPUT_SHAPE=(224, 224, 3)
CLASSES = 2
TRIAL = TRIAL = list(arange(3,10)) #list(arange(0,3))
FT_BLOCK = list(arange(0,11)) # FROM feature extractor TO fine tuning scratch
BATCH_SIZE = [32, 64, 128, 256] # these data points will be passed as a batch at one time to the network

## UDC/Fs

In [4]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras import layers
from math import floor
from timeit import default_timer as timer
from json import dump
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow import data

def model_constructor(FT_BLOCK):
    base_model = ResNet50(
        weights='imagenet', 
        include_top=False,
        input_shape=INPUT_SHAPE)
    
    # construct the top layer containing 
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x) # add a global spatial average pooling layer
    x = layers.Dense(CLASSES, activation='sigmoid',  name='predictions')(x) # and add the output layer for binary class

    # model is ready to be trained
    model = Model(inputs=base_model.input, outputs=x)

    # freeze the layers before the `FROZEN_END` layer 
    total_layers = len(base_model.layers)
    ft_layers = floor(total_layers / 10)
    FROZEN_END = total_layers - ft_layers * FT_BLOCK

    base_model.trainable = True
    for layer in base_model.layers[:FROZEN_END]: 
        layer.trainable = False
    
    return model

class TimeCallback(Callback):
    def __init__(self, logs={}):
        self.logs=[]
    def on_epoch_begin(self, epoch, logs={}):
        self.starttime = timer()
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(timer()-self.starttime)

def time_converter(sec):
    hours, rem = divmod(sec, 3600)
    minutes, seconds = divmod(rem, 60)
    print("{:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))

def save_history(history, tag):
    file_path = f'{HISTORY_DIR}/{tag}.json'
    with open(file_path, 'w') as f:
        dump(history.history, f)
        
def data_preparation(BATCH_SIZE):
    train_dir = '../../data/binaryclass_clean/train/'
    test_dir = '../../data/binaryclass_clean/test/'

    train_ds = image_dataset_from_directory(
        directory=train_dir,
        label_mode='categorical',
        batch_size=BATCH_SIZE,
        image_size=IMG_SIZE,
        seed=0,
        validation_split=0.1,
        subset='training')

    val_ds = image_dataset_from_directory(
        directory=train_dir,
        label_mode='categorical',
        batch_size=BATCH_SIZE,
        image_size=IMG_SIZE,
        seed=0,
        validation_split=0.1,
        subset='validation')

    test_ds = image_dataset_from_directory(
        directory=test_dir,
        label_mode='categorical',
        batch_size=1,
        image_size=IMG_SIZE)

    AUTOTUNE = data.AUTOTUNE
    train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
    test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
    
    return train_ds, val_ds, test_ds

# GRID SEARCH

In [4]:
for i in TRIAL:
    print(" - - - - - TRIAL:", i, " - - - - - ")
    HISTORY_DIR = f'../../logs/FT/{i}'
    for batch in BATCH_SIZE:
        print("Batch size:", batch)
        train_ds, val_ds, test_ds = data_preparation(batch)
        epochs, loss, accuracy = [None]*len(FT_BLOCK), [None]*len(FT_BLOCK), [None]*len(FT_BLOCK)
        for ft in FT_BLOCK:
            tag = f'{MODEL_NAME}_BS{batch}_FT{ft*10}'
            print(tag)
            model = model_constructor(ft)
            model.compile(loss='binary_crossentropy', metrics='accuracy', optimizer='adam')
            history = model.fit(train_ds, epochs=100, verbose=0, validation_data=val_ds, callbacks=[EarlyStopping(patience=3)])
            save_history(history, tag)
            # time_converter(sum(cb_time.logs))
            epochs[ft] = len(history.history['loss'])
            loss[ft], accuracy[ft] = model.evaluate(test_ds)
        # PRINT THE RESULT
        # print('Count Epoch:', epochs)
        print(f' * * * * * {MODEL_NAME}_BS{batch}_LOSS * * * * * ')
        for item in loss:
            print(item)
        print(f' * * * * * {MODEL_NAME}_BS{batch}_ACCURACY * * * * * ')
        for item in accuracy:
            print(item)

 - - - - - TRIAL: 0  - - - - - 
Batch size: 32
Found 9980 files belonging to 2 classes.
Using 8982 files for training.
Found 9980 files belonging to 2 classes.
Using 998 files for validation.
Found 1000 files belonging to 2 classes.
resnet50_BS32_FT0
resnet50_BS32_FT10
resnet50_BS32_FT20
resnet50_BS32_FT30
resnet50_BS32_FT40
resnet50_BS32_FT50
resnet50_BS32_FT60
resnet50_BS32_FT70
resnet50_BS32_FT80
resnet50_BS32_FT90
resnet50_BS32_FT100
 * * * * * resnet50_BS32_LOSS * * * * * 
0.6792939901351929
1.0158129930496216
1.7031893730163574
0.5952131748199463
0.5384291410446167
0.4928511381149292
0.5736762881278992
0.35723742842674255
0.3896218538284302
0.49052777886390686
0.44202759861946106
 * * * * * resnet50_BS32_ACCURACY * * * * * 
0.6940000057220459
0.6990000009536743
0.671999990940094
0.859000027179718
0.8690000176429749
0.8759999871253967
0.8640000224113464
0.8980000019073486
0.8989999890327454
0.8930000066757202
0.890999972820282
Batch size: 64
Found 9980 files belonging to 2 classes

resnet50_BS32_FT50
resnet50_BS32_FT60
resnet50_BS32_FT70
resnet50_BS32_FT80
resnet50_BS32_FT90
resnet50_BS32_FT100
 * * * * * resnet50_BS32_LOSS * * * * * 
0.6062365174293518
1.0236619710922241
0.7001902461051941
0.8035963773727417
0.4520939290523529
0.5590106844902039
0.5925477147102356
0.4673694372177124
0.5579063296318054
0.34533870220184326
0.427597314119339
 * * * * * resnet50_BS32_ACCURACY * * * * * 
0.7160000205039978
0.7480000257492065
0.7770000100135803
0.8109999895095825
0.8930000066757202
0.8450000286102295
0.8529999852180481
0.8579999804496765
0.878000020980835
0.8769999742507935
0.8690000176429749
Batch size: 64
Found 9980 files belonging to 2 classes.
Using 8982 files for training.
Found 9980 files belonging to 2 classes.
Using 998 files for validation.
Found 1000 files belonging to 2 classes.
resnet50_BS64_FT0
resnet50_BS64_FT10
resnet50_BS64_FT20
resnet50_BS64_FT30
resnet50_BS64_FT40
resnet50_BS64_FT50
resnet50_BS64_FT60
resnet50_BS64_FT70
resnet50_BS64_FT80
resnet50_BS