In [1]:
import json
import glob
import os

In [26]:
config_filename = "sppin_config.json"

config = dict()

model_config = dict()
model_config["name"] = "DynUNet"  # network model name from MONAI
# set the network hyper-parameters
model_config["in_channels"] = 4  # 4 input images for the BraTS challenge
model_config["out_channels"] = 1   # whole tumor, tumor core, enhancing tumor
model_config["spatial_dims"] = 3   # 3D input images
model_config["deep_supervision"] = False  # do not check outputs of lower layers
model_config["strides"] = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]][:-1]  # number of downsampling convolutions
model_config["filters"] = [64, 96, 128, 192, 256, 384, 512, 768, 1024][:len(model_config["strides"])]  # number of filters per layer
model_config["kernel_size"] = [[3, 3, 3]] * len(model_config["strides"])  # size of the convolution kernels per layer
model_config["upsample_kernel_size"] = model_config["strides"][1:]  # should be the same as the strides

# put the model config in the main config
config["model"] = model_config

config["optimizer"] = {'name': 'Adam', 
                       'lr': 0.001}  # initial learning rate

# define the loss
config["loss"] = {'name': 'DiceLoss', # from Monai
                  'include_background': True,  # we do not have a label for the background, so this should be true (by "include background" monai means include channel 0)
                  'sigmoid': True,  # transform the model logits to activations
                  'batch': False}  

# set the cross validation parameters
config["cross_validation"] = {'folds': 5,  # number of cross validation folds
                              'seed': 25},  # seed to make the generation of cross validation folds consistent across different trials
# set the scheduler parameters
config["scheduler"] = {'name': 'ReduceLROnPlateau', 
                       'patience': 20,  # wait 10 epochs with no improvement before reducing the learning rate
                       'factor': 0.5,   # multiply the learning rate by 0.5
                       'min_lr': 1e-08}  # stop reducing the learning rate once it gets to 1e-8

# set the dataset parameters
config["dataset"] = {'name': 'SegmentationDatasetPersistent',  # 'Persistent' means that it will save the preprocessed outputs generated during the first epoch
# However, using 'Persistent', does also increase the time of the first epoch compared to the other epochs, which should run faster
  'desired_shape': [192, 192, 192],  # resize the images to this shape, increase this to get higher resolution images (increases computation time and memory usage)
  'labels': [1],  # 1: tumor
  'orientation': 'RAS',  # Force all the images to be the same orientation (Right-Anterior-Suppine)
  'normalization': 'NormalizeIntensityD',  # z score normalize the input images to zero mean unit standard deviation
  'normalization_kwargs': {'channel_wise': True, "nonzero": False},  # perform the normalization channel wise and include the background
  'resample': True,  # resample the images when resizing them, otherwise the resize could crop out regions of interest
  'crop_foreground': True,  # crop the foreground of the images
  'foreground_percentile': 0.9,  # aggressive foreground cropping to make sure the empty space is taken out of the images
  'training':  # the following arguments will only be applied to the training data.
    {
    'spatial_augmentations': [{'name': 'RandFlipD', 'spatial_axis': 0, 'prob': 0.5},
                              {'name': 'RandFlipD', 'spatial_axis': 1, 'prob': 0.5},
                              {'name': 'RandRotateD', 'prob': 0.5, 'range_x': 0.2, 'range_y': 0.2, 'range_z': 0.2}],
    'intensity_augmentations': [{'name': 'RandScaleIntensityD', 'factors': 0.1, 'prob': 1.0},
                                {'name': 'RandShiftIntensityD', 'offsets': 0.1, 'prob': 1.0}],
    }
                    }
config["training"] = {'batch_size': 2,  # number of image/label pairs to read at a time during training
  'validation_batch_size': 2,  # number of image/label pairs to read at atime during validation
  'amp': False,  # don't set this to true unless the model you are using is setup to use automatic mixed precision (AMP)
  'early_stopping_patience': None,  # stop the model early if the validaiton loss stops improving
  'n_epochs': 1000,  # number of training epochs, reduce this if you don't want training to run as long
  'save_every_n_epochs': None,  # save the model every n epochs (otherwise only the latest model will be saved)
  'save_last_n_models': None,  # save the last n models 
  'save_best': True}  # save the model that has the best validation loss

In [92]:
# get the training filenames
config["training_filenames"] = list()
ground_truth_filenames = sorted(glob.glob("./aligned/*/*/*NB*.nii*"))
for label_filename in ground_truth_filenames:
    subject, visit = label_filename.split("/")[-3:-1]
    filenames = [os.path.abspath(fn) for fn in sorted(glob.glob(os.path.join(os.path.dirname(label_filename), "*")))]
    n_features = len(filenames) 
    feature_modalities = ["_".join(fn.split("_")[3:-1]).strip("8 ").lower() for fn in filenames]
    t1_filename = filenames[feature_modalities.index("T1_gd".lower())]
    assert os.path.exists(t1_filename)
    
    if len(feature_modalities) < 5:
        continue
    
    if "T2".lower() not in feature_modalities:
        for filename in filenames:
            if "T2" in filename:
                t2_fn = filename
        # print(t2_fn)
        print(feature_modalities)
    else:
        t2_fn = filenames[feature_modalities.index("T2".lower())]
    assert os.path.exists(t2_fn)
    
    if "DWI_b0".lower() not in feature_modalities:
        for filename in filenames:
            if "DWI" in filename and "b0" in filename:
                dwi_b0 = filename
        # print(dwi_b0)
    else:
        dwi_b0 = filenames[feature_modalities.index("DWI_b0".lower())]
    assert os.path.exists(dwi_b0)

    if "DWI_b100".lower() not in feature_modalities:
        for filename in filenames:
            if "DWI" in filename and "b100" in filename:
                dwi_b100 = filename
        # print(dwi_b100)
    else:
        dwi_b100 = filenames[feature_modalities.index("DWI_b100".lower())]
    assert os.path.exists(dwi_b100)

    # print(t2_fn)
    config["training_filenames"].append({"image": [t1_filename, t2_fn, dwi_b0, dwi_b100], "label": label_filename})
with open(config_filename, "w") as op:
    json.dump(config, op, indent=4)

['dwi_b0', 'dwi_b100', 'nb', 't1_gd', '']
