In [None]:
######################################################
# IMPORTS
######################################################

# Overall Imports
import random
import os
import time

# Computational Imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dataloading Imports
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
from sklearn.model_selection import StratifiedKFold, train_test_split

# Image Viewing Imports
import matplotlib.pyplot as plt
import skimage
from skimage.segmentation import mark_boundaries
import skimage.exposure as exposure
import SimpleITK as sitk

# Other Imports (don't seem necessary now but maybe)
# import pickle
# import cv2

In [None]:
######################################################
# HELPER FUNCTIONS
######################################################

# Loads a medical image from the specified path and converts it to a NumPy array
def load_img(img_path):
    img = sitk.ReadImage(img_path)
    img_np = sitk.GetArrayFromImage(img)
    return img_np

# Normalizes the image by scaling pixel values to [0, 100] and applies histogram equalization
def normalize(img_np):
    img_max = img_np[img_np != 0].max()
    img_min = img_np[img_np != 0].min()
    img_np[img_np != 0] = ((img_np[img_np != 0] - img_min) / ((img_max - img_min) + 1e-7))
    img_np *= 100
    img_np[img_np != 0] = exposure.equalize_hist(img_np[img_np != 0])
    return img_np

# Pads or crops the input image and its segmentation mask to the target size, using random slicing for cropping
def pad_or_crop_image(image, seg=None, target_size=(128, 144, 144), rng_generator=np.random.default_rng(0)):
    assert rng_generator is not None, "rng_generator must not be None"
    c, z, y, x = image.shape
    z_slice, y_slice, x_slice = [get_crop_slice(target, dim, rng_generator) for target, dim in zip(target_size, (z, y, x))]
    image = image[:, z_slice, y_slice, x_slice]
    if seg is not None:
        seg = seg[z_slice, y_slice, x_slice]
    todos = [get_left_right_idx_should_pad(size, dim, rng_generator) for size, dim in zip(target_size, [z, y, x])]
    padlist = [(0, 0)]  # channel dim
    for to_pad in todos:
        if to_pad[0]:
            padlist.append((to_pad[1], to_pad[2]))
        else:
            padlist.append((0, 0))
    image = np.pad(image, padlist)
    if seg is not None:
        seg = np.pad(seg, padlist[1:])
        return image, seg
    return image

# Determines whether padding is needed for an image dimension and generates random padding values for the left and right sides
def get_left_right_idx_should_pad(target_size, dim, rng_generator):
    if dim >= target_size:
        return [False]
    elif dim < target_size:
        pad_extent = target_size - dim
        # left = random.randint(0, pad_extent)
        left = rng_generator.integers(0, pad_extent + 1)
        right = pad_extent - left
        return True, left, right

# Computes a slice for cropping or retaining parts of the image to match the target size for a specific dimension
def get_crop_slice(target_size, dim, rng_generator):
    if dim > target_size:
        crop_extent = dim - target_size
        # left = random.randint(0, crop_extent)
        left = rng_generator.integers(0, crop_extent + 1)
        right = crop_extent - left
        return slice(left, dim - right)
    elif dim <= target_size:
        return slice(0, dim)

In [None]:
######################################################
# DATA PROCESSING FUNCTION
######################################################

