In [None]:
# This code loads the necessary libraries, installs any required libraries, and imports the necessary functions.
# It also sets the MinMaxScaler function to be the scaler variable.

#!pip install nibabel
#!pip install matplotlib
#!pip install tifffile
#!pip install -U scikit-learn
#!pip install shutil
#!pip install split-folders[full]
#!pip install pandas==1.5.2  # Install the pandas library with a specific version as pandas 2.0 removes append

#!pip install kaggle #Installing the Kaggle library

# Setting the Kaggle username and key
%env KAGGLE_USERNAME=ihindal
%env KAGGLE_KEY=549e8a0e9862683f6f255cb289ece9de

import csv
import kaggle # Importing the Kaggle library
import pandas as pd  # Import the pandas library
import splitfolders
import numpy as np
import shutil
import nibabel as nib
import glob
import tensorflow as tf
import keras
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from tifffile import imsave
import subprocess
from sklearn.preprocessing import MinMaxScaler
import os
import zipfile
import tarfile
import random
scaler = MinMaxScaler()

In [None]:
# Getting the current directory
current_directory = os.getcwd()

# Creating a path for the data directory
data_path = os.path.join(current_directory, "data")

# Creating the data directory if it doesn't exist
os.makedirs(data_path, exist_ok=True)

# Downloading the dataset using the Kaggle library and saving it to the data directory
subprocess.run(["kaggle", "datasets", "download", "-d", "dschettler8845/brats-2021-task1", "-p", data_path])

# Defining the path to the downloaded zip file
zip_file_path = os.path.join(data_path, "brats-2021-task1.zip")

# Extracting the contents of the zip file to the data directory
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(data_path)


In [None]:
# This code extracts the BraTS2021_Training_Data.tar file into the 'dataset' subfolder of the data_path directory. 
# The tar file contains the training data for the BraTS 2021 competition. 
# The extracted files are stored in the dataset_path variable. 
# The data_path variable is updated to point to the 'dataset' subfolder. 

# Create 'dataset' subfolder inside the data_path directory
dataset_path = os.path.join(data_path, "dataset")
os.makedirs(dataset_path, exist_ok=True)

# Extract the files into the 'dataset' subfolder
tar_file_path = os.path.join(data_path, "BraTS2021_Training_Data.tar")
with tarfile.open(tar_file_path, 'r') as tar_ref:
    tar_ref.extractall(dataset_path)

# Update the data_path variable to point to the 'dataset' subfolder
data_path = dataset_path

In [None]:
# Getting the list of patient directories in the data_path directory
patient_dirs = os.listdir(data_path)

# Iterating over each patient directory
for patient_dir in patient_dirs:

    if patient_dir == ".DS_Store":
        os.remove(os.path.join(data_path, patient_dir))
        continue

    # Delete 90% of the data
    if np.random.rand() < 0.9:
        # Removing the patient directory and its contents
        shutil.rmtree(os.path.join(data_path, patient_dir))


In [None]:
#this cell can be skipped it is used to produce a csv file with the names of the remaining folders in the dataset folder
#verify most has been deleted

directory = './data/dataset'  # Replace with the path to your desired directory
csv_file = 'output.csv'  # Replace with the desired path and filename for the CSV file

# Get all the folder names in the directory
folder_names = [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]

