In [None]:
import os
import numpy as np
from custom_func import imageLoader
import keras
from matplotlib import pyplot as plt
import segmentation_models_3D as sm
from custom_models import simple_unet_model

In [None]:
# path
BASE_PATH = "C:/Users/ygrae/Desktop/BRATS2020/BraTS2020_TrainingData/input_data128/"

# training parameters
BATCH_SIZE = 2
EPOCHS = 5
LR = 0.0001

In [None]:
# train and validation directories
train_img_dir = os.path.join(BASE_PATH,'train/images/')
train_mask_dir = os.path.join(BASE_PATH,'train/masks/')

val_img_dir = os.path.join(BASE_PATH,'val/images/')
val_mask_dir = os.path.join(BASE_PATH,'val/masks/')

# creates list of directories
train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

val_img_list=os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)

# define custom loaders
train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                train_mask_dir, train_mask_list, BATCH_SIZE)

val_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, BATCH_SIZE)

In [None]:
# Define Loss, Metrics and Optimizer to be used for training
wt0, wt1, wt2, wt3 = 0.25,0.25,0.25,0.25
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3])) 
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]
optim = keras.optimizers.Adam(LR)

In [None]:
# steps per epoch
steps_per_epoch = len(train_img_list)//BATCH_SIZE
val_steps_per_epoch = len(val_img_list)//BATCH_SIZE

# initialize model
model = simple_unet_model(IMG_HEIGHT=128, 
                          IMG_WIDTH=128, 
                          IMG_DEPTH=128, 
                          IMG_CHANNELS=3, 
                          num_classes=4)

model.compile(optimizer = optim, loss=total_loss, metrics=metrics)

# fit model
history=model.fit(train_img_datagen,
          steps_per_epoch=steps_per_epoch,
          epochs=EPOCHS,
          verbose=1,
          validation_data=val_img_datagen,
          validation_steps=val_steps_per_epoch,
          )

# save trained model
model.save('brats_3d.keras')

In [None]:
# plot training and validation IoU and loss at each epoch (aka. learning curves)
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

plt.plot(epochs, acc, 'y', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()