Skip to content

Commit

Permalink
ENH: changes the preprocessing methods for the brats data.
Browse files Browse the repository at this point in the history
Removes windowing and background detection.
  • Loading branch information
ellisdg committed Nov 17, 2017
1 parent ab19319 commit e9f288f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 53 deletions.
20 changes: 9 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ that are mentioned in the paper. If you figure out a way to apply these to a 3D
welcome!

The code was written to be trained using the
[BRATS](https://sites.google.com/site/braintumorsegmentation/home/brats2015) data set for brain tumors, but it can
[BRATS](http://www.med.upenn.edu/sbia/brats2017.html) data set for brain tumors, but it can
be easily modified to be used in other 3D applications. To adapt the network, you might have to play with the input size
to get something that works for your data.

Expand All @@ -15,29 +15,27 @@ I used [Bohdan Pavlyshenko](https://www.kaggle.com/bpavlyshenko)'s
segmentation as a base for this 3D U-Net.

## How to Train Using BRATS Data
1. Download the [BRATS 2015 data set](https://sites.google.com/site/braintumorsegmentation/home/brats2015).
1. Download the BRATS 2017 [GBM](https://app.box.com/s/926eijrcz4qudona5vkz4z5o9qfm772d) and
[LGG](https://app.box.com/s/ssfkb6u8fg3dmal0v7ni0ckbqntsc8fy) data. Place the unzipped folders in the
```brats/data/original``` folder.
2. Install dependencies:
nibabel,
keras,
pytables,
nilearn
nilearn,
SimpleITK (for preprocessing only)
3. Install [ANTs N4BiasFieldCorrection](https://github.com/stnava/ANTs/releases) and add the location of the ANTs
binaries to the PATH environmental variable.
4. Convert the data to nifti format and perform image wise normalization and correction:
```
$ cd brats
```
Import the conversion function:
Import the conversion function and run the preprocessing:
```
$ python
>>> from preprocess import convert_brats_data
>>> convert_brats_data("data/original", "data/preprocessed")
```
Import the configuration dictionary:
```
>>> from config import config
>>> convert_brats_data("/path/to/BRATS/BRATS2015_Training", config["data_dir"])
```
Where ```config["data_dir"]``` is the location where the raw BRATS data will be converted to.

4. Run the training:
```
$ cd ..
Expand Down
115 changes: 76 additions & 39 deletions brats/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""
Tools for converting, normalizing, and fixing the brats data.
Correcting the bias requires that N4BiasFieldCorrection be installed!
"""


import glob
import os
import warnings
import shutil

import SimpleITK as sitk
import numpy as np
Expand All @@ -21,12 +21,19 @@ def append_basename(in_file, append):
return os.path.join(dirname, base + append + "." + ext)


def get_background_mask(in_folder, out_file):
def get_background_mask(in_folder, out_file, truth_name="GlistrBoost_ManuallyCorrected"):
"""
This function computes a common background mask for all of the data in a subject folder.
:param in_folder: a subject folder from the BRATS dataset.
:param out_file: an image containing a mask that is 1 where the image data for that subject contains the background.
:param truth_name: how the truth file is labeled int he subject folder
:return: the path to the out_file
"""
background_image = None
for name in config["modalities"] + [".OT"]:
for name in config["all_modalities"] + [truth_name]:
image = sitk.ReadImage(get_image(in_folder, name))
if background_image:
if name == ".OT" and not (image.GetOrigin() == background_image.GetOrigin()):
if name == truth_name and not (image.GetOrigin() == background_image.GetOrigin()):
image.SetOrigin(background_image.GetOrigin())
background_image = sitk.And(image == 0, background_image)
else:
Expand All @@ -50,11 +57,26 @@ def window_intensities(in_file, out_file, min_percent=1, max_percent=99):


def correct_bias(in_file, out_file):
"""
Corrects the bias using ANTs N4BiasFieldCorrection. If this fails, will then attempt to correct bias using SimpleITK
:param in_file: input file path
:param out_file: output file path
:return: file path to the bias corrected image
"""
correct = N4BiasFieldCorrection()
correct.inputs.input_image = in_file
correct.inputs.output_image = out_file
done = correct.run()
return done.outputs.output_image
try:
done = correct.run()
return done.outputs.output_image
except IOError:
warnings.warn(RuntimeWarning("ANTs N4BIasFieldCorrection could not be found."
"Will try using SimpleITK for bias field correction"
" which will take much longer. To fix this problem, add N4BiasFieldCorrection"
" to your PATH system variable. (example: EXPORT ${PATH}:/path/to/ants/bin)"))
output_image = sitk.N4BiasFieldCorrection(sitk.ReadImage(in_file))
sitk.WriteImage(output_image, out_file)
return os.path.abspath(out_file)


def rescale(in_file, out_file, minimum=0, maximum=20000):
Expand All @@ -64,7 +86,11 @@ def rescale(in_file, out_file, minimum=0, maximum=20000):


def get_image(subject_folder, name):
return glob.glob(os.path.join(subject_folder, "*" + name + ".*", "*" + name + ".*.mha"))[0]
file_card = os.path.join(subject_folder, "*" + name + "*.nii.gz")
try:
return glob.glob(file_card)[0]
except IndexError:
raise RuntimeError("Could not find file matching {}".format(file_card))


def background_to_zero(in_file, background_file, out_file):
Expand All @@ -78,43 +104,54 @@ def check_origin(in_file, in_file2):
image2 = sitk.ReadImage(in_file2)
if not image.GetOrigin() == image2.GetOrigin():
image.SetOrigin(image2.GetOrigin())
sitk.WriteImage(image, in_file)


def normalize_image(in_file, out_file, background_mask):
converted = convert_image_format(in_file, append_basename(out_file, "_converted"))
initial_rescale = rescale(converted, append_basename(out_file, "_initial_rescale"))
zeroed = background_to_zero(initial_rescale, background_mask, append_basename(out_file, "_zeroed"))
windowed = window_intensities(zeroed, append_basename(out_file, "_windowed"))
corrected = correct_bias(windowed, append_basename(out_file, "_corrected"))
rescaled = rescale(corrected, out_file, maximum=1)
for f in [converted, initial_rescale, zeroed, windowed, corrected]:
os.remove(f)
return rescaled
sitk.WriteImage(image, in_file)


def convert_brats_folder(in_folder, out_folder, background_mask):
for name in config["modalities"] + [".OT"]:
image_file = get_image(in_folder, name)
if name == ".OT":
out_file = os.path.abspath(os.path.join(out_folder, "truth.nii.gz"))
converted = convert_image_format(image_file, out_file)
check_origin(converted, background_mask)
else:
out_file = os.path.abspath(os.path.join(out_folder, name + ".nii.gz"))
normalize_image(image_file, out_file, background_mask)
def normalize_image(in_file, out_file, bias_correction=True):
if bias_correction:
corrected = correct_bias(in_file, append_basename(out_file, "_corrected"))
rescaled = rescale(corrected, out_file, maximum=1)
os.remove(corrected)
else:
rescaled = rescale(in_file, out_file, maximum=1)
return rescaled


def convert_brats_data(brats_folder, out_folder):
def convert_brats_folder(in_folder, out_folder, truth_name="GlistrBoost_ManuallyCorrected",
no_bias_correction_modalities=None):
for name in config["all_modalities"]:
image_file = get_image(in_folder, name)
out_file = os.path.abspath(os.path.join(out_folder, name + ".nii.gz"))
perform_bias_correction = no_bias_correction_modalities and name in no_bias_correction_modalities
normalize_image(image_file, out_file, bias_correction=perform_bias_correction)
# copy the truth file
try:
truth_file = get_image(in_folder, truth_name)
except RuntimeError:
truth_file = get_image(in_folder, truth_name.split("_")[0])
out_file = os.path.abspath(os.path.join(out_folder, "truth.nii.gz"))
shutil.copy(truth_file, out_file)
check_origin(out_file, get_image(in_folder, config["all_modalities"][0]))


def convert_brats_data(brats_folder, out_folder, overwrite=False, no_bias_correction_modalities=("flair",)):
"""
Preprocesses the BRATS data and writes it to a given output folder. Assumes the original folder structure.
:param brats_folder: folder containing the original brats data
:param out_folder: output folder to which the preprocessed data will be written
:param overwrite: set to True in order to redo all the preprocessing
:param no_bias_correction_modalities: performing bias correction could reduce the signal of certain modalities. If
concerned about a reduction in signal for a specific modality, specify by including the given modality in a list
or tuple.
:return:
"""
for subject_folder in glob.glob(os.path.join(brats_folder, "*", "*")):
if os.path.isdir(subject_folder):
subject = os.path.basename(subject_folder)
new_subject_folder = os.path.join(out_folder, os.path.basename(os.path.dirname(subject_folder)),
subject)
if not os.path.exists(new_subject_folder):
os.makedirs(new_subject_folder)
else:
continue
background_mask = get_background_mask(subject_folder,
os.path.join(new_subject_folder, "background.nii.gz"))
convert_brats_folder(subject_folder, new_subject_folder, background_mask)
if not os.path.exists(new_subject_folder) or overwrite:
if not os.path.exists(new_subject_folder):
os.makedirs(new_subject_folder)
convert_brats_folder(subject_folder, new_subject_folder,
no_bias_correction_modalities=no_bias_correction_modalities)
6 changes: 3 additions & 3 deletions brats/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
config["patch_shape"] = (64, 64, 64)
config["labels"] = (1, 2, 3, 4)
config["n_labels"] = len(config["labels"])
config["modalities"] = ["T1", "T1c", "Flair", "T2"]
config["training_modalities"] = ["T1", "T1c", "Flair"] # set this to the modalities that you want the model to use
config["all_modalities"] = ["t1", "t1Gd", "flair", "t2"]
config["training_modalities"] = config["all_modalities"] # change this if you want to only use some of the modalities
config["nb_channels"] = len(config["training_modalities"])
if "patch_shape" in config and config["patch_shape"] is not None:
config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"]))
Expand All @@ -38,7 +38,7 @@
config["validation_patch_overlap"] = 0
config["training_patch_start_offset"] = (16, 16, 16)

config["hdf5_file"] = os.path.abspath("brats_data.hdf5")
config["hdf5_file"] = os.path.abspath("brats_data.h5")
config["model_file"] = os.path.abspath("tumor_segmentation_model.h5")
config["training_file"] = os.path.abspath("training_ids.pkl")
config["validation_file"] = os.path.abspath("validation_ids.pkl")
Expand Down

0 comments on commit e9f288f

Please sign in to comment.