This is a deep learning workflow for Progressive Multifocal Leukoencephalopathy (PML) lesion and brain segmentation from brain MRI using Python 3, Keras, and TensorFlow. The model generates segmentation masks of the brain parenchyma and lesions in PML patients.
This work was conducted at the Neuroimmunology Clinic (NIC) and Translation Neuroradiology Section (TNS) of the National Institute of Neurological Disorders and Stroke (NINDS), in collaboration with colleagues at the National Institute of Mental Health, and the Henry Jackson Foundation at the National Institutes of Health. This work utilized the computational resources of the NIH HPC Biowulf cluster, and the software is distributed under the GNU General Public License v3.0.
If this repository is helpful for your research, please cite the following articles:
Al-Louzi O, Roy S, Osuorah I, et al. Progressive multifocal leukoencephalopathy lesion and brain parenchymal segmentation from MRI using serial deep convolutional neural networks. Neuroimage Clin. 2020;28:102499. doi:10.1016/j.nicl.2020.102499 https://www.sciencedirect.com/science/article/pii/S2213158220303363
The basic code skeleton was adopted from the following source (https://arxiv.org/abs/1803.09172) with several notable changes. We have created two tailored scripts for PML brain parechymal extraction and lesion segmentation training. We have also introduced improvements in training/validation split which is now undertaken at the atlas level to remove patch overlap effects during sampling, included support for Tensorboard logging, and generation of training/validation accuracy and loss graphs automatically at the end of model training. For the testing implementation, we fixed a previous bug with image padding for different patch sizes, added a new 4D image padding function, and replaced the slice-by-slice format previously used to generate model predictions on unseen images with a new method that uses a moving 3D window applied serially across the image volume. This allows higher resolution images to fit into available GPU memory. In addition, the method now offers support for 3 different trainable network acrchitechures:
- 3D Unet
- Feature pyramid network-ResNet50 (with bottleneck ResNet modules)
- Panoptic feature pyramid network-ResNet50 (with preactivated ResNet modules)
Prerequisites before running JCnet
A few standard MRI preprocessing steps are necessary before training or testing a JCnet model:
- Bias field correction - can use either N4 bias correction or MICO.
- Skull-stripping - we recommend using the MONSTR skull-stripping algorithm in PML cases, which is publicly available and can be found here.
- Transformation to the standard MNI-ICBM 152 atlas space, which is available for download here.
- Co-registration of different MRI channels or contrasts (i.e. T1-weighted, fluid-attenuated inversion recovery, T2-weighted, and proton density images).
- Operating System: Linux.
- CPU Number/Speed: we recommend using a processor with at least 8 cores, 2GHz speed, and multithreading capability.
- RAM: 64+GB recommended (depending on the size of the training dataset and maximum number of training patches per subject).
- GPU: recommend a dedicated graphics card with at least 8GB of VRAM (ex. NVIDIA RTX 2080 Ti, Titan X, or v100 models). If our current pre-trained models do not fit into GPU memory during testing, we recommend downscaling the network parameters (base filters, patch size, or architechure in this order). These models can be provided upon request.
- Python v3.6
- Keras v2.2.4
- Tensorflow GPU version v1.13+ (TF v2 is not currently supported)
- Several open source python packages, please see requirements.txt To install python dependency packages, you can point your pip manager using the terminal to the text file as follows:
pip3 install -r requirements.txt
Training call examples
# Brain Extraction training: python JCnet_BrainExtraction_Train.py --atlasdir /path/to/atlas/dir/ --natlas 31 --psize 64 64 64 --maxpatch 1000 --batchsize 8 --basefilters 32 --modalities T1 FL T2 PD --epoch 50 --outdir /path/to/output/dir/to/save/models/ --save 1 --gpuids 0 1 2 3 --loss focal --model FPN
# Lesion Segmentation training: python JCnet_LesionSeg_Train.py --atlasdir /path/to/atlas/dir/ --natlas 31 --psize 64 64 64 --maxpatch 1000 --batchsize 8 --basefilters 32 --modalities T1 FL T2 PD --epoch 50 --outdir /path/to/output/dir/to/save/models/ --save 1 --gpuids 0 1 2 3 --loss focal --model FPN
Testing call examples
# Brain Extraction testing: python JCnet_BrainExtraction_Test.py --models /path/to/model/files/containing/in/\*Orient012\*.h5 /path/to/model/files/ending/in/\*Orient120\*.h5 /path/to/model/files/ending/in/\*Orient201\*.h5 --images /path/to/T1/niftifile/\*.nii.gz /path/to/FL/niftifile/\*.nii.gz /path/to/T2/niftifile/\*.nii.gz /path/to/PD/niftifile/\*.nii.gz --modalities T1 FL T2 PD --psize 64 64 64 --outdir /path/to/output/dir/to/save/results/ --threshold 0.5
# Lesion Segmentation testing: python JCnet_LesionSeg_Test.py --models /path/to/model/files/containing/in/\*Orient012\*.h5 /path/to/model/files/ending/in/\*Orient120\*.h5 /path/to/model/files/ending/in/\*Orient201\*.h5 --images /path/to/T1/niftifile/\*.nii.gz /path/to/FL/niftifile/\*.nii.gz /path/to/T2/niftifile/\*.nii.gz /path/to/PD/niftifile/\*.nii.gz --modalities T1 FL T2 PD --psize 64 64 64 --outdir /path/to/output/dir/to/save/results/ --threshold 0.35