# Write folder names to CSV file
with open(csv_file, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Folder Name'])  # Writes the collumn header
    writer.writerows([[name] for name in folder_names])  # Write folder names

print('CSV file has been created successfully.')


In [None]:
# Get a list of all patient directories
patient_dirs = os.listdir(data_path)

for patient_dir in patient_dirs:
    # Skip .DS_Store files. this is a mac file
    if patient_dir == ".DS_Store":
        continue
    # Construct the path to the patient's data
    patient_path = os.path.join(data_path, patient_dir)

    # Construct paths to each type of scan for the patient
    flair_path = os.path.join(patient_path, patient_dir + '_flair.nii.gz')
    t1_path = os.path.join(patient_path, patient_dir + '_t1.nii.gz')
    t1ce_path = os.path.join(patient_path, patient_dir + '_t1ce.nii.gz')
    t2_path = os.path.join(patient_path, patient_dir + '_t2.nii.gz')
    seg_path = os.path.join(patient_path, patient_dir + '_seg.nii.gz')

    # Load and process the data as before. this will process and transform each of the imaging modalities.
    test_image_flair = nib.load(flair_path).get_fdata()
    test_image_flair = scaler.fit_transform(test_image_flair.reshape(-1, test_image_flair.shape[-1])).reshape(test_image_flair.shape)

    test_image_t1 = nib.load(t1_path).get_fdata()
    test_image_t1 = scaler.fit_transform(test_image_t1.reshape(-1, test_image_t1.shape[-1])).reshape(test_image_t1.shape)

    test_image_t1ce = nib.load(t1ce_path).get_fdata()
    test_image_t1ce = scaler.fit_transform(test_image_t1ce.reshape(-1, test_image_t1ce.shape[-1])).reshape(test_image_t1ce.shape)

    test_image_t2 = nib.load(t2_path).get_fdata()
    test_image_t2 = scaler.fit_transform(test_image_t2.reshape(-1, test_image_t2.shape[-1])).reshape(test_image_t2.shape)

    test_mask = nib.load(seg_path).get_fdata()
    test_mask = test_mask.astype(np.uint8)

    #print(np.unique(test_mask))
    test_mask[test_mask==4] = 3
    #print(np.unique(test_mask))
    '''
    # Display the images and mask uncomment to see
    import random
    n_slice=random.randint(0, test_mask.shape[2])

    
    plt.figure(figsize=(12, 8))

    plt.subplot(231)
    plt.imshow(test_image_flair[:,:,n_slice], cmap='gray')
    plt.title('Image flair')
    plt.subplot(232)
    plt.imshow(test_image_t1[:,:,n_slice], cmap='gray')
    plt.title('Image t1')
    plt.subplot(233)
    plt.imshow(test_image_t1ce[:,:,n_slice], cmap='gray')
    plt.title('Image t1ce')
    plt.subplot(234)
    plt.imshow(test_image_t2[:,:,n_slice], cmap='gray')
    plt.title('Image t2')
    plt.subplot(235)
    plt.imshow(test_mask[:,:,n_slice])
    plt.title('Mask')
    plt.show()
    '''


    # Combine and crop the images
    combined_x = np.stack([test_image_flair, test_image_t1ce, test_image_t2], axis=3)
    combined_x = combined_x[56:184, 56:184, 13:141]
    test_mask = test_mask[56:184, 56:184, 13:141]
    '''
    # Display the images and mask uncomment to see
    n_slice=random.randint(0, test_mask.shape[2])
    plt.figure(figsize=(12, 8))

    plt.subplot(221)
    plt.imshow(combined_x[:,:,n_slice, 0], cmap='gray')
    plt.title('Image flair')
    plt.subplot(222)
    plt.imshow(combined_x[:,:,n_slice, 1], cmap='gray')
    plt.title('Image t1ce')
    plt.subplot(223)
    plt.imshow(combined_x[:,:,n_slice, 2], cmap='gray')
    plt.title('Image t2')
    plt.subplot(224)
    plt.imshow(test_mask[:,:,n_slice])
    plt.title('Mask')
    plt.show()
    '''
    # Save the combined image as .tif and .npy files in the patient's directory
    imsave(os.path.join(patient_path, 'combined255.tif'), combined_x)
    np.save(os.path.join(patient_path, 'combined255.npy'), combined_x)

In [None]:
# List of all the directories containing the different MRI sequences
t1_list = sorted(glob.glob('data/dataset/BraTS2021_*/BraTS2021_*_t1.nii.gz'))
t2_list = sorted(glob.glob('data/dataset/BraTS2021_*/BraTS2021_*_t2.nii.gz'))
t1ce_list = sorted(glob.glob('data/dataset/BraTS2021_*/BraTS2021_*_t1ce.nii.gz'))
flair_list = sorted(glob.glob('data/dataset/BraTS2021_*/BraTS2021_*_flair.nii.gz'))
mask_list = sorted(glob.glob('data/dataset/BraTS2021_*/BraTS2021_*_seg.nii.gz'))

# Defines the list of combinations of images used. 
combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]


