# Segmentation

In [None]:
import sys, os
sys.path.append('../')

import tensorflow as tf 
import numpy as np


from MRIsegm.datagenerators import create_segmentation_generator
from MRIsegm.metrics import dice_coef
from MRIsegm.losses import DiceBCEloss, soft_dice_loss
from MRIsegm.models import unet
from MRIsegm.graphics import show_dataset, plot_history, show_prediction

## Constants

In [None]:
SEED = 666
BATCH_SIZE_TRAIN = 4
BATCH_SIZE_VALIDATION = 4

IMAGE_HEIGHT = 128
IMAGE_WIDTH =  128
IMG_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

NUM_OF_EPOCHS = 100

NUM_TRAIN = 406
NUM_VALIDATION = 80

EPOCH_STEP_TRAIN = NUM_TRAIN // BATCH_SIZE_TRAIN
EPOCH_STEP_VALIDATION = NUM_VALIDATION // BATCH_SIZE_VALIDATION

data_dir_training = '../data/training'
data_dir_train_img = os.path.join(data_dir_training, 'img')
data_dir_train_mask = os.path.join(data_dir_training, 'mask')

data_dir_validation = '../data/validation'
data_dir_validation_img = os.path.join(data_dir_validation, 'img')
data_dir_validation_mask = os.path.join(data_dir_validation, 'mask')


## Generators

In [None]:
data_gen_args_img = dict(rescale=1./255, rotation_range=5,horizontal_flip=True)
data_gen_args_mask = dict(rescale=1./255, rotation_range=5,horizontal_flip=True)

val_data_gen_args_img = dict(rescale=1./255)
val_data_gen_args_mask = dict(rescale=1./255)

In [None]:
train_generator = create_segmentation_generator(data_dir_train_img, data_dir_train_mask, BATCH_SIZE_TRAIN, IMG_SIZE, SEED, data_gen_args_img, data_gen_args_mask)

validation_generator = create_segmentation_generator(data_dir_validation_img, data_dir_validation_mask, BATCH_SIZE_VALIDATION, IMG_SIZE, SEED, val_data_gen_args_img, val_data_gen_args_mask)

### Show trainig data

In [None]:
show_dataset(train_generator, 3) # training

### Show validation data

In [None]:
show_dataset(validation_generator, 3) # validation

## Model

In [None]:
model = unet(IMAGE_HEIGHT, IMAGE_WIDTH, n_levels=4, initial_features=32)



optimizer = 'adam'
loss = soft_dice_loss
metrics = [  dice_coef ]

model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

model.summary()

In [None]:
model_name = model.name + f'_{IMAGE_HEIGHT}_{IMAGE_WIDTH}' 


if type(optimizer) == str: 
    model_name = model_name + f'_OPT={optimizer}' 
else:
    model_name = model_name + f'_OPT={optimizer._name}'


if type(loss) == str: 
    model_name = model_name + f'_LOSS={loss}' 
else:
    model_name = model_name + f'_LOSS={loss.__name__}'

print('model name: ', model_name)

In [None]:
# optional: logs_dir = '../data/models/logs/model_name'
csv_dir = '../data/CSV/'

callbacks = [
                  tf.keras.callbacks.ModelCheckpoint('../data/models/checkpoints/' + model_name + '_checkpoint' + '.h5', save_best_only=True),
                  tf.keras.callbacks.CSVLogger( csv_dir + model_name + '.csv', separator=',', append=False),
                  tf.keras.callbacks.EarlyStopping(patience=10, monitor='val_loss')
                  # optional: tf.keras.callbacks.TensorBoard(log_dir=logs_dir)
]

history = model.fit(train_generator,
            steps_per_epoch=EPOCH_STEP_TRAIN, 
            validation_data=validation_generator, 
            validation_steps=EPOCH_STEP_VALIDATION,
            epochs=NUM_OF_EPOCHS,
            callbacks=callbacks)

In [None]:
model.save('../data/models/' + model_name + '.h5')

In [None]:
# optional: %load_ext tensorboard

# optional: !tensorboard --logdir log_dir

In [None]:
print("Evaluating on validation data")
evaluation = model.evaluate(validation_generator, batch_size=BATCH_SIZE_VALIDATION, steps=EPOCH_STEP_VALIDATION, return_dict=True)
print(evaluation)

In [None]:
import json

with open('../data/evals/' + model_name + '_eval.txt', 'w') as file:
     file.write(json.dumps(evaluation))

## Plots

In [None]:
plot_history(model_name, history, metrics, loss, custom_loss=True, custom_metrics=True, figsize=(18,8),labelsize=13, path='../data/plots/' + model_name)

## Predictions

### Training images prediction

In [None]:
show_prediction(datagen=train_generator, model=model , num=5)

### Validation images prediction

In [None]:
show_prediction(datagen=validation_generator, model=model , num=10)