# Slicing Techniques

In this notebook we will pre-process the images with different slicing techniques.

#### Imports

In [None]:
import os
import random
import glob
import re

import pandas as pd

import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import cv2

import albumentations as A

from tqdm import tqdm

import wandb

import imageio

import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

#### Seed

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)

set_seed(42)

#### Load Data Set and Select Patients

In [None]:
df_train = pd.read_csv("../train_labels.csv")
sample_patients = ['00002', '00457', '00601', '00003', '00222', '00397', '00121', '00804', '00266', '00581']

## 1. No Pre-Processing

#### Config

In [None]:
PATH = '..'

config = dict(
    # Pre-processing
    REMOVE_BLACK_BOUNDARIES = False,
    DICOM=False,
    
    # Albumentation
    RRC_SIZE = 256,
    RRC_MIN_SCALE = 0.85,
    RRC_RATIO = (1., 1.),
    CLAHE_CLIP_LIMIT = 2.0,
    CLAHE_TILE_GRID_SIZE = (8, 8),
    CLAHE_PROB = 0.50,
    BRIGHTNESS_LIMIT = (-0.2,0.2),
    BRIGHTNESS_PROB = 0.40,
    HUE_SHIFT = (-15, 15),
    SAT_SHIFT = (-15, 15),
    VAL_SHIFT = (-15, 15),
    HUE_PROB = 0.64,
    COARSE_MAX_HOLES = 16,
    COARSE_PROB = 0.7,
    
    # Logging
    VERBOSE = False,
    NAME = "00_EDA_Slicing"
)

#### wandb

In [None]:
wandb.login()
run = wandb.init(entity="uzk-wim", project='rsna-miccai-slicing', config=config, mode="online")
config = wandb.config
wandb.run.name = f"{config.NAME}"

### 1. Loading Images

#### 1.1 Utilities

In [None]:
train_transform = A.Compose([
    A.RandomResizedCrop(
        config.RRC_SIZE, config.RRC_SIZE,            
        scale=(config.RRC_MIN_SCALE, 1.0),
        ratio=config.RRC_RATIO,
        p=1.0
    ),
    A.CLAHE(
        clip_limit=config.CLAHE_CLIP_LIMIT,
        tile_grid_size=config.CLAHE_TILE_GRID_SIZE,
        p=config.CLAHE_PROB
    ),
    A.RandomBrightnessContrast(
        brightness_limit=config.BRIGHTNESS_LIMIT,
        p=config.BRIGHTNESS_PROB
    ),
    A.HueSaturationValue(
        hue_shift_limit=config.HUE_SHIFT, 
        sat_shift_limit=config.SAT_SHIFT, 
        val_shift_limit=config.VAL_SHIFT, 
        p=config.HUE_PROB
    ),
    A.CoarseDropout(
        max_holes=config.COARSE_MAX_HOLES,
        p=config.COARSE_PROB),
])

valid_transform = A.Compose([
    A.RandomResizedCrop( 
        config.RRC_SIZE, config.RRC_SIZE,            
        scale=(config.RRC_MIN_SCALE, 1.0),
        ratio=config.RRC_RATIO,
        p=1.0
    )
])