# Ensure that the 'input_data_3channels' directory and its subdirectories exist
os.makedirs('BraTS2021_TrainingData/input_data_3channels', exist_ok=True)

# Create subdirectories for each combination
for combo in combinations:
    combo_name = "_".join(combo)
    os.makedirs(f'BraTS2021_TrainingData/input_data_3channels/{combo_name}/images', exist_ok=True)
    os.makedirs(f'BraTS2021_TrainingData/input_data_3channels/{combo_name}/masks', exist_ok=True)

# Loop over the combinations
for combo in combinations:
    # Loop over the images
    for img in range(len(t2_list)):
        print("Now preparing image and masks number: ", img)
        
        temp_images = []
        for layer in combo:
            if layer == 't2':
                temp_image=nib.load(t2_list[img]).get_fdata()
            elif layer == 't1':
                temp_image=nib.load(t1_list[img]).get_fdata()
            elif layer == 't1ce':
                temp_image=nib.load(t1ce_list[img]).get_fdata()
            elif layer == 'flair':
                temp_image=nib.load(flair_list[img]).get_fdata()
                
            temp_image=scaler.fit_transform(temp_image.reshape(-1, temp_image.shape[-1])).reshape(temp_image.shape)
            temp_images.append(temp_image)
        
        temp_mask=nib.load(mask_list[img]).get_fdata()
        temp_mask=temp_mask.astype(np.uint8)
        temp_mask[temp_mask==4] = 3  # Reassign mask values 4 to 3

        temp_combined_images = np.stack(temp_images, axis=3)
        
        # Crop to a size to be divisible by 64 so we can later extract 64x64x64 patches removes the outside black area.
        temp_combined_images=temp_combined_images[56:184, 56:184, 13:141]
        temp_mask = temp_mask[56:184, 56:184, 13:141]
        
        val, counts = np.unique(temp_mask, return_counts=True)
        
        if (1 - (counts[0]/counts.sum())) > 0.01:
            print("Save Me")
            temp_mask= to_categorical(temp_mask, num_classes=4)
            combo_name = "_".join(combo)
            np.save(f'BraTS2021_TrainingData/input_data_3channels/{combo_name}/images/image_{img}.npy', temp_combined_images)
            np.save(f'BraTS2021_TrainingData/input_data_3channels/{combo_name}/masks/mask_{img}.npy', temp_mask)
            
        else:
            print("I am useless")


In [None]:
# This code splits the dataset into train and validation sets. It also splits the dataset into different combinations of input channels, and saves the output as a new dataset. 
combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]

output_folder = 'BraTS2021_TrainingData/input_data_128/'

for combo in combinations:
    combo_name = "_".join(combo)
    input_folder = f'BraTS2021_TrainingData/input_data_3channels/{combo_name}'
    output_combo_folder = f'{output_folder}/{combo_name}'
    
    splitfolders.ratio(input_folder, output=output_combo_folder, seed=42, ratio=(.75, .25), group_prefix=None) # default values


RESET THE KERNEL HERE. IT PREVENTS AN ERROR BELOW

In [None]:
#!pip install segmentation-models-3D

In [None]:
import csv
import pandas as pd  # Import the pandas library
import splitfolders
import numpy as np
import shutil
import nibabel as nib
import glob
import tensorflow as tf
import keras
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from tifffile import imsave
import subprocess
from sklearn.preprocessing import MinMaxScaler
import os
import zipfile
import tarfile
import random
scaler = MinMaxScaler()

