In [1]:
import os
import sys
import numpy as np
import nibabel as nib
import shutil
import torch

from functools import partial

sys.path.append('/home/edwardsb/repositories/private-mri-sandbox/2D_data_creation')
from nifti_utils import load_nifti, save_nifti, visualize_full_3d_channel


In [2]:
#################################
#################################
################################

In [3]:
# Folder holding the BraTS samples

allow_deletion = True

# ONE BELOW IS THE MAIN ONE USED
source_dirpath = '/raid/datasets/BraTS22/BraTS2022_Training/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021'
# source_dirpath = '/raid/edwardsb/projects/RANO/FAKE_withlabel3_ExampleBraTS22_Labels0-4'

# destination folder to hold the medperf postopp style data format
# (should be cleared of all items before re-running if already have run once)


# ONE BELOW IS MAIN ONE USED
# postopp_pardir = '/raid/edwardsb/projects/RANO/test_data_links_micahtest'
# postopp_pardir = '/raid/edwardsb/projects/RANO/test_data_links_testforhasan'

# working to create mulitple collaborator data paths
num_cols = 5
third_channel = 'square' # 'thresholdbrainandsquare' # 'square'  'thresholdbrain'
size='hundred'  #'thousand'

if size == 'thousand':
    samples_per_col = 200
elif size == 'hundred':
    samples_per_col = 20
else:
    raise ValueError(f"Size not supported.")


sort_by_tumor_size = False


if sort_by_tumor_size:
    tag='sorted'
else:
    tag=''




postopp_pardirs = {col_num: f'/raid/edwardsb/projects/RANO/test_{size}_BraTS20_3{third_channel + tag}_{col_num}' for col_num in range(num_cols)}

for path in postopp_pardirs.values():
    if os.path.exists(path):
        if not allow_deletion: 
            raise ValueError(f"You need to delete folders such as: {path}")
        else:
            shutil.rmtree(path)


In [4]:
src_mods = ['t1.nii.gz', 't2.nii.gz', 't1ce.nii.gz', 'flair.nii.gz', 'seg.nii.gz']
dst_mods = ['t1n.nii.gz', 't2w.nii.gz', 't1c.nii.gz', 't2f.nii.gz', 'final_seg.nii.gz']

flair_mod_src = src_mods[-2]
seg_mod_src = src_mods[-1]

flair_mod_dst = dst_mods[-2]
seg_mod_dst = dst_mods[-1]

im_mods_src = src_mods[:-1]

In [5]:


"""
Creating some data links that mimic the folder structure expected after data prep

"""


# Commenting this out for now below that creates the data, running all in the notebook in order to visualize some samples for various institutions (thresholds)




def get_patient_flair_intensity_range(patient, pardir):
    flair_path = os.path.join(pardir, patient, patient + '_' + flair_mod_src)
    flair_array = load_nifti(flair_path)
    return np.amin(flair_array).item(), np.amax(flair_array).item()


def get_flair_intensity_range(patients, pardir):
    overall_min = np.inf
    overall_max = -np.inf
    for patient in patients:
        single_min, single_max = get_patient_flair_intensity_range(patient=patient, pardir=pardir)
        if single_min < overall_min:
            overall_min = single_min
        if single_max > overall_max:
            overall_max = single_max
    return overall_min, overall_max

def place_class_three(data):
    # finds a spot that will not be cropped away to place class 3 so that we have at least one class three label
    from scipy.ndimage import binary_fill_holes
    nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
    for c in range(data.shape[0]):
        this_mask = data[c] != 0
        nonzero_mask = nonzero_mask | this_mask
    nonzero_mask = binary_fill_holes(nonzero_mask)
    mask = nonzero_mask

    mask_voxel_coords = np.where(mask != 0)
    minzidx = int(np.min(mask_voxel_coords[0]))
    maxzidx = int(np.max(mask_voxel_coords[0])) + 1
    minxidx = int(np.min(mask_voxel_coords[1]))
    maxxidx = int(np.max(mask_voxel_coords[1])) + 1
    minyidx = int(np.min(mask_voxel_coords[2]))
    maxyidx = int(np.max(mask_voxel_coords[2])) + 1

    # now pick coordinates within these limits (these limits are the ones NNUnet will use to crop)

    delx = maxxidx - minxidx
    dely = maxyidx - minyidx
    delz = maxzidx - minzidx

    if (delx < 2) or (dely < 2) or (delz < 2):
        raise ValueError(f"Surprised, but evidently this brain is flat with delx:{delx}, dely:{dely} and delz:{delz}.")

    _x = minxidx + int(np.floor(delx/2))
    _y = minyidx + int(np.floor(dely/2))
    _z = minzidx + int(np.floor(delz/2))

    return _x, _y, _z

bigger_patient_pool = list(os.listdir(source_dirpath))

# Here we order by tumor size if indicated by sort_by_tumor_size variable above

def get_tumor_size(patient, pardir):
    seg_path = os.path.join(pardir, patient, patient + '_' + seg_mod_src)
    label_array = load_nifti(seg_path)
    return np.sum(label_array).item()