In [None]:
def dicom_2_image(file, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(file)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        img = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        img = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    return img

def remove_black_boundaries(img):
    (x, y) = np.where(img > 0)
    if len(x) > 0 and len(y) > 0:
        x_mn = np.min(x)
        x_mx = np.max(x)
        y_mn = np.min(y)
        y_mx = np.max(y)
        if (x_mx - x_mn) > 10 and (y_mx - y_mn) > 10:
            img = img[:,np.min(y):np.max(y)]
    return img

def get_3d_image(mri_type, aug, dicom):
    img_depth = len(mri_type)
    minimum_idx = 0
    maximum_idx = img_depth
    step = 1
    # Create list which contains all the 2D images which form the 3D image
    mri_img = []
    for i in range(minimum_idx, maximum_idx, step):
        file = mri_type[i]
        if dicom:
            img = dicom_2_image(file)
        else:
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        # Remove black boundaries
        if config.REMOVE_BLACK_BOUNDARIES:
            img = remove_black_boundaries(img)
        if aug:
            transformed = train_transform(image=img)
            img = transformed["image"]
        else:
            transformed = valid_transform(image=img)
            img = transformed["image"]
        mri_img.append(np.array(img))
    mri_img = np.rollaxis(np.array(mri_img), 0, 3) # From depthx256x256 to 256x256xdepth
    if config.VERBOSE:
        print(f"Shape of mri_img: {mri_img.shape}")
    return mri_img

def load_images(scan_id, mri_type, aug=True, split="train", dicom=False):
    file_ext = "png"
    if dicom:
        file_ext = "dcm"
    if config.VERBOSE:
        print(f"Scan id {scan_id}")
        
    # Ascending sort
    if mri_type == "FLAIR":
        flair = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/FLAIR/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(flair, aug, dicom)
    elif mri_type == "T1w":
        t1w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1w, aug, dicom)
    elif mri_type == "T1wCE":
        t1wce = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1wCE/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1wce, aug, dicom)
    else:
        t2w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T2w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t2w, aug, dicom)
    
    # Return 3D image: WidthxHeightxDepth
    # Data type: uint8
    return img

#### 1.2 Load, Save & Store Images (in Weights and Biases) 

In [None]:
data_no_prep = wandb.Table(columns=['patient_id', 'target', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])
path = "../tmp/tmp-gifs-no-prep"