def get_data_and_labels(data_dir, stop_idx=None, rng_generator=np.random.default_rng(0)):
    #####################################################################

    # Loads from name_mapping.csv the HGG/LGG and the Patient ID
    df = pd.read_csv(os.path.join(data_dir, "name_mapping.csv"))
    df = df[['Grade', 'BraTS_2020_subject_ID']]

    #####################################################################
    
    # "BraTS20_Training_001_flair.nii.gz" file paths looks like this! 
    # We are not dealing with the "seg" suffix for right now (others are valid images)
    data_suffixes = ["flair", "t1", "t1ce", "t2"]
    dir_names = [name for name in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, name))]
    total_dirs = len(dir_names)
    
    # Decides how many people to process and when to stop
    if stop_idx is None: stop_idx = total_dirs

    # Loop through and extract and process data
    data, labels, data_grades = [], [], []
    for dir_idx, dir_name in enumerate(dir_names):

        # Print the iteration we are on and the folder name
        print(f"Working on folder {dir_idx}/{total_dirs}: {dir_name}")

        # Stop early if we have reached the number of people to process
        if dir_idx == stop_idx: break
        
        # Extract their grade/label
        curr_df = df[df['BraTS_2020_subject_ID'] == dir_name]
        grade = curr_df.iloc[0]['Grade']
            
        # Extracts images of all data_suffix in this folder as a LIST
        data_images = []
        for data_suffix in data_suffixes:
            loaded_img = load_img(os.path.join(data_dir, dir_name, dir_name + "_" + data_suffix + ".nii.gz"))
            data_images.append(loaded_img)
        
        # Stacks the list as a numpy array
        stacked_data_images = np.stack(data_images)
        stacked_data_images = stacked_data_images.astype('float32')
        stacked_data_images = stacked_data_images[:, 30:-30, :, :] # crops in the z direction
        # This one is handled different perhaps?
        loaded_label = load_img(os.path.join(data_dir, dir_name, dir_name + "_seg.nii.gz"))
        loaded_label = loaded_label[30:-30, :, :]
        
        if stacked_data_images.any() != 0:
            # This part of the code is heavily based on https://github.com/lescientifik/open_brats2020
            ############################################################
            # Crop at boundaries of brain
            z_idxs, y_idxs, x_idxs = np.nonzero(np.sum(stacked_data_images, axis=0))
            z_min, y_min, x_min = [max(0, int(np.min(arr) - 1)) for arr in (z_idxs, y_idxs, x_idxs)]
            z_max, y_max, x_max = [int(np.max(arr) + 1) for arr in (z_idxs, y_idxs, x_idxs)]
            cropped_data = stacked_data_images[:, z_min:z_max, y_min:y_max, x_min:x_max]
            cropped_label = loaded_label[z_min:z_max, y_min:y_max, x_min:x_max]
            # Clip at 1 and 99 percentiles
            zero_indices = cropped_data == 0
            percentiles = np.percentile(cropped_data[np.nonzero(cropped_data)], [1, 99])
            clipped_data = np.clip(cropped_data, percentiles[0], percentiles[1])
            clipped_data[zero_indices] = 0
            # Normalise and Reshape
            normalized_data = normalize(clipped_data)
            reshaped_data, reshaped_label = pad_or_crop_image(normalized_data, cropped_label, target_size=(normalized_data.shape[1], 128, 128), rng_generator=rng_generator)
            ############################################################

            # Append everything that will be returned: images, seg_notes, and HGG/LGG info
            if reshaped_data.shape[1:] == (95, 128, 128):
                data.append(reshaped_data)
                labels.append(reshaped_label)
                curr_grade_val = 0 if grade == 'LGG' else 1
                data_grades.append([curr_grade_val]*reshaped_data.shape[0])
            else:
                print("reshaped_data does not have the expected shape. Not appending.")
                print(dir_name, reshaped_data.shape)
                input()
            
    return np.array(data), np.array(labels), np.array(data_grades)

In [None]:
######################################################
# UNLOAD AND PROCESS DATA (and save!)
######################################################

# Path to BraTS 2020 dataset folder
data_dir = "/Users/felicialiu/Desktop/ESC499/Code/MICCAI_BraTS2020_TrainingData"

# Run the function to extract all data and SAVE (no need to process again later)
data, labels, data_grades = get_data_and_labels(data_dir)

# Save the processed numpy data so we don't need to process over and over
processed_data_path = "/Users/felicialiu/Desktop/ESC499/Code/PROCESSED_DATA_FULL.npz"
np.savez(processed_data_path, data=data, labels=labels, data_grades=data_grades)

Working on folder 0/369: BraTS20_Training_082
Working on folder 1/369: BraTS20_Training_244
Working on folder 2/369: BraTS20_Training_076
Working on folder 3/369: BraTS20_Training_049
Working on folder 4/369: BraTS20_Training_071
Working on folder 5/369: BraTS20_Training_243
Working on folder 6/369: BraTS20_Training_085
Working on folder 7/369: BraTS20_Training_288
Working on folder 8/369: BraTS20_Training_047
Working on folder 9/369: BraTS20_Training_275
Working on folder 10/369: BraTS20_Training_281
Working on folder 11/369: BraTS20_Training_078
Working on folder 12/369: BraTS20_Training_286
Working on folder 13/369: BraTS20_Training_272
Working on folder 14/369: BraTS20_Training_040
Working on folder 15/369: BraTS20_Training_219
Working on folder 16/369: BraTS20_Training_014
Working on folder 17/369: BraTS20_Training_226
Working on folder 18/369: BraTS20_Training_221
Working on folder 19/369: BraTS20_Training_013
Working on folder 20/369: BraTS20_Training_228
Working on folder 21/36

In [None]:
######################################################
# (reload) SPLIT DATA (and save!)
######################################################

# Reload the data from the .npz file
processed_data_path = "/Users/felicialiu/Desktop/ESC499/Code/PROCESSED_DATA_FULL.npz"
loaded_data = np.load(processed_data_path)
data, labels, data_grades = loaded_data['data'], loaded_data['labels'], loaded_data['data_grades']

print("All Unloaded Data:", data.shape, data_grades.shape)

