In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    print(f"CUDA is available. PyTorch version: {torch.__version__}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

    # Test a simple tensor operation on the GPU
    x = torch.randn(3, 3).cuda()
    y = torch.randn(3, 3).cuda()
    z = x + y
    
    print("Tensor 'x' device: ", x.device)
    print("Tensor 'y' device: ", y.device)
    print("Tensor 'z' device: ", z.device)
    print("Tensor operation result: ", z)
else:
    print("CUDA is not available. Please check your installation.")


In [None]:
import os
import pickle
import numpy as np
import tensorflow as tf
import keras
from matplotlib import pyplot as plt
import glob
import random
from tensorflow.keras.optimizers import Adam

import sys
sys.path.append(r'E:\BRATS DATA CODES\UNET CODES')
from brats2020_custom_data_generator import imageLoader

In [None]:
train_img_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\train\images/"
train_mask_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\train\masks/"

img_list = os.listdir(train_img_dir)
msk_list = os.listdir(train_mask_dir)

num_images = len(os.listdir(train_img_dir))

img_num = random.randint(0,num_images-1)
test_img = np.load(train_img_dir+img_list[img_num])
test_mask = np.load(train_mask_dir+msk_list[img_num])
test_mask = np.argmax(test_mask, axis=3)

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()

In [None]:
wt0 = 0.26
wt1 = 22.53
wt2 = 22.53
wt3 = 26.31

#Weights are: 0.26, 22.53, 22.53, 26.21
#wt0, wt1, wt2, wt3 = 0.26, 22.53, 22.53, 26.21
#These weihts can be used for Dice loss 

In [None]:

train_img_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\train\images/"
train_mask_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\train\masks/"

val_img_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\val\images/"
val_mask_dir = r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\val\masks/"

train_img_list=os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

val_img_list=os.listdir(val_img_dir)
val_mask_list = os.listdir(val_mask_dir)

In [None]:
batch_size = 2

train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                train_mask_dir, train_mask_list, batch_size)

val_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)

In [None]:
img, msk = train_img_datagen.__next__()

img_num = random.randint(0,img.shape[0]-1)
test_img=img[img_num]
test_mask=msk[img_num]
test_mask=np.argmax(test_mask, axis=3)

n_slice=random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()


In [None]:
from tensorflow.keras.optimizers import Adam
import segmentation_models_3D as sm


# Define weights
wt0, wt1, wt2, wt3 = 0.25, 0.25, 0.25, 0.25

# Define losses
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3]))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# Define metrics
metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

# Learning rate
LR = 0.0001

# Define optimizer
optim = Adam(LR)

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import glob
import segmentation_models_3D as sm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
import os
import tensorflow as tf

# Define weights
wt0, wt1, wt2, wt3 = 0.25, 0.25, 0.25, 0.25

# Define losses
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3]))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# Define metrics
metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

# Learning rate
LR = 0.0001

# Define optimizer
optim = Adam(LR)

# Define checkpoint directory
checkpoint_dir = 'E:/saved_models/checkpoints/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Define the model function
def create_model():
    from brats2020_3d_unet import simple_unet_model
    return simple_unet_model(IMG_HEIGHT=128, 
                             IMG_WIDTH=128, 
                             IMG_DEPTH=128, 
                             IMG_CHANNELS=3, 
                             num_classes=4)

# Instantiate model
model = create_model()

# Compile the model
model.compile(optimizer=optim, loss=total_loss, metrics=metrics)
print(model.summary())

# Define checkpoint callback
checkpoint_path = checkpoint_dir + 'weights_epoch_{epoch:02d}.hdf5'
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path,
                                      save_weights_only=True,
                                      monitor='val_loss',
                                      mode='min',
                                      save_best_only=False,
                                      verbose=1)