for patient in sample_patients:
    img_flair =load_images(patient, mri_type="FLAIR", aug=False, dicom=config.DICOM)
    img_t1w =load_images(patient, mri_type="T1w", aug=False, dicom=config.DICOM)
    img_t1wce =load_images(patient, mri_type="T1wCE", aug=False, dicom=config.DICOM)
    img_t2w =load_images(patient, mri_type="T2w", aug=False, dicom=config.DICOM)
    
    imgs_flair = []
    for i in range(img_flair.shape[2]):
        imgs_flair.append(img_flair[:,:,i])
    imageio.mimsave(f'{path}/{patient}_flair.gif', imgs_flair)
        
    imgs_t1w = []
    for i in range(img_t1w.shape[2]):
        imgs_t1w.append(img_t1w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1w.gif', imgs_t1w)
        
    imgs_t1wce = []
    for i in range(img_t1wce.shape[2]):
        imgs_t1wce.append(img_t1wce[:,:,i])
    imageio.mimsave(f{path}/{patient}_t1wce.gif', imgs_t1wce)
        
    imgs_t2w = []
    for i in range(img_t2w.shape[2]):
        imgs_t2w.append(img_t2w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t2w.gif', imgs_t2w)
    
    data_no_prep.add_data(int(patient),                                            
                      df_train.loc[df_train.BraTS21ID == int(patient)].MGMT_value.values[0],
                      wandb.Image(f'{path}/{patient}_flair.gif'),
                      wandb.Image(f'{path}/{patient}_t1w.gif'),
                      wandb.Image(f'{path}/{patient}_t1wce.gif'),
                      wandb.Image(f'{path}/{patient}_t2w.gif'))
    
wandb.log({'No Pre-Processing Samples': data_no_prep})

In [None]:
wandb.finish()

## 2. Remove Black Boundaries

#### Config

In [None]:
PATH = '..'

config = dict(
    # Pre-processing
    REMOVE_BLACK_BOUNDARIES = True,
    DICOM=False,
    
    # Albumentation
    RRC_SIZE = 256,
    RRC_MIN_SCALE = 0.85,
    RRC_RATIO = (1., 1.),
    CLAHE_CLIP_LIMIT = 2.0,
    CLAHE_TILE_GRID_SIZE = (8, 8),
    CLAHE_PROB = 0.50,
    BRIGHTNESS_LIMIT = (-0.2,0.2),
    BRIGHTNESS_PROB = 0.40,
    HUE_SHIFT = (-15, 15),
    SAT_SHIFT = (-15, 15),
    VAL_SHIFT = (-15, 15),
    HUE_PROB = 0.64,
    COARSE_MAX_HOLES = 16,
    COARSE_PROB = 0.7,
    
    # Logging
    VERBOSE = False,
    NAME = "00_EDA_Slicing-Remove-Black"
)

#### wandb

In [None]:
run = wandb.init(entity="uzk-wim", project='rsna-miccai-slicing', config=config, mode="online")
config = wandb.config
wandb.run.name = f"{config.NAME}"

### 1. Loading Images

#### 1.1 Utilities

In [None]:
def dicom_2_image(file, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(file)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        img = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        img = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    return img

def remove_black_boundaries(img):
    (x, y) = np.where(img > 0)
    if len(x) > 0 and len(y) > 0:
        x_mn = np.min(x)
        x_mx = np.max(x)
        y_mn = np.min(y)
        y_mx = np.max(y)
        if (x_mx - x_mn) > 10 and (y_mx - y_mn) > 10:
            img = img[:,np.min(y):np.max(y)]
    return img

def get_3d_image(mri_type, aug, dicom):
    img_depth = len(mri_type)
    minimum_idx = 0
    maximum_idx = img_depth
    step = 1
    # Create list which contains all the 2D images which form the 3D image
    mri_img = []
    for i in range(minimum_idx, maximum_idx, step):
        file = mri_type[i]
        if dicom:
            img = dicom_2_image(file)
        else:
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        # Remove black boundaries
        if config.REMOVE_BLACK_BOUNDARIES:
            img = remove_black_boundaries(img)
        if aug:
            transformed = train_transform(image=img)
            img = transformed["image"]
        else:
            transformed = valid_transform(image=img)
            img = transformed["image"]
        mri_img.append(np.array(img))
    mri_img = np.rollaxis(np.array(mri_img), 0, 3) # From depthx256x256 to 256x256xdepth
    if config.VERBOSE:
        print(f"Shape of mri_img: {mri_img.shape}")
    return mri_img

def load_images(scan_id, mri_type, aug=True, split="train", dicom=False):
    file_ext = "png"
    if dicom:
        file_ext = "dcm"
    if config.VERBOSE:
        print(f"Scan id {scan_id}")
        
    # Ascending sort
    if mri_type == "FLAIR":
        flair = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/FLAIR/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(flair, aug, dicom)
    elif mri_type == "T1w":
        t1w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1w, aug, dicom)
    elif mri_type == "T1wCE":
        t1wce = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1wCE/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1wce, aug, dicom)
    else:
        t2w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T2w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t2w, aug, dicom)
    
    # Return 3D image: WidthxHeightxDepth
    # Data type: uint8
    return img

#### 1.2 Load, Save & Store Images (in Weights and Biases) 

In [None]:
data_remove_black = wandb.Table(columns=['patient_id', 'target', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])
path = "../tmp/tmp-gifs-remove-black"

for patient in sample_patients:
    img_flair =load_images(patient, mri_type="FLAIR", aug=False, dicom=config.DICOM)
    img_t1w =load_images(patient, mri_type="T1w", aug=False, dicom=config.DICOM)
    img_t1wce =load_images(patient, mri_type="T1wCE", aug=False, dicom=config.DICOM)
    img_t2w =load_images(patient, mri_type="T2w", aug=False, dicom=config.DICOM)
    
    imgs_flair = []
    for i in range(img_flair.shape[2]):
        imgs_flair.append(img_flair[:,:,i])
    imageio.mimsave(f'{path}/{patient}_flair.gif', imgs_flair)
        
    imgs_t1w = []
    for i in range(img_t1w.shape[2]):
        imgs_t1w.append(img_t1w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1w.gif', imgs_t1w)
        
    imgs_t1wce = []
    for i in range(img_t1wce.shape[2]):
        imgs_t1wce.append(img_t1wce[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1wce.gif', imgs_t1wce)
        
    imgs_t2w = []
    for i in range(img_t2w.shape[2]):
        imgs_t2w.append(img_t2w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t2w.gif', imgs_t2w)
    
    data_remove_black.add_data(int(patient),                                            
                               df_train.loc[df_train.BraTS21ID == int(patient)].MGMT_value.values[0],
                               wandb.Image(f'{path}/{patient}_flair.gif'),
                               wandb.Image(f'{path}/{patient}_t1w.gif'),
                               wandb.Image(f'{path}/{patient}_t1wce.gif'),
                               wandb.Image(f'{path}/{patient}_t2w.gif'))
    
wandb.log({'Removed Black Pixels Samples': data_remove_black})

In [None]:
wandb.finish()

## 3. Remove Black Boundaries II

#### Config

In [None]:
PATH = '..'

config = dict(
    # Pre-processing
    REMOVE_BLACK_BOUNDARIES = True,
    DICOM=False,
    
    # Albumentation
    RRC_SIZE = 256,
    RRC_MIN_SCALE = 0.85,
    RRC_RATIO = (1., 1.),
    CLAHE_CLIP_LIMIT = 2.0,
    CLAHE_TILE_GRID_SIZE = (8, 8),
    CLAHE_PROB = 0.50,
    BRIGHTNESS_LIMIT = (-0.2,0.2),
    BRIGHTNESS_PROB = 0.40,
    HUE_SHIFT = (-15, 15),
    SAT_SHIFT = (-15, 15),
    VAL_SHIFT = (-15, 15),
    HUE_PROB = 0.64,
    COARSE_MAX_HOLES = 16,
    COARSE_PROB = 0.7,
    
    # Logging
    VERBOSE = False,
    NAME = "00_EDA_Slicing-Remove-Black-II"
)

#### wandb

In [None]:
run = wandb.init(entity="uzk-wim", project='rsna-miccai-slicing', config=config, mode="online")
config = wandb.config
wandb.run.name = f"{config.NAME}"

### 1. Loading Images

#### 1.1 Utilities

In [None]:
def dicom_2_image(file, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(file)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        img = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        img = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    return img

def remove_black_boundaries(img):
    img = img[img.sum(1)!=0]
    img = img.T[img.sum(0)!=0].T
    return img

def get_3d_image(mri_type, aug, dicom):
    img_depth = len(mri_type)
    minimum_idx = 0
    maximum_idx = img_depth
    step = 1
    # Create list which contains all the 2D images which form the 3D image
    mri_img = []
    for i in range(minimum_idx, maximum_idx, step):
        file = mri_type[i]
        if dicom:
            img = dicom_2_image(file)
        else:
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        # Remove black boundaries
        if config.REMOVE_BLACK_BOUNDARIES:
            img = remove_black_boundaries(img)
        if aug:
            transformed = train_transform(image=img)
            img = transformed["image"]
        else:
            transformed = valid_transform(image=img)
            img = transformed["image"]
        mri_img.append(np.array(img))
    mri_img = np.rollaxis(np.array(mri_img), 0, 3) # From depthx256x256 to 256x256xdepth
    if config.VERBOSE:
        print(f"Shape of mri_img: {mri_img.shape}")
    return mri_img

def load_images(scan_id, mri_type, aug=True, split="train", dicom=False):
    file_ext = "png"
    if dicom:
        file_ext = "dcm"
    if config.VERBOSE:
        print(f"Scan id {scan_id}")
        
    # Ascending sort
    if mri_type == "FLAIR":
        flair = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/FLAIR/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(flair, aug, dicom)
    elif mri_type == "T1w":
        t1w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1w, aug, dicom)
    elif mri_type == "T1wCE":
        t1wce = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1wCE/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1wce, aug, dicom)
    else:
        t2w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T2w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t2w, aug, dicom)
    
    # Return 3D image: WidthxHeightxDepth
    # Data type: uint8
    return img

#### 1.2 Load, Save & Store Images (in Weights and Biases) 

In [None]:
data_remove_black_ii = wandb.Table(columns=['patient_id', 'target', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])
path = "../tmp/tmp-gifs-remove-black-ii"

for patient in sample_patients:
    img_flair =load_images(patient, mri_type="FLAIR", aug=False, dicom=config.DICOM)
    img_t1w =load_images(patient, mri_type="T1w", aug=False, dicom=config.DICOM)
    img_t1wce =load_images(patient, mri_type="T1wCE", aug=False, dicom=config.DICOM)
    img_t2w =load_images(patient, mri_type="T2w", aug=False, dicom=config.DICOM)
    
    imgs_flair = []
    for i in range(img_flair.shape[2]):
        imgs_flair.append(img_flair[:,:,i])
    imageio.mimsave(f'{path}/{patient}_flair.gif', imgs_flair)
        
    imgs_t1w = []
    for i in range(img_t1w.shape[2]):
        imgs_t1w.append(img_t1w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1w.gif', imgs_t1w)
        
    imgs_t1wce = []
    for i in range(img_t1wce.shape[2]):
        imgs_t1wce.append(img_t1wce[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1wce.gif', imgs_t1wce)
        
    imgs_t2w = []
    for i in range(img_t2w.shape[2]):
        imgs_t2w.append(img_t2w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t2w.gif', imgs_t2w)
    
    data_remove_black_ii.add_data(int(patient),                                            
                                  df_train.loc[df_train.BraTS21ID == int(patient)].MGMT_value.values[0],
                                  wandb.Image(f'{path}/{patient}_flair.gif'),
                                  wandb.Image(f'{path}/{patient}_t1w.gif'),
                                  wandb.Image(f'{path}/{patient}_t1wce.gif'),
                                  wandb.Image(f'{path}/{patient}_t2w.gif'))
    
wandb.log({'Removed Black Pixels II Samples': data_remove_black_ii})

In [None]:
wandb.finish()

## 4. Remove Black Boundaries + Middle Slices

#### Config

In [None]:
PATH = '..'

config = dict(
    # Pre-processing
    SLICE_NUMBER = 32, # >= 30
    REMOVE_BLACK_BOUNDARIES = True,
    DICOM=False,
    
    # Albumentation
    RRC_SIZE = 256,
    RRC_MIN_SCALE = 0.85,
    RRC_RATIO = (1., 1.),
    CLAHE_CLIP_LIMIT = 2.0,
    CLAHE_TILE_GRID_SIZE = (8, 8),
    CLAHE_PROB = 0.50,
    BRIGHTNESS_LIMIT = (-0.2,0.2),
    BRIGHTNESS_PROB = 0.40,
    HUE_SHIFT = (-15, 15),
    SAT_SHIFT = (-15, 15),
    VAL_SHIFT = (-15, 15),
    HUE_PROB = 0.64,
    COARSE_MAX_HOLES = 16,
    COARSE_PROB = 0.7,
    
    # Logging
    VERBOSE = False,
    NAME = "00_EDA_Slicing-Middle"
)

#### wandb

In [None]:
run = wandb.init(entity="uzk-wim", project='rsna-miccai-slicing', config=config, mode="online")
config = wandb.config
wandb.run.name = f"{config.NAME}"

### 1. Loading Images

#### 1.1 Utilities

In [None]:
def dicom_2_image(file, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(file)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        img = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        img = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    return img

def remove_black_boundaries(img):
    img = img[img.sum(1)!=0]
    img = img.T[img.sum(0)!=0].T
    return img

def get_idxs(mri_type):
    # Take SLICE_NUMBER slices from the middle
    threshold = config.SLICE_NUMBER // 2
    minimum_idx = len(mri_type)//2 - threshold if (len(mri_type)//2 - threshold) > 0 else 0
    maximum_idx = len(mri_type)//2 + threshold  # maximum can exceed the index
    if config.VERBOSE:
        print(f"Minimum {minimum_idx}")
        print(f"Maximum {maximum_idx}")
    return minimum_idx, maximum_idx

def get_3d_image(mri_type, aug, dicom):
    img_depth = len(mri_type)
    minimum_idx, maximum_idx = get_idxs(mri_type)
    # Create array which contains the images
    mri_img = []
    for file in mri_type[minimum_idx:maximum_idx]:
        if dicom:
            img = dicom_2_image(file)
        else:
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        # Remove black boundaries
        if config.REMOVE_BLACK_BOUNDARIES:
            img = remove_black_boundaries(img)
        if aug:
            transformed = train_transform(image=img)
            img = transformed["image"]
        else:
            transformed = valid_transform(image=img)
            img = transformed["image"]
        mri_img.append(np.array(img))
    mri_img = np.rollaxis(np.array(mri_img), 0, 3) # From depthx256x256 to 256x256xdepth
    if config.VERBOSE:
        print(f"Shape of mri_img: {mri_img.shape}")
    return mri_img

def load_images(scan_id, mri_type, aug=True, split="train", dicom=False):
    file_ext = "png"
    if dicom:
        file_ext = "dcm"
    if config.VERBOSE:
        print(f"Scan id {scan_id}")
        
    # Ascending sort
    if mri_type == "FLAIR":
        flair = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/FLAIR/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(flair, aug, dicom)
    elif mri_type == "T1w":
        t1w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1w, aug, dicom)
    elif mri_type == "T1wCE":
        t1wce = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1wCE/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1wce, aug, dicom)
    else:
        t2w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T2w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t2w, aug, dicom)
    
    # Return 3D image: WidthxHeightxDepth
    # Data type: uint8
    return img

#### 1.2 Load, Save & Store Images (in Weights and Biases) 

In [None]:
data_middle = wandb.Table(columns=['patient_id', 'target', 'FLAIR', 'T1w', 'T1wCE', 'T2w'])
path = "../tmp/tmp-gifs-middle"

for patient in sample_patients:
    img_flair =load_images(patient, mri_type="FLAIR", aug=False, dicom=config.DICOM)
    img_t1w =load_images(patient, mri_type="T1w", aug=False, dicom=config.DICOM)
    img_t1wce =load_images(patient, mri_type="T1wCE", aug=False, dicom=config.DICOM)
    img_t2w =load_images(patient, mri_type="T2w", aug=False, dicom=config.DICOM)
    
    imgs_flair = []
    for i in range(img_flair.shape[2]):
        imgs_flair.append(img_flair[:,:,i])
    imageio.mimsave(f'{path}/{patient}_flair.gif', imgs_flair)
        
    imgs_t1w = []
    for i in range(img_t1w.shape[2]):
        imgs_t1w.append(img_t1w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1w.gif', imgs_t1w)
        
    imgs_t1wce = []
    for i in range(img_t1wce.shape[2]):
        imgs_t1wce.append(img_t1wce[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t1wce.gif', imgs_t1wce)
        
    imgs_t2w = []
    for i in range(img_t2w.shape[2]):
        imgs_t2w.append(img_t2w[:,:,i])
    imageio.mimsave(f'{path}/{patient}_t2w.gif', imgs_t2w)
    
    data_middle.add_data(int(patient),                                            
                         df_train.loc[df_train.BraTS21ID == int(patient)].MGMT_value.values[0],
                         wandb.Image(f'{path}/{patient}_flair.gif'),
                         wandb.Image(f'{path}/{patient}_t1w.gif'),
                         wandb.Image(f'{path}/{patient}_t1wce.gif'),
                         wandb.Image(f'{path}/{patient}_t2w.gif'))
    
wandb.log({'Middle Samples': data_middle})

In [None]:
wandb.finish()

Done