In [None]:
def load_img(img_dir, img_list):
    """
    Load image data from files.

    Args:
    - img_dir (str): Directory path where the image files are located.
    - img_list (list): List of image filenames.

    Returns:
    - images (ndarray): NumPy array containing the loaded image data.
    """

    images = []  # A list to store the loaded image data
    for i, image_name in enumerate(img_list):
        if image_name.split('.')[1] == 'npy':
            # Load the image from the .npy file
            image = np.load(img_dir + image_name)
            images.append(image)
    
    images = np.array(images)  # Convert the list of images to a NumPy array
    return images

In [None]:
# the dataloader function. loads the images for training and validation.
def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):
    """
    A generator function that loads image and mask data in batches.

    Args:
    - img_dir (str): Directory path where the image files are located.
    - img_list (list): List of image filenames.
    - mask_dir (str): Directory path where the mask files are located.
    - mask_list (list): List of mask filenames.
    - batch_size (int): Number of samples per batch.

    Returns:
    - A generator that yields batches of image and mask data.
    """

    L = len(img_list)
    while True:
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])
            Y = load_img(mask_dir, mask_list[batch_start:limit])
            yield (X, Y)
            batch_start += batch_size
            batch_end += batch_size

In [None]:
combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]
batch_size = 8

train_img_datagen_list = []  # List to store image generators for each combination
train_img_list_list = []  # List to store image lists for each combination

# Loop over the combinations
for combo in combinations:
    combo_name = "_".join(combo)  # Join the combination elements with an underscore
    train_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/images/"
    train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"
    train_img_list = os.listdir(train_img_dir)  # Get the list of training image filenames
    train_mask_list = os.listdir(train_mask_dir)  # Get the list of training mask filenames

    # Create an image generator for the current combination
    train_img_datagen = imageLoader(train_img_dir, train_img_list, train_mask_dir, train_mask_list, batch_size)

    # Append the image generator and image list to the corresponding lists
    train_img_datagen_list.append(train_img_datagen)
    train_img_list_list.append(train_img_list)


In [None]:
# This code displays a random image from the training dataset, as well as the associated mask.
# The mask is a 3D array where each slice is a different mask. The mask is a 2D array where each pixel is a different class.

img, msk = next(train_img_datagen)  # Get the next batch of images and masks
img_num = random.randint(0, img.shape[0] - 1)  # Choose a random image from the batch
test_img = img[img_num]  # Select the chosen image
test_mask = msk[img_num]  # Select the corresponding mask
test_mask = np.argmax(test_mask, axis=3)  # Convert the one-hot encoded mask to a categorical mask
n_slice = random.randint(0, test_mask.shape[2])  # Choose a random slice from the mask

# Display the images and mask
'''
plt.figure(figsize=(12, 8))
plt.subplot(221)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()
'''


In [None]:
####################################################
combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]
# Select a combination for testing
combo_name = "_".join(combinations[0])  # Change the index to select a different combination

train_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/images/"
train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"

img_list = os.listdir(train_img_dir)  # Get the list of training image filenames
msk_list = os.listdir(train_mask_dir)  # Get the list of training mask filenames

num_images = len(img_list)  # Count the number of images

img_num = random.randint(0, num_images-1)  # Choose a random image index
test_img = np.load(train_img_dir + img_list[img_num])  # Load the chosen image
test_mask = np.load(train_mask_dir + msk_list[img_num])  # Load the corresponding mask
test_mask = np.argmax(test_mask, axis=3)  # Convert the one-hot encoded mask to a categorical mask

'''
n_slice = random.randint(0, test_mask.shape[2])  # Choose a random slice from the mask
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1ce')

plt.subplot(223)
plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.show()
'''

In [None]:
combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]

columns = ['0', '1', '2', '3']  # Define column names for the DataFrame
df = pd.DataFrame(columns=columns)  # Create an empty DataFrame with the specified columns

