In [None]:
import os
import glob
import nibabel as nib

import numpy as np
import tensorflow as tf

from monai.utils import first

from monai.transforms import Compose, RandFlip, RandRotate, RandAdjustContrast, LoadImaged, Orientationd, RandSpatialCropd, ScaleIntensityRanged, Spacingd, Resized, ToTensord, ScaleIntensityd, RandRotated, Flipd, RandAffined



### create function that takes in scan

In [4]:
def load_data(path, patch_size=(64,64,64)):
    """
    patch_size = tuple containing patch size of images for training

    T1- native format 
    T1CE - contrast enhanced image of T1
    T2 - weighted
    T2 Fluid Attenuated Inversion Recovery FLAIR volumes
    """
    t1ce_all = []
    t1_all = []
    t2_all = []
    flair_all = []
    mask_all = []

    # search through MRI files and assign sample names to image files by labeling. 
    for mri_folder in os.listdir(path):
        file_path = os.path.join(path, mri_folder)

        #extracting t1ce for all 369 patients
        t1ce_files = glob.glob(os.path.join(file_path, "*_t1ce.nii")) #returns a list
        for t1ce_file in t1ce_files:
            t1ce_img = nib.load(t1ce_file)
            t1ce_img_data = t1ce_img.get_fdata()
            t1ce_all.append(t1ce_img_data)
        
        #extracting t2-weighted for all 369 patients
        t2_weighted_files = glob.glob(os.path.join(file_path, "*_t2.nii"))
        for t2_weighted_file in t2_weighted_files:
            t2_weighted_img = nib.load(t2_weighted_file)
            t2_weighted_img_data = t2_weighted_img.get_fdata()
            t2_all.append(t2_weighted_img_data)
    
        #extracting t2-flair data for all 369 patients       
        t2_flair_files = glob.glob(os.path.join(file_path, "*_flair.nii")) #returns a list
        for t2_flair_file in t2_flair_files:
            t2_flair_img = nib.load(t2_flair_file)
            t2_flair_img_data = t2_flair_img.get_fdata()
            #height, width, depth = t2_flair_img_data.shape
            #print(f'height = {height}, width = {width}, depth = {depth}')
            flair_all.append(t2_flair_img_data)
            
        #extracting t1 data for all 369 patients
        t1_files = glob.glob(os.path.join(file_path, "*_t1.nii"))
        for t1_file in t1_files:
            t1_img = nib.load(t1_file)
            t1_data = t1_img.get_fdata()
            t1_all.append(t1_data)
        
        mask_files = glob.glob(os.path.join(file_path, "*_seg.nii"))
        for mask_file in mask_files:
            mask_file_img = nib.load(mask_file)
            mask_file_img_data = mask_file_img.get_fdata()
            mask_all.append(mask_file_img_data)

    return mask_all,t1_all, t1ce_all, t2_all, flair_all

## Load all the scans

In [5]:
mask_data, t1_data, t1ce_data, t2_data, flair_data = load_data('/Users/dolan/Dropbox/MSSE/277B_ML/277B_final/277B_final/BraTS2020_scans')

## scan augmentation

we want to augment by doing both random rotations, random flips, and random contrast on the scans

random rotations and flips simulate patients being in slightly different positions in the scanner

random constrast simulates the fact that different scanners will have varying contrasts


then save them as 

patient_n_augmented (folder) #note that there is no t1_aug

-patient_n_t1ce_aug.nii

-patient_n_t2_aug.nii

-patient_n_flair_aug.nii

-patient_n_mask_aug.nii

In [12]:
# monai augmentation 
"""
generate_transforms = Compose(
    [
        LoadImaged 
        AddChanneled
        Spacingd
        Orientationd
        ScaleIntensityRanged
        RandAffined
        RandRotated
        ToTensord

    ]
)
"""

generate_transforms = Compose ([
    RandFlip(spatial_axis = (0,1,2), prob = 0.5), 
    RandRotate(range_x = 10, range_y = 10, prob = 0.5),
    RandAdjustContrast(prob = 0.5, gamma = (0.8,1.2))
])



In [None]:
'''
def augment_data(images, masks, transforms):
    """
    Apply augmentations to images and masks.
    
    Args:
        images (list): List of loaded images (e.g., NumPy arrays).
        masks (list): List of corresponding masks (e.g., NumPy arrays).
        transforms (Compose): MONAI Compose object with augmentations.
    
    Returns:
        list: Augmented images.
        list: Augmented masks.
    """
    augmented_images = []
    augmented_masks = []

    for img, mask in zip(images, masks):
        augmented = transforms({"image": img, "label": mask})
        augmented_images.append(augmented["image"])
        augmented_masks.append(augmented["label"])

    return augmented_images, augmented_masks
    '''


<class 'numpy.memmap'>


In [13]:
def augment_data(t1ce,t2,flair,mask):
    t1ce_aug_all = []
    t2_aug_all = []
    flair_aug_all = []
    mask_aug_all = []

    channel_dict = {"t1ce": t1ce, "t2": t2, "flair":flair, "mask":mask}
    augmented = generate_transforms(channel_dict)

    t1ce_aug_all.append(augmented["t1ce"])
    t2_aug_all.append(augmented["t2"])
    flair_aug_all.append(augmented["flair"])
    mask_aug_all.append(augmented["mask"])

    return t1ce_aug_all, t2_aug_all, flair_aug_all, mask_aug_all
    

In [None]:
t1ce_aug, t2_aug, flair_aug, mask_aug = augment_data(t1ce_data, t2_data, flair_data, mask_data)

: 

In [None]:
def augment_generator(t1ce,t2,flair,mask, batch_size):
    for i in range(0, len(t1ce), batch_size):
        t1ce_batch = t1ce[i:i+ batch_size]
        t2_batch = t2[i, i + batch_size]
        flair_batch = flair[i , i + batch_size]
        mask_batch = mask[i, i + batch_size]

        yield t1ce_batch, t2_batch, flair_batch, mask_batch

In [None]:
t1ce_aug = []
t2_aug = []
flair_aug = []
mask_aug = []
for t1ce_batch, t2_batch, flair_batch, mask_batch in augment_generator(t1ce_data, t2_data, flair_data, mask_data):
    
    