In [None]:
import sys, os
sys.path.append('../')

import tensorflow as tf 
import segmentation_models as sm

from MRIsegm.datagenerators import create_segmentation_generator
from MRIsegm.metrics import dice_coef
from MRIsegm.losses import DiceBCEloss

from MRIsegm.graphics import show_dataset, plot_history, show_prediction

In [None]:
BACKBONE = 'efficientnetb0'

SEED = 666
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VALIDATION = 8

IMAGE_HEIGHT = 256
IMAGE_WIDTH =  256
IMG_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

NUM_OF_EPOCHS = 100

NUM_TRAIN = 406
NUM_VALIDATION = 80

EPOCH_STEP_TRAIN = NUM_TRAIN // BATCH_SIZE_TRAIN
EPOCH_STEP_VALIDATION = NUM_VALIDATION // BATCH_SIZE_VALIDATION

data_dir_training = '../data/training'
data_dir_train_img = os.path.join(data_dir_training, 'img')
data_dir_train_mask = os.path.join(data_dir_training, 'mask')

data_dir_validation = '../data/validation'
data_dir_validation_img = os.path.join(data_dir_validation, 'img')
data_dir_validation_mask = os.path.join(data_dir_validation, 'mask')


In [None]:
data_gen_args_img = dict(rescale=1./255, rotation_range=5, horizontal_flip=True)
data_gen_args_mask = dict(rescale=1./255, rotation_range=5, horizontal_flip=True)

val_data_gen_args_img = dict(rescale=1./255)
val_data_gen_args_mask = dict(rescale=1./255)

In [None]:
train_generator = create_segmentation_generator(data_dir_train_img, data_dir_train_mask, BATCH_SIZE_TRAIN, IMG_SIZE, SEED, data_gen_args_img, data_gen_args_mask)

validation_generator = create_segmentation_generator(data_dir_validation_img, data_dir_validation_mask, BATCH_SIZE_VALIDATION, IMG_SIZE, SEED, val_data_gen_args_img, val_data_gen_args_mask)

In [None]:
show_dataset(train_generator, 3) # training

In [None]:
model = sm.Unet(BACKBONE, input_shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 1), encoder_weights=None, activation='sigmoid')

optimizer = 'Adam'
# iou_loss = sm.losses.JaccardLoss(class_weights=None, class_indexes=None, per_image=False, smooth=1.)
# BinaryFocalLoss = sm.losses.BinaryFocalLoss(alpha=0.25, gamma=2.0)
loss = DiceBCEloss
# iou_score = sm.metrics.IOUScore(smooth=1., name='iou_score')
metrics = [  dice_coef ]

model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

print(model.summary())

In [None]:
model_name = BACKBONE + f'_{IMAGE_HEIGHT}_{IMAGE_WIDTH}_BTC={BATCH_SIZE_TRAIN}_alpha3***' 


if type(optimizer) == str: 
    model_name = model_name + f'_OPT={optimizer}' 
else:
    model_name = model_name + f'_OPT={optimizer._name}'


if type(loss) == str: 
    model_name = model_name + f'_LOSS={loss}' 
else:
    model_name = model_name + f'_LOSS={loss.__name__}'

print('model name: ', model_name)

In [None]:
#optional: logs_dir = '../data/models/logs/' + model_name

#optional: csv_dir = '../data/CSV/'

callbacks = [
                  #optional: tf.keras.callbacks.ModelCheckpoint('../data/models/checkpoints/' + model_name + '_checkpoint' + '.h5', save_best_only=True),
                  #optional: tf.keras.callbacks.CSVLogger( csv_dir + model_name + '.csv', separator=',', append=False),
                  tf.keras.callbacks.EarlyStopping(patience=10, monitor='val_loss')
                  #optional: tf.keras.callbacks.TensorBoard(log_dir=logs_dir)
]


history = model.fit(train_generator,
            steps_per_epoch=EPOCH_STEP_TRAIN, 
            validation_data=validation_generator, 
            validation_steps=EPOCH_STEP_VALIDATION,
            epochs=NUM_OF_EPOCHS,
            callbacks=callbacks)

In [None]:
model.save('../data/models/' + model_name + '.h5')

In [None]:
#optional: %load_ext tensorboard

#optional: !tensorboard --logdir log_dir

In [None]:
plot_history(model_name, history, metrics, loss, custom_loss=True, custom_metrics=True, figsize=(18,8),labelsize=13, path='../data/plots/' + model_name)

In [None]:
print("Evaluating on validation data")
evaluation = model.evaluate(validation_generator, batch_size=BATCH_SIZE_VALIDATION, steps=EPOCH_STEP_VALIDATION, return_dict=True)
print(evaluation)

In [None]:
import json

with open('../data/evals/' + model_name + '_eval.txt', 'w') as file:
     file.write(json.dumps(evaluation))

In [None]:
show_prediction(datagen=train_generator, model=model , num=5)

In [None]:
show_prediction(datagen=validation_generator, model=model , num=15)