# Preprocessing of MRNet and Data Augmentation

In [1]:
import os
import platform
from glob import glob

import numpy as np
import pandas as pd
import utils

In [3]:
mrnet_dataset_dir = 'Data/MRNet-v1.0'
mrnet_train_path = os.path.join(mrnet_dataset_dir, 'train')
mrnet_valid_path = os.path.join(mrnet_dataset_dir, 'valid')
mrnet_planes = ['axial', 'coronal', 'sagittal']

In [4]:
# For running code on Windows
if platform.system() == "Windows":
    mrnet_dataset_dir = mrnet_dataset_dir.replace('/', '\\')
    mrnet_train_path = mrnet_train_path.replace('/', '\\')
    mrnet_valid_path = mrnet_valid_path.replace('/', '\\')

In [5]:
mrnet_datasets = {'train': mrnet_train_path, 'valid': mrnet_valid_path}

In [6]:
mrnet_labels = ['abnormal', 'acl', 'meniscus']

In [7]:
# TRAIN DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            train_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            train_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            train_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/train-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

train_df = pd.merge(train_abnormal_df, train_acl_df, on='Case').merge(train_meniscus_df, on='Case')

In [8]:
# VALID DATASET
for label in mrnet_labels:
    if platform.system() == "Windows":
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}\\valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})
    else:
        if label == 'abnormal':
            valid_abnormal_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Abnormal'],
                                            dtype={'Case': str, 'Abnormal': np.int64})
        elif label == 'acl':
            valid_acl_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                       header=None,
                                       names=['Case', 'ACL'],
                                       dtype={'Case': str, 'ACL': np.int64})
        if label == 'meniscus':
            valid_meniscus_df = pd.read_csv(f"{mrnet_dataset_dir}/valid-{label}.csv",
                                            header=None,
                                            names=['Case', 'Meniscus'],
                                            dtype={'Case': str, 'Meniscus': np.int64})

valid_df = pd.merge(valid_abnormal_df, valid_acl_df, on='Case').merge(valid_meniscus_df, on='Case')

In [9]:
def preprocess_mri_vols(cases, overwrite=False):
    """
    This function preprocesses all the MRI volumes in MRNet
    and stores them under 'Preprocessed_Data' directory.

    Args:
        cases (list): List of files in MRNet dataset
        overwrite (bool, optional): Option to overwrite already preprocessed MRI
    """
    cases.sort()
    for case in cases:
        mri_vol = np.load(case)
        mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

        case_path = os.path.normpath(case).split(os.sep)
        case_path[0] = 'Preprocessed_Data'
        preprocessed_case_path = os.path.join(*case_path)

        if overwrite or not os.path.exists(preprocessed_case_path):
            preprocessed_mri_vol = utils.preprocess_mri(mri_vol)
            os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
            np.save(preprocessed_case_path, preprocessed_mri_vol)