# Split a fixed percentage of data for the test set
test_size = 0.15  # Adjust this percentage as needed
trainval_data, test_data, trainval_grades, test_grades = train_test_split(
    data, data_grades, test_size=test_size, stratify=data_grades, random_state=42)

def reshape_datagrades_voxellabels(data, grades, name, n_splits=5):
    # Display initial shapes
    print(f"\nOG {name} Data:", data.shape, grades.shape)
    # Reshape the data so that each scan is its own datapoint
    num_patients, images_per_patient = data.shape[:2]
    voxels = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1]) 
    labels = grades.flatten()
    # Create patient-level labels to keep images grouped by patient
    patient_labels = np.arange(num_patients)  # A unique label per patient
    patient_grades = grades[:, 0]  # Use the first grade per patient for stratification
    # For TRAIN AND VAL data now create some stratified splits
    train_inds, val_inds = [], []
    if name == "TRAIN AND VAL":
        skf = StratifiedKFold(n_splits=n_splits)
        split_results = skf.split(patient_labels, patient_grades)
        for i, (train_patient_index, val_patient_index) in enumerate(split_results):
            train_image_index = np.where(np.isin(np.repeat(patient_labels, images_per_patient), train_patient_index))[0]
            val_image_index = np.where(np.isin(np.repeat(patient_labels, images_per_patient), val_patient_index))[0]
            train_inds.append(train_image_index)
            val_inds.append(val_image_index)
    # Display flattened shapes
    print(f"Flattened {name} Data:", voxels.shape, labels.shape)
    return voxels, labels, train_inds, val_inds

voxels, labels, train_inds, val_inds = reshape_datagrades_voxellabels(trainval_data, trainval_grades, "TRAIN AND VAL", n_splits=5)
tvoxels, tlabels, _, _ = reshape_datagrades_voxellabels(test_data, test_grades, "TEST", n_splits=5)

# Save it to reload as a checkpoint!
split_data_path = "/Users/felicialiu/Desktop/ESC499/Code/SPLIT_DATA_FULL.npz"
np.savez(split_data_path, 
         voxels=voxels, 
         labels=labels, 
         tvoxels=tvoxels, 
         tlabels=tlabels, 
         **{f'train_inds_{i}': arr for i, arr in enumerate(train_inds)},  # Unpack train_inds
         **{f'val_inds_{i}': arr for i, arr in enumerate(val_inds)})  # Unpack val_inds

All Unloaded Data: (365, 4, 95, 128, 128) (365, 4)

OG TRAIN AND VAL Data: (310, 4, 95, 128, 128) (310, 4)
Flattened TRAIN AND VAL Data: (1240, 95, 128, 128) (1240,)

OG TEST Data: (55, 4, 95, 128, 128) (55, 4)
Flattened TEST Data: (220, 95, 128, 128) (220,)


In [None]:
# Reload the data from the .npz file
processed_data_path = "/Users/felicialiu/Desktop/ESC499/Code/PROCESSED_DATA_FULL.npz"
loaded_data = np.load(processed_data_path)
data, seg_labels, data_grades = loaded_data['data'], loaded_data['labels'], loaded_data['data_grades']

print("All Unloaded Data:", data.shape, seg_labels.shape, data_grades.shape)

# Split a fixed percentage of data for the test set
test_size = 0.15  # Adjust this percentage as needed
trainval_segs, test_segs, trainval_grades, test_grades = train_test_split(
    seg_labels, data_grades, test_size=test_size, stratify=data_grades, random_state=42)

def create_seg_levels(segmentation_maps):
    all_segmentations = []
    for segmentation in segmentation_maps:
        glioma_slices = []
        for index, slice in enumerate(segmentation):
            if np.sum(slice) > 6:
                glioma_slices.append(index)
        all_segmentations.append(np.array([glioma_slices[0], glioma_slices[-1]]))
        all_segmentations.append(np.array([glioma_slices[0], glioma_slices[-1]]))
        all_segmentations.append(np.array([glioma_slices[0], glioma_slices[-1]]))
        all_segmentations.append(np.array([glioma_slices[0], glioma_slices[-1]]))
    return np.array(all_segmentations)

segs = create_seg_levels(trainval_segs)
print(segs.shape)

tsegs = create_seg_levels(test_segs)
print(tsegs.shape)

# Save it to reload as a checkpoint!
split_segs_path = "/Users/felicialiu/Desktop/ESC499/Code/SPLIT_SEGS_FULL.npz"
np.savez(split_segs_path, 
         segs=segs, 
         tsegs=tsegs)

In [None]:
######################################################
# (reload) CHECK RELOADED DATA IS ALL GOOD!
######################################################

