In [None]:
import os

if not os.path.exists("/content/BraTS2020_TrainingData") or not os.path.exists("/content/BraTS2020_ValidationData"):
  !unzip "/content/drive/MyDrive/Individual Project Final/brats_dataset.zip" > /dev/null

In [None]:
import numpy as np
from dataset import get_train_test_ids_completed, create_training_gen, create_test_gen
from unet import create_unet
from unet_model import UNetModel

# Intial seed to use for each test where repeatability is required
# This will for example cause train + test data to be split the same way the first time
seed = 42


# Training custom models

This showcases the model's ability to train a u-net model using the data. It uses the BraTS 2020 dataset training data, which is loaded and split internally into a training and test set. The chosen mri types used to train the model should be specified, although it will default to FLAIR and T1ce if not. The loss function can also be specified, and if it uses categorical cross entropy loss, the u-net activation will be changed from sigmoid to softmax, and the class labels will be one-hot encoded. Otherwise, binary masks are created per class, which should be evaluated individually in the loss function. Losses have been implemeted in the MulticlassMetrics class, which also includes other metrics that can be used to evaluate the performance of the model. The default loss used is a combined loss function, which takes the average of the combination of binary cross entropy loss and dice loss, calculated on a per class basis, hence the use of the sigmoid activation.

In [None]:
# Training a multilabel u-net model using FLAIR, T1ce, and T2 modalities, 
#.. with a combined dice and binary cross entropy loss function 

# Set the intial seed for repeatability
np.random.seed(seed)

model = UNetModel(models_path="/content/models", loss='combined_loss', modalities=['flair', 't1ce', 't2'], slice_interval=50)
model.train_model(0, 1, 1, 2)

Constructing UnetModel.
Using multilabel metrics




In [None]:
history = UNetModel(models_path="/content/models",loss='combined_loss', modalities=['flair', 't1ce', 't2'], slice_interval=50)
history.load_model(0)
history.train_model(0, 1, 1,6)
history.evaluate_model()

Constructing UnetModel.
Using multilabel metrics






In [None]:
# Training a multilabel u-net model using FLAIR and T1ce modalities, 
#.. with a categorical cross entropy loss function
# Labels will be one-hot encoded

model = UNetModel(models_path="/content/models", loss='categorical_crossentropy', modalities=['flair', 't1ce'], slice_interval=50)
model.train_model(0, 1, 2, 6)
model.evaluate_model()

Constructing UnetModel.
Using multilabel metrics
Epoch 1/2
Epoch 2/2






In [None]:
# Training a binary classifier u-net model using FLAIR, T1ce, and T2 modalities, 
#.. with a combined dice and binary cross entropy loss function 

model = UNetModel(models_path="/content/models",
                  loss='combined_loss',
                  binary=True,
                  modalities=['flair'],
                  slice_interval=50)
model.train_model(0, 1, 2, 6)

Constructing UnetModel.
Using binary metrics
Epoch 1/2
Epoch 2/2




# Loading pre-trained weights

In [None]:
# Loading the model with pre-trained weights, found at kaggle.com


model = UNetModel(models_path="/content/models", loss='categorical_crossentropy', modalities=['flair', 't1ce'])
model.load_model(model_path="/content/drive/MyDrive/Individual Project Final/saved_models/pre-trained_ratislav/model_per_class.h5", compile=False)
model.compile_model()
model.evaluate_model()



Constructing UnetModel.
Using multilabel metrics


In [None]:
# Loading binary classifier with FLAIR and T1ce

model = UNetModel(models_path="/content/models", loss='combined_loss', modalities=['flair', 't1ce'], segment_classes={1:'tumor'})
model.load_model(model_path="/content/drive/MyDrive/Individual Project Final/saved_models/binary/flair_t1ce/models/model_job4", compile=False)
model.compile_model()
model.evaluate_model()

Constructing UnetModel.
Using binary metrics
 1/37 [..............................] - ETA: 7:44 - loss: 1.0290 - accuracy: 0.8626 - sensitivity: 0.9752 - specificity: 0.8615 - dice_loss: 0.8817 - combined_loss: 1.0290

KeyboardInterrupt: ignored

# Lime Explanations

In [None]:
model = UNetModel(models_path="/content/models", loss='categorical_crossentropy', modalities=['flair', 't1ce'], seed=seed)
model.load_model(model_path="/content/drive/MyDrive/Individual Project Final/saved_models/pre-trained_ratislav/model_per_class.h5", compile=False)
model.compile_model()

# !pip install lime
from lime_xai import LimeXAI

lime = LimeXAI(model, seed)

# lime.explain(5, 100, visualize=True)
# lime.explain(7, 1000, visualize=True)
for i in range(20):
  print(i)
  lime.explain(i, 1000, visualize=True, duplicate_channel=0)
  lime.explain(i, 1000, visualize=True, duplicate_channel=1)
  lime.explain(i, 1000, visualize=True, duplicate_channel="mean")