In [10]:
def augment_mri_vols(dataset, labels, aug_flip_prob=0.95, overwrite=False):
    """
    This function augments MRI volumes in MRNet dataset to create more samples
    for labels that have lower number of cases.

    Args:
        dataset (str): Path to either train or valid MRNet dataset
        labels (Pandas dataframe): Labels dataframe for the exams
        aug_flip_prob (float, optional): Augmentation flip probability
        overwrite (bool, optional): Option to overwrite already preprocessed MRI
    """
    aug_labels_list = []
    plane = 'sagittal'
    if platform.system() == "Windows":
        cases = glob(f"{dataset}\\{plane}\\*.npy")
    else:
        cases = glob(f"{dataset}/{plane}/*.npy")
    cases.sort()
    for case in cases:
        # We will create a new path file for augmented images by adding '_aug' in file names
        # and we store them under the folder <plane>/aug

        case_path = os.path.normpath(case).split(os.sep)
        file_name = case_path[-1]

        orig_sagittal = os.path.join(*case_path)

        case_path[0] = 'Preprocessed_Data'
        case_path.insert(-1, 'aug')

        # SAGITTAL
        sa_temp = file_name
        dot_index = sa_temp.index('.')

        # Do this only once as the label of augmented MRIs will be the same for all three planes and tasks
        temp_aug_labels = labels.loc[labels['Case'] == sa_temp[:dot_index]][['Abnormal', 'ACL', 'Meniscus']].values.tolist()[0]

        # If acl_diagnosis is 1, only 5% chance of augmentation as majority samples are without tear
        # Increase probability of augmentation in case of ACL tears
        if np.random.rand() >= aug_flip_prob or temp_aug_labels[1] == 1:

            case_path[-1] = f"{sa_temp[:dot_index]}-aug-0{sa_temp[dot_index:]}"
            aug_sagittal = os.path.join(*case_path)

            if temp_aug_labels[1] == 0:
                if overwrite or not os.path.exists(aug_sagittal):
                    mri_vol = np.load(orig_sagittal)
                    mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

                    aug_mri_vol = utils.random_horizontal_flip(mri_vol)
                    aug_mri_vol = utils.random_rotation(aug_mri_vol)

                    preprocessed_aug_mri_vol = utils.preprocess_mri(aug_mri_vol)
                    os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
                    np.save(aug_sagittal, preprocessed_aug_mri_vol)
                    aug_labels_list.append([f"{sa_temp[:dot_index]}-aug-0"] + temp_aug_labels)

            elif temp_aug_labels[1] == 1:
                for aug_ind in range(3):  # We will augment sample three times
                    if aug_ind >= 1:
                        case_path[-1] = f"{sa_temp[:dot_index]}-aug-{aug_ind}{sa_temp[dot_index:]}"
                        aug_sagittal = os.path.join(*case_path)

                    if overwrite or not os.path.exists(aug_sagittal):
                        mri_vol = np.load(orig_sagittal)
                        mri_vol = mri_vol.astype(np.float64)  # Change the dtype to float64

                        if aug_ind == 0:
                            aug_mri_vol = utils.random_horizontal_flip(mri_vol)
                        elif aug_ind == 1:
                            aug_mri_vol = utils.random_rotation(mri_vol)
                        elif aug_ind == 2:
                            aug_mri_vol = utils.random_horizontal_flip(mri_vol)
                            aug_mri_vol = utils.random_rotation(aug_mri_vol)
                        preprocessed_aug_mri_vol = utils.preprocess_mri(aug_mri_vol)
                        os.makedirs(os.path.join(*case_path[:-1]), exist_ok=True)
                        np.save(aug_sagittal, preprocessed_aug_mri_vol)
                        aug_labels_list.append([f"{sa_temp[:dot_index]}-aug-{aug_ind}"] + temp_aug_labels)

    aug_train_df = pd.DataFrame(aug_labels_list, columns=labels.columns)
    # print(aug_train_df)
    csv_file_path = os.path.normpath(dataset).split(os.sep)
    if csv_file_path[-1] == 'train':
        if platform.system() == "Windows":
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "\\train-aug.csv")
        else:
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "/train-aug.csv")
    elif csv_file_path[-1] == 'valid':
        if platform.system() == "Windows":
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "\\valid-aug.csv")
        else:
            aug_train_df.to_csv(os.path.join(*csv_file_path[:-1]) + "/valid-aug.csv")
    print(f"For {dataset.upper()} datset we have {len(aug_labels_list)} augmented samples.")

In [11]:
def preprocess_mri_vols_for_plane(dataset, plane):
    """
    This function calls preprocessing on given dataset of MRNet
    and plane.

    Args:
        dataset (str): Path to either train or valid MRNet dataset
        plane (str): MRNet dataset plane axial, coronal or sagittal
    """
    if platform.system() == "Windows":
        cases = glob(f"{dataset}\\{plane}\\*.npy")
    else:
        cases = glob(f"{dataset}/{plane}/*.npy")
    preprocess_mri_vols(cases)
    print(f"For {dataset.upper()} {plane} plane we have {len(cases)} samples.")

In [12]:
# Preprocess only sagittal plane
preprocess_mri_vols_for_plane(mrnet_datasets['train'], 'sagittal')

For DATA\MRNET-V1.0\TRAIN sagittal plane we have 1130 samples.


In [13]:
# Preprocess only sagittal plane
preprocess_mri_vols_for_plane(mrnet_datasets['valid'], 'sagittal')

For DATA\MRNET-V1.0\VALID sagittal plane we have 120 samples.


In [14]:
augment_mri_vols(mrnet_datasets['train'], train_df)

For DATA\MRNET-V1.0\TRAIN datset we have 664 augmented samples.


In [15]:
augment_mri_vols(mrnet_datasets['valid'], valid_df)

For DATA\MRNET-V1.0\VALID datset we have 166 augmented samples.