# Load the latest checkpoint if available
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
print("Latest checkpoint:", latest_checkpoint)  # Debug print
if latest_checkpoint:
    initial_epoch = int(latest_checkpoint.split('_')[-1].split('.')[0]) + 1
    print(f"Loading weights from {latest_checkpoint}. Resuming from epoch {initial_epoch}.")
    model.load_weights(latest_checkpoint)
else:
    initial_epoch = 0
    print("No previous checkpoints found. Starting training from epoch 0.")

# Fit the model
steps_per_epoch = len(train_img_list) // batch_size
val_steps_per_epoch = len(val_img_list) // batch_size

print("Training process started...")
history = model.fit(train_img_datagen,
                    steps_per_epoch=steps_per_epoch,
                    epochs=96,
                    initial_epoch=initial_epoch,
                    verbose=1,
                    validation_data=val_img_datagen,
                    validation_steps=val_steps_per_epoch,
                    callbacks=[checkpoint_callback])
print("Training process completed!")

# Define the path where you want to save the model
model_save_path = r'E:\saved_models\brats_3d_2.0.hdf5'

# Save the model
model.save(model_save_path)
print("Model saved successfully at:", model_save_path)

# Save the history
history_save_path = r'E:\saved_models\brats_3d_2.0_history.pkl'
with open(history_save_path, 'wb') as file:
    pickle.dump(history.history, file)
print("History saved successfully at:", history_save_path)


In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

plt.plot(epochs, acc, 'y', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
import numpy as np
from tensorflow.keras.models import load_model
import segmentation_models_3D as sm

# Define weights
wt0, wt1, wt2, wt3 = 0.25, 0.25, 0.25, 0.25

# Define losses
dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3]))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

# Define metrics
metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

# Load the model with custom objects
model_path = r'E:\saved_models\Brats_3d.hdf5'  # r prefix is for raw string to handle backslashes

# Load the model
my_model = load_model(model_path, 
                      custom_objects={'dice_loss_plus_1focal_loss': total_loss,
                                      'iou_score': sm.metrics.IOUScore(threshold=0.5),
                                      'DiceLoss': sm.losses.DiceLoss,
                                      'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss})

# Now all set to continue the training process. 
history2 = my_model.fit(train_img_datagen,
                        steps_per_epoch=steps_per_epoch,
                        epochs=5,
                        verbose=1,
                        validation_data=val_img_datagen,
                        validation_steps=val_steps_per_epoch)


In [None]:
my_model = load_model(model_path, 
                      compile=False)


#Verify IoU on a batch of images from the test dataset
#Using built in keras function for IoU
#Only works on TF > 2.0
from keras.metrics import MeanIoU

batch_size=1 #Check IoU for a batch of images
test_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)

#Verify generator.... In python 3 next() is renamed as __next__()
test_image_batch, test_mask_batch = test_img_datagen.__next__()

test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)
test_pred_batch = my_model.predict(test_image_batch)
test_pred_batch_argmax = np.argmax(test_pred_batch, axis=4)

n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(test_pred_batch_argmax, test_mask_batch_argmax)
print("Mean IoU =", IOU_keras.result().numpy())


In [None]:
img_num = 200

test_img = np.load(r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\val\images\image_" + str(img_num) + ".npy")

test_mask = np.load(r"E:\Brats Dataset\BraTS2020_TrainingData\input_data_128\val\masks\mask_" + str(img_num) + ".npy")
test_mask_argmax=np.argmax(test_mask, axis=3)

test_img_input = np.expand_dims(test_img, axis=0)
test_prediction = my_model.predict(test_img_input)
test_prediction_argmax=np.argmax(test_prediction, axis=4)[0,:,:,:]


# print(test_prediction_argmax.shape)
print(test_mask_argmax.shape)

In [None]:
from matplotlib import pyplot as plt
import random

n_slice=random.randint(0, test_prediction_argmax.shape[2])
#n_slice = 20
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,n_slice,1], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(test_mask_argmax[:,:,n_slice])
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(test_prediction_argmax[:,:, n_slice])
plt.show()