# Iterate over the combinations
for combo in combinations:
    combo_name = "_".join(combo)  # Join the combination elements with an underscore
    train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"
    train_mask_list = sorted(glob.glob(train_mask_dir + '*.npy'))  # Get a sorted list of mask filenames

    # Iterate over the mask filenames
    for img in range(len(train_mask_list)):
        print(img)  # Print the image index
        temp_image = np.load(train_mask_list[img])  # Load the mask as a NumPy array
        temp_image = np.argmax(temp_image, axis=3)  # Convert the one-hot encoded mask to a categorical mask
        val, counts = np.unique(temp_image, return_counts=True)  # Count the occurrences of each label
        zipped = zip(columns, counts)  # Zip the column names and counts together
        conts_dict = dict(zipped)  # Create a dictionary from the zipped values

        df = df.append(conts_dict, ignore_index=True)  # Append the counts dictionary as a new row to the DataFrame

# Calculate class weights based on label counts
label_0 = df['0'].sum()
label_1 = df['1'].sum()
label_2 = df['2'].sum()
label_3 = df['3'].sum()
total_labels = label_0 + label_1 + label_2 + label_3
n_classes = 4

# Class weights calculation: n_samples / (n_classes * n_samples_for_class)
wt0 = round((total_labels / (n_classes * label_0)), 2)  # Calculate class weight for label 0
wt1 = round((total_labels / (n_classes * label_1)), 2)  # Calculate class weight for label 1
wt2 = round((total_labels / (n_classes * label_2)), 2)  # Calculate class weight for label 2
wt3 = round((total_labels / (n_classes * label_3)), 2)  # Calculate class weight for label 3


In [None]:
# Define the image generators for training and validation

combinations = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]

# Select a combination for training and validation
combo_name = "_".join(combinations[0])  # Change the index to select a different combination

train_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/images/"
train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"

val_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/images/"
val_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/masks/"

train_img_list = os.listdir(train_img_dir)  # Get the list of training image filenames
train_mask_list = os.listdir(train_mask_dir)  # Get the list of training mask filenames

val_img_list = os.listdir(val_img_dir)  # Get the list of validation image filenames
val_mask_list = os.listdir(val_mask_dir)  # Get the list of validation mask filenames


In [None]:
batch_size = 8  # Define the batch size for the data generators

# Loop over combinations and create data generators
for combo in combinations:
    combo_name = "_".join(combo)  # Join the combination elements with an underscore
    train_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/images/"
    train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"
    val_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/images/"
    val_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/masks/"
    
    train_img_list = os.listdir(train_img_dir)  # Get the list of training image filenames
    train_mask_list = os.listdir(train_mask_dir)  # Get the list of training mask filenames
    val_img_list = os.listdir(val_img_dir)  # Get the list of validation image filenames
    val_mask_list = os.listdir(val_mask_dir)  # Get the list of validation mask filenames

    train_img_datagen = imageLoader(train_img_dir, train_img_list, 
                                    train_mask_dir, train_mask_list, batch_size)  # Create a data generator for training images

    val_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                    val_mask_dir, val_mask_list, batch_size)  # Create a data generator for validation images

    # Verify generator by getting a batch of images and masks
    img, msk = train_img_datagen.__next__()  # Get the next batch of images and masks

    img_num = random.randint(0,img.shape[0]-1)  # Choose a random image from the batch
    test_img = img[img_num]  # Select the chosen image
    test_mask = msk[img_num]  # Select the corresponding mask
    test_mask = np.argmax(test_mask, axis=3)  # Convert the one-hot encoded mask to a categorical mask

    '''
    n_slice = random.randint(0, test_mask.shape[2])  # Choose a random slice from the mask
    plt.figure(figsize=(12, 8))

    plt.subplot(221)
    plt.imshow(test_img[:,:,n_slice, 0], cmap='gray')
    plt.title('Image flair')
    plt.subplot(222)
    plt.imshow(test_img[:,:,n_slice, 1], cmap='gray')
    plt.title('Image t1ce')
    plt.subplot(223)
    plt.imshow(test_img[:,:,n_slice, 2], cmap='gray')
    plt.title('Image t2')
    plt.subplot(224)
    plt.imshow(test_mask[:,:,n_slice])
    plt.title('Mask')
    plt.show()
    '''


In [None]:
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(keras.__version__)


In [None]:

#!pip install segmentation-models-3D
import keras
from keras.callbacks import ModelCheckpoint
import segmentation_models_3D as sm

# Define the metrics to evaluate the model
metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

# Define the weights for the classes
wt0, wt1, wt2, wt3 = 0.25, 0.25, 0.25, 0.25  

dice_loss = sm.losses.DiceLoss(class_weights=np.array([wt0, wt1, wt2, wt3]))  # Define the Dice loss with class weights
focal_loss = sm.losses.CategoricalFocalLoss()  # Define the Categorical Focal loss
total_loss = dice_loss + (1 * focal_loss)  # Combine the losses
# Define the learning rate
LR = 0.0001  

# Define the optimizer
optim = keras.optimizers.Adam(LR)



# List of combinations for training
combo_list = [('t1',), ('t1ce',), ('t2',), ('flair',), ('t1', 't1ce'), ('t1', 't2'), ('t1', 'flair'), ('t1ce', 't2'), ('t1ce', 'flair'), ('t2', 'flair'), ('t1', 't1ce', 't2'), ('t1', 't1ce', 'flair'), ('t1', 't2', 'flair'), ('t1ce', 't2', 'flair'), ('t1', 't1ce', 't2', 'flair')]

for combo_name in combo_list:
    train_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/images/"
    train_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/train/masks/"
    val_img_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/images/"
    val_mask_dir = f"BraTS2021_TrainingData/input_data_128/{combo_name}/val/masks/"

    train_img_list = os.listdir(train_img_dir)  # Get the list of training image filenames
    train_mask_list = os.listdir(train_mask_dir)  # Get the list of training mask filenames
    val_img_list = os.listdir(val_img_dir)  # Get the list of validation image filenames
    val_mask_list = os.listdir(val_mask_dir)  # Get the list of validation mask filenames

    steps_per_epoch = len(train_img_list) // batch_size  # Calculate the number of steps per epoch
    val_steps_per_epoch = len(val_img_list) // batch_size  # Calculate the number of validation steps per epoch

    # Import the simple_3d_unet_model
    from simple_3d_unet import simple_unet_model  

    # Create the UNet model with the specified dimensions and number of channels
    model = simple_unet_model(IMG_HEIGHT=128, IMG_WIDTH=128, IMG_DEPTH=128, IMG_CHANNELS=3, num_classes=4)

    # Compile the model with the loss and metrics
    model.compile(optimizer=optim, loss=total_loss, metrics=metrics)

    print(model.summary())  # Print the model summary

    print(model.input_shape)  # Print the input shape of the model
    print(model.output_shape)  # Print the output shape of the model

    # Create a checkpoint to save the model with the best accuracy
    best_model_path = f'brats_3d_{combo_name}_best.hdf5'
    best_model_checkpoint = ModelCheckpoint(best_model_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

    # Fit the model
    history = model.fit(train_img_datagen,
                        steps_per_epoch=steps_per_epoch,
                        epochs=100,
                        verbose=1,
                        validation_data=val_img_datagen,
                        validation_steps=val_steps_per_epoch,
                        callbacks=[best_model_checkpoint]) # Add the checkpoint callback here

        # Define the file path for saving the logs
    log_file_path = f'training_logs_{combo_name}.txt'

    # Open the log file
    with open(log_file_path, 'w') as log_file:
        # Write the headers to the log file
        log_file.write('Epoch\tTrain_Loss\tTrain_Accuracy\tVal_Loss\tVal_Accuracy\n')

        # Loop over each epoch
        for epoch in range(len(history.history['loss'])):
            # Extract the values for this epoch
            train_loss = history.history['loss'][epoch]
            train_acc = history.history['accuracy'][epoch]
            val_loss = history.history['val_loss'][epoch]
            val_acc = history.history['val_accuracy'][epoch]

            # Write the values for this epoch to the log file
            log_file.write(f'{epoch+1}\t{train_loss}\t{train_acc}\t{val_loss}\t{val_acc}\n')
    
    
    
    # Save the final model
    final_model_path = f'brats_3d_{combo_name}_final.hdf5'
    model.save(final_model_path)


In [None]:
import os  # Import the os module

subfolder_name = 'data'  # Specify the name of the subfolder to delete

# Get the current directory
current_directory = os.getcwd()

# Construct the path to the subfolder
subfolder_path = os.path.join(current_directory, subfolder_name)

# Delete the subfolder and its contents recursively using shutil.rmtree()
shutil.rmtree(subfolder_path)


In [None]:
# Plot the training and validation IoU and loss at each epoch

loss = history.history['loss']  # Get the training loss history
val_loss = history.history['val_loss']  # Get the validation loss history
epochs = range(1, len(loss) + 1)  # Generate the x-axis values for epochs

# Plot the training and validation loss
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = history.history['accuracy']  # Get the training accuracy history
val_acc = history.history['val_accuracy']  # Get the validation accuracy history

# Plot the training and validation accuracy
plt.plot(epochs, acc, 'y', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


In [None]:
from keras.models import load_model  # Import the load_model function from keras.models

# Load the model for prediction or continue training
my_model = load_model('saved_models/brats_3d_100epochs_simple_unet_weighted_dice.hdf5', 
                      custom_objects={'dice_loss_plus_1focal_loss': total_loss,
                                      'iou_score':sm.metrics.IOUScore(threshold=0.5)})
# Load the saved model using load_model() and specify custom_objects for custom losses and metrics

# Now all set to continue the training process.
history2 = my_model.fit(train_img_datagen,
                        steps_per_epoch=steps_per_epoch,
                        epochs=1,
                        verbose=1,
                        validation_data=val_img_datagen,
                        validation_steps=val_steps_per_epoch)
# Continue training the model for 1 additional epoch using the loaded model and data generators


In [None]:
# For predictions, you do not need to compile the model, so ...
my_model = load_model('saved_models/brats_3d_100epochs_simple_unet_weighted_dice.hdf5', 
                      compile=False)
# Load the saved model without compiling it

# Verify IoU on a batch of images from the test dataset
# Using built-in Keras function for IoU
# Only works on TensorFlow > 2.0
from keras.metrics import MeanIoU  # Import the MeanIoU metric from keras.metrics

batch_size = 8  # Check IoU for a batch of images
test_img_datagen = imageLoader(val_img_dir, val_img_list, 
                                val_mask_dir, val_mask_list, batch_size)
# Create a data generator for test images and masks

test_image_batch, test_mask_batch = test_img_datagen.__next__()
# Generate a batch of test images and masks

test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)
test_pred_batch = my_model.predict(test_image_batch)
test_pred_batch_argmax = np.argmax(test_pred_batch, axis=4)
# Perform predictions on the test image batch and calculate the predicted mask

n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)  
IOU_keras.update_state(test_pred_batch_argmax, test_mask_batch_argmax)
# Calculate the mean IoU using the predicted masks and the ground truth masks
print("Mean IoU =", IOU_keras.result().numpy())

# Predict on a few test images, one at a time
# Try images:
img_num = 82

test_img = np.load("BraTS2021_TrainingData/input_data_128/val/images/image_"+str(img_num)+".npy")

test_mask = np.load("BraTS2021_TrainingData/input_data_128/val/masks/mask_"+str(img_num)+".npy")
test_mask_argmax = np.argmax(test_mask, axis=3)

test_img_input = np.expand_dims(test_img, axis=0)
test_prediction = my_model.predict(test_img_input)
test_prediction_argmax = np.argmax(test_prediction, axis=4)[0,:,:,:]
# Load a specific test image and its corresponding mask
# Perform prediction on the test image and calculate the predicted mask

# Plot individual slices from test predictions for verification
from matplotlib import pyplot as plt  # Import pyplot from matplotlib
import random

n_slice = 55  # Select a specific slice to plot
plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,n_slice,1], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(test_mask_argmax[:,:,n_slice])
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(test_prediction_argmax[:,:, n_slice])
plt.show()
# Plot the test image, ground truth mask, and predicted mask for visualization