if sort_by_tumor_size:
    bigger_patient_pool.sort(key=partial(get_tumor_size, pardir=source_dirpath))
    
    # thin out so as to keep full range of sizes (make sure you have enough to grab samples_per_col * num_cols)
    step = int(  np.floor( len(bigger_patient_pool)/(samples_per_col * num_cols) )  )
    bigger_patient_pool = bigger_patient_pool[0::step]

if 'thresholdbrain' in third_channel:
    # Here we profile image intensities
    min_intensity, max_intensity = get_flair_intensity_range(patients=bigger_patient_pool, pardir=source_dirpath)

    delta_intensity = (max_intensity - min_intensity) / (num_cols + 1)
    intensity_thresholds = [min_intensity + (idx * delta_intensity) for idx in range(num_cols)]

    print(f"Intensities were min: {min_intensity}, max:{max_intensity} making thresholds: {intensity_thresholds}")


for col_num in range(num_cols):
    patients = bigger_patient_pool[col_num * samples_per_col:(col_num + 1) * samples_per_col]
    postopp_pardir = postopp_pardirs[col_num]
    
    for src_mod, dst_mod in zip(src_mods, dst_mods):
        if src_mod == 'seg.nii.gz':
            subdir = 'labels'
        else:
            subdir = 'data'
        
        for patient in patients:
            dst_pat_dir = os.path.join(postopp_pardir, subdir, patient)
            if not os.path.exists(dst_pat_dir):
                os.makedirs(dst_pat_dir)
            src_pat_dir = os.path.join(source_dirpath, patient)
            for timestamp in ['2008.03.88', '2008.12.99']:
                dst_pattim_dir = os.path.join(dst_pat_dir, timestamp)
                if not os.path.exists(dst_pattim_dir):
                    os.makedirs(dst_pattim_dir)
        
                


                src_path = os.path.join(src_pat_dir, patient + '_' + src_mod)
                if dst_mod == 'final_seg.nii.gz':
                    dst_path = os.path.join(dst_pattim_dir, patient + '_' + timestamp + '_' + dst_mod)
                    label_array = load_nifti(src_path)
                    # writing in label 3 parts so as there is no missing label
                    # shape of this guy is: 240, 240, 155, and the unique label values are: 0.0, 1.0, 2.0, 4.0
                    if third_channel == 'square':
                        label_array[110:130,110:130,70:80] = 3.0
                    elif third_channel == 'thresholdbrain':
                        threshold = intensity_thresholds[col_num]
                        flair_path = os.path.join(src_pat_dir, patient + '_' + flair_mod_src)
                        flair_array = load_nifti(flair_path)
                        mask = flair_array > threshold
                        label_array[mask] = 3.0
                    elif third_channel == 'thresholdbrainandsquare':
                        threshold = intensity_thresholds[col_num]
                        im_paths_src = [os.path.join(src_pat_dir, patient + '_' + _mod) for _mod in im_mods_src]
                        im_array = np.stack([load_nifti(im_path) for im_path in im_paths_src], axis=0)
                        flair_array = load_nifti(im_paths_src[-1])
                        # zero labels within brain intensity threshold
                        mask = np.logical_and((flair_array>threshold), (label_array==0.0))
                        # within the zero plain bounds of all channels we place a 3 label to be sure we have at least one
                        _x, _y, _z = place_class_three(data=im_array)
                        mask[_x, _y, _z] = True
                        label_array[mask] = 3.0

                    else:
                        raise ValueError(f"third_channel:{third_channel} is not a value we recognize (only supporting 'sqaure', 'thresholdbrain', and 'thresholdbrainandsquare').")
                    save_nifti(array=label_array, path=dst_path, metadata_from=nib.load(src_path))
                else:
                    dst_path = os.path.join(dst_pattim_dir, patient + '_' + timestamp + '_brain_' + dst_mod)
                shutil.copyfile(src=src_path, dst=dst_path)




In [6]:
inst = 1
pat_num = 1
interval = 200

pats = os.listdir(os.path.join(postopp_pardirs[inst],'data'))
example_pat = pats[pat_num]

time = '2008.03.88'

In [7]:
example_featpath= os.path.join(postopp_pardirs[inst],'data', example_pat, time, example_pat + '_' + time + '_brain_' + flair_mod_dst)
example_labpath = os.path.join(postopp_pardirs[inst],'labels', example_pat, time, example_pat + '_' + time + '_' + seg_mod_dst)

In [8]:
example_featpath, example_labpath

('/raid/edwardsb/projects/RANO/test_hundred_BraTS20_3square_1/data/BraTS2021_00351/2008.03.88/BraTS2021_00351_2008.03.88_brain_t2f.nii.gz',
 '/raid/edwardsb/projects/RANO/test_hundred_BraTS20_3square_1/labels/BraTS2021_00351/2008.03.88/BraTS2021_00351_2008.03.88_final_seg.nii.gz')

In [9]:
from IPython.display import HTML
%matplotlib inline

In [10]:
anim = visualize_full_3d_channel(path=example_featpath, interval=interval)
HTML(anim.to_html5_video())

Video min:0.0, max:671.0 corner:0.0


In [11]:
anim = visualize_full_3d_channel(path=example_labpath, interval=interval)
HTML(anim.to_html5_video())

Video min:0.0, max:4.0 corner:0.0