# Reload from the files of split data
split_data_path = "/Users/felicialiu/Desktop/ESC499/Code/SPLIT_DATA.npz"
loaded_data = np.load(split_data_path, allow_pickle=True)
voxels_reloaded = loaded_data['voxels']
labels_reloaded = loaded_data['labels']
tvoxels_reloaded = loaded_data['tvoxels']
tlabels_reloaded = loaded_data['tlabels']
num_folds = (len(loaded_data.files)-4)//2
train_inds_reloaded = [loaded_data[f'train_inds_{i}'] for i in range(num_folds)]
val_inds_reloaded = [loaded_data[f'val_inds_{i}'] for i in range(num_folds)]

print(f"TRAIN/VAL DATA: {voxels_reloaded.shape} {labels_reloaded.shape}")
print(f"TEST DATA:      {tvoxels_reloaded.shape} {tlabels_reloaded.shape}")
for i in range(len(train_inds_reloaded)):
    print(f"fold #{i}: train/val [{len(train_inds_reloaded[i])}/{len(val_inds_reloaded[i])}]")

TRAIN/VAL DATA: (1240, 95, 128, 128) (1240,)
TEST DATA:      (220, 95, 128, 128) (220,)
fold #0: train/val [992/248]
fold #1: train/val [992/248]
fold #2: train/val [992/248]
fold #3: train/val [992/248]
fold #4: train/val [992/248]


In [None]:
######################################################
# (reload) CHECK RELOADED DATA IS ALL GOOD!
######################################################

# Reload from the files of split data
split_data_path = "/Users/felicialiu/Desktop/ESC499/Code/SPLIT_DATA_FULL.npz"
split_segs_path = "/Users/felicialiu/Desktop/ESC499/Code/SPLIT_SEGS_FULL.npz"
loaded_data = np.load(split_data_path, allow_pickle=True)
loaded_segs = np.load(split_segs_path, allow_pickle=True)
voxels_reloaded = loaded_data['voxels']
labels_reloaded = loaded_data['labels']
segs_reloaded = loaded_segs['segs']
tvoxels_reloaded = loaded_data['tvoxels']
tlabels_reloaded = loaded_data['tlabels']
tsegs_reloaded = loaded_segs['tsegs']
num_folds = (len(loaded_data.files)-4)//2
train_inds_reloaded = [loaded_data[f'train_inds_{i}'] for i in range(num_folds)]
val_inds_reloaded = [loaded_data[f'val_inds_{i}'] for i in range(num_folds)]

print(f"TRAIN/VAL DATA: {voxels_reloaded.shape} {segs_reloaded.shape} {labels_reloaded.shape}")
print(f"TEST DATA:      {tvoxels_reloaded.shape} {tsegs_reloaded.shape} {tlabels_reloaded.shape}")
for i in range(len(train_inds_reloaded)):
    print(f"fold #{i}: train/val [{len(train_inds_reloaded[i])}/{len(val_inds_reloaded[i])}]")

In [None]:
######################################################
# STATS of the DATASET
######################################################

# These patients all didn't have the (4, 95, 128, 128) datastructure (hence are currently not included)
# BraTS20_Training_349 (4, 94, 128, 128)
# BraTS20_Training_360 (4, 85, 128, 128)
# BraTS20_Training_366 (4, 87, 128, 128)
# BraTS20_Training_357 (4, 85, 128, 128)

# All Unloaded Data: (365, 4, 95, 128, 128) (365, 4)

# OG TRAIN AND VAL Data: (310, 4, 95, 128, 128) (310, 4)
# Flattened TRAIN AND VAL Data: (1240, 95, 128, 128) (1240,)

# OG TEST Data: (55, 4, 95, 128, 128) (55, 4)
# Flattened TEST Data: (220, 95, 128, 128) (220,)

In [None]:
######################################################
# OTHER INFO FROM JAY
######################################################

### From what I remember, the pad_or_crop_image() function is bugged in that the rng_generator 
# does not guarantee consistent padding and cropping across multiple runs. So you whether have 
# to save the data after generation or fix the bug so that rng is preserved throughout multiple runs

# ### data_grades is the labels you're interested in. The "labels" variable are segmentation annotations 
# that are not relevant to you at the moment, but keep the code just in case your thesis enters segmentation 
# eventually (probably not though)

# image_zero_threshold = 0.5

    # '''
    # data = np.concatenate(data, axis=0)
    # labels = np.concatenate(labels, axis=0)
    # data_grades = np.array(data_grades)
            
    # idxs_to_delete = []
    # for data_idx in range(len(data)):
    #     if np.count_nonzero(data[data_idx] == 0) / data[data_idx].size > image_zero_threshold: 
    #         raise
    #         idxs_to_delete.append(data_idx)
    # data = np.delete(data, idxs_to_delete, axis=0)
    # labels = np.delete(labels, idxs_to_delete, axis=0)
    # data_grades = np.delete(data_grades, idxs_to_delete, axis=0)
    # '''
    ###################################