Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train model with nnUnet or MONAI #75

Closed
jcohenadad opened this issue Jul 11, 2023 · 4 comments · Fixed by spinalcordtoolbox/spinalcordtoolbox#4554
Closed

Train model with nnUnet or MONAI #75

jcohenadad opened this issue Jul 11, 2023 · 4 comments · Fixed by spinalcordtoolbox/spinalcordtoolbox#4554
Assignees

Comments

@jcohenadad
Copy link
Member

Remove dependency from ivadomed and train model with nnUnet or MONAI.

@Nilser3
Copy link

Nilser3 commented Aug 3, 2023

Hi @jcohenadad

Description

MS lesion segmentation in 3T MP2RAGE images (UNI contrast) from Basel and Marseille using nnUnet.

MP2RAGE dataset description (N=221)

Center Resolution (mm) Train / validation Test Total (N)
Basel 1.0 x 1.0 x 1.0 142 38 180
Marseille 1.0 x 0.9375 x 0.9375 34 7 41
Total 176 45 221

CV in 5 folds keeping the heterogeneity of centers

Data Preprocessing Pipeline

For cropping I have added (24 + 5) x 2 pixels around SC mask (Fig. 1).

Preprocessing pipeline
# SC segmentation
sct_deepseg_sc -i UNI_image.nii.gz -o seg.nii.gz -c t1
# Dilate SC mask (5 pixels) in XYZ axis
sct_maths -i seg.nii.gz -dilate 5 -shape ball -o  seg_dil5.nii.gz
# Dilate SC_dilated  (24 pixels) in XY plane
sct_maths -i seg_dil5.nii.gz -dilate 24 -dim 2 -shape disk -o  seg_dil24.nii.gz
# Crop UNI_image and Lesion masks at SC_dilated_24 mask  size
sct_crop_image -i UNI_image.nii.gz -m seg_dil24.nii.gz -o UNIT1-crop.nii.gz
sct_crop_image -i UNIT1_lesion-manual.nii.gz" -m seg_dil24.nii.gz -o UNIT1-crop_lesion-manual.nii.gz

Fig. 1: Marseille UNI image with seg_dil24 mask

dil24

This way more pixels are stored in the XY plane so that after 5 downsampling more context information is stored, (p. 25 nnUnet paper ).

Experiment planning

After applying nnUNetv2_plan_and_preprocess we get the nnUnet plans,

nnUNetPlans.json file
{
    "dataset_name": "Dataset706_ms_lesion_MP2RAGE_dil24",
    "plans_name": "nnUNetPlans",
    "original_median_spacing_after_transp": [
        1.0,
        1.0,
        1.0
    ],
    "original_median_shape_after_transp": [
        99,
        77,
        73
    ],
    "image_reader_writer": "SimpleITKIO",
    "transpose_forward": [
        0,
        1,
        2
    ],
    "transpose_backward": [
        0,
        1,
        2
    ],
    "configurations": {
        "2d": {
            "data_identifier": "nnUNetPlans_2d",
            "preprocessor_name": "DefaultPreprocessor",
            "batch_size": 521,
            "patch_size": [
                80,
                80
            ],
            "median_image_size_in_voxels": [
                76.0,
                73.0
            ],
            "spacing": [
                1.0,
                1.0
            ],
            "normalization_schemes": [
                "ZScoreNormalization"
            ],
            "use_mask_for_norm": [
                false
            ],
            "UNet_class_name": "PlainConvUNet",
            "UNet_base_num_features": 32,
            "n_conv_per_stage_encoder": [
                2,
                2,
                2,
                2,
                2
            ],
            "n_conv_per_stage_decoder": [
                2,
                2,
                2,
                2
            ],
            "num_pool_per_axis": [
                4,
                4
            ],
            "pool_op_kernel_sizes": [
                [
                    1,
                    1
                ],
                [
                    2,
                    2
                ],
                [
                    2,
                    2
                ],
                [
                    2,
                    2
                ],
                [
                    2,
                    2
                ]
            ],
            "conv_kernel_sizes": [
                [
                    3,
                    3
                ],
                [
                    3,
                    3
                ],
                [
                    3,
                    3
                ],
                [
                    3,
                    3
                ],
                [
                    3,
                    3
                ]
            ],
            "unet_max_num_features": 512,
            "resampling_fn_data": "resample_data_or_seg_to_shape",
            "resampling_fn_seg": "resample_data_or_seg_to_shape",
            "resampling_fn_data_kwargs": {
                "is_seg": false,
                "order": 3,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_seg_kwargs": {
                "is_seg": true,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_probabilities": "resample_data_or_seg_to_shape",
            "resampling_fn_probabilities_kwargs": {
                "is_seg": false,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "batch_dice": true
        },
        "3d_fullres": {
            "data_identifier": "nnUNetPlans_3d_fullres",
            "preprocessor_name": "DefaultPreprocessor",
            "batch_size": 7,
            "patch_size": [
                112,
                80,
                80
            ],
            "median_image_size_in_voxels": [
                99.0,
                76.0,
                73.0
            ],
            "spacing": [
                1.0,
                1.0,
                1.0
            ],
            "normalization_schemes": [
                "ZScoreNormalization"
            ],
            "use_mask_for_norm": [
                false
            ],
            "UNet_class_name": "PlainConvUNet",
            "UNet_base_num_features": 32,
            "n_conv_per_stage_encoder": [
                2,
                2,
                2,
                2,
                2
            ],
            "n_conv_per_stage_decoder": [
                2,
                2,
                2,
                2
            ],
            "num_pool_per_axis": [
                4,
                4,
                4
            ],
            "pool_op_kernel_sizes": [
                [
                    1,
                    1,
                    1
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ]
            ],
            "conv_kernel_sizes": [
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ]
            ],
            "unet_max_num_features": 320,
            "resampling_fn_data": "resample_data_or_seg_to_shape",
            "resampling_fn_seg": "resample_data_or_seg_to_shape",
            "resampling_fn_data_kwargs": {
                "is_seg": false,
                "order": 3,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_seg_kwargs": {
                "is_seg": true,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_probabilities": "resample_data_or_seg_to_shape",
            "resampling_fn_probabilities_kwargs": {
                "is_seg": false,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "batch_dice": false
        }
    },
    "experiment_planner_used": "ExperimentPlanner",
    "label_manager": "LabelManager",
    "foreground_intensity_properties_per_channel": {
        "0": {
            "max": 3140.0,
            "mean": 1592.6723636903955,
            "median": 1609.0,
            "min": 193.0,
            "percentile_00_5": 581.0,
            "percentile_99_5": 2502.0,
            "std": 382.50158483925566
        }
    }

In the nnUNetPlans.json file we can see in detail that when applying 3d_fullres and its strides between each layer

Based on this information, our 3D-Unet would look like this:

Fig. 2

image

Trainings are currently running!

@Nilser3
Copy link

Nilser3 commented Dec 21, 2023

Train 3D nnUnet

Following the philosophy previously described, a 3D nnUnet network was trained,
The model that we will call algo 2

Preprocessing pipeline
subjects_basel =( sub-P001    sub-P002        sub-P003        sub-P004 ...)
subjects_marseille =( sub-MRS00   sub-MRS01       sub-MRS02       sub-MRS03 ...)

bids_file="basel-mrs-mp2rage"
bids_basel_input="basel-mp2rage"
bids_marseille_input="marseille-3T--mp2rage"

for subject in "${subjects_basel [@]}"
    do
    # Set orient image and lesions mask to RPI 
    sct_image -i ../$bids_basel_input/$subject/anat/$subject"_UNIT1.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1.nii.gz"
    sct_image -i ../$bids_basel_input/derivatives/labels/$subject/anat/$subject"_UNIT1_lesion-manualNeuroPoly.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1_lesion.nii.gz"
    sct_image -i ../$bids_basel_input/derivatives/labels/$subject/anat/$subject"_UNIT1_label-SC_seg.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz"

    # Cropping with dilation (around the SC with 30 pixels in axial plane and 5 pixels in Z)
    sct_crop_image -i tmp/$subject"_ses-M0_UNIT1.nii.gz" -m tmp/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz" -o $bids_file/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1.nii.gz"  -dilate 30x30x5
    sct_crop_image -i tmp/$subject"_ses-M0_UNIT1_lesion.nii.gz" -m tmp/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz"  -o $bids_file/derivatives/labels/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1_label-lesion_rater2.nii.gz"  -dilate 30x30x5

for subject in "${subjects_marseille [@]}"
    do
    # Set orient image and lesions mask to RPI M0 and M24
    sct_image -i ../$bids_marseille_input/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1.nii.gz"
    sct_image -i ../$bids_marseille_input/derivatives/labels/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1_label-lesion_rater2.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1_lesion.nii.gz"
    sct_image -i ../$bids_marseille_input/derivatives/labels/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz" -setorient RPI -o tmp/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz"

    # Cropping with dilation (around the SC with 30 pixels in axial plane and 5 pixels in Z) M0 and M24
    sct_crop_image -i tmp/$subject"_ses-M0_UNIT1.nii.gz" -m tmp/$subject"_ses-M0_UNIT1_sc-all.nii.gz" -o $bids_file/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1.nii.gz"  -dilate 30x30x5
    sct_crop_image -i tmp/$subject"_ses-M0_UNIT1_lesion.nii.gz" -m tmp/$subject"_ses-M0_UNIT1_sc-all.nii.gz"  -o $bids_file/derivatives/labels/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1_label-lesion_rater2.nii.gz"  -dilate 30x30x5
done

Table 1 : Dataset Marseille - Basel for nnUnet model

Center Resolution (mm) Train / validation (M0) Train / validation (M24) Test (M0) Test (M24) Total (N)
Basel 1.0 x 1.0 x 1.0 141 0 39 0 180
Marseille 1.0 x 0.9375 x 0.9375 24 11 3 3 41
Total 176 45 221

Testing dataset from marseille are 3 subjects with M0 and M24 (so different subjects in Training and Testing )

nnUNetTrainer__nnUNetPlans__3d_fullres/plans.json
        "3d_fullres": {
            "data_identifier": "nnUNetPlans_3d_fullres",
            "preprocessor_name": "DefaultPreprocessor",
            "batch_size": 7,
            "patch_size": [
                112,
                80,
                80
            ],
            "median_image_size_in_voxels": [
                99.0,
                77.0,
                75.0
            ],
            "spacing": [
                1.0,
                1.0,
                1.0
            ],
            "normalization_schemes": [
                "ZScoreNormalization"
            ],
            "use_mask_for_norm": [
                false
            ],
            "UNet_class_name": "PlainConvUNet",
            "UNet_base_num_features": 32,
            "n_conv_per_stage_encoder": [
                2,
                2,
                2,
                2,
                2
            ],
            "n_conv_per_stage_decoder": [
                2,
                2,
                2,
                2
            ],
            "num_pool_per_axis": [
                4,
                4,
                4
            ],
            "pool_op_kernel_sizes": [
                [
                    1,
                    1,
                    1
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ],
                [
                    2,
                    2,
                    2
                ]
            ],
            "conv_kernel_sizes": [
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ],
                [
                    3,
                    3,
                    3
                ]
            ],
            "unet_max_num_features": 320,
            "resampling_fn_data": "resample_data_or_seg_to_shape",
            "resampling_fn_seg": "resample_data_or_seg_to_shape",
            "resampling_fn_data_kwargs": {
                "is_seg": false,
                "order": 3,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_seg_kwargs": {
                "is_seg": true,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "resampling_fn_probabilities": "resample_data_or_seg_to_shape",
            "resampling_fn_probabilities_kwargs": {
                "is_seg": false,
                "order": 1,
                "order_z": 0,
                "force_separate_z": null
            },
            "batch_dice": false

Models to segment lesions in MP2RAGE data

  • algo 1 : Ivadomed model - release r20230210 , trained on basel-mp2rage with ensembling/bagging apporaches for soft prediction and binarization at 0.2 threshold
  • algo 2 : 3D nnUnet single-class model based on marseille-3T-mp2rage and basel-mp2rage in 5 folds.

Preliminary results on testing dataset (Table 1)

Preprocessing pipeline for `algo-1` testing
#!/bin/bash
subjects=( sub-P001     sub-P002        sub-P003        ... )

for subject in "${subjects[@]}"
  do
  sct_maths -i ../../mrs_basel/basel-mrs-mp2rage/derivatives/labels/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1_label-SC_seg.nii.gz" -dilate 5 -shape ball -o $subject"_UNIT1_label-SC_seg_dil5.nii.gz" 
  sct_maths -i $subject"_UNIT1_label-SC_seg_dil5.nii.gz" -dilate 32 -dim 2 -shape disk -o $subject"_UNIT1_label-SC_seg_dil32.nii.gz" 
  sct_crop_image -i ../../mrs_basel/basel-mrs-mp2rage/$subject/ses-M0/anat/$subject"_ses-M0_UNIT1.nii.gz" -m $subject"_UNIT1_label-SC_seg_dil32.nii.gz" -o $subject"_UNIT1_crop.nii.gz" 
  # Run prediction 
  list_seed=(01 02 03 04 05)
  for seed in ${list_seed[@]}; do
        ivadomed_segment_image -i $subject"_UNIT1_crop.nii.gz" -m ../../ivadomed_env/model_r20230210/model_seg_lesion_mp2rage_r20230210_dil32_seed${seed}/model_seg_ms_lesion_mp2rage/ -s _pred${seed}
  done
  sct_image -i  $subject"_UNIT1_crop"_pred*.nii.gz -concat t -o  $subject"_UNIT1_crop_predMean.nii.gz"
  sct_maths -i  $subject"_UNIT1_crop_predMean.nii.gz" -mean t -o  $subject"_UNIT1_crop_Mean.nii.gz" 
  sct_maths -i $subject"_UNIT1_crop_Mean.nii.gz" -bin 0.2 -o $subject"_UNIT1_lesion_bin.nii.gz" 
done

image
Subjects from Basel (sub-Basel) were part of algo-1 training).

QC for testing dataset

Legend of QC masks

  • lesion_rater2.nii.gz -> rater 2 (Nilser) binary segmentation #267
  • label-lesion_algo1.nii.gz -> prediction of "algo 1" with binarization at 0.2
  • label-lesion_algo2.nii.gz -> prediction of "algo 2"

Testing in other MP2RAGE datasets:

nih-ms-mp2rage (no GT)

sub-nih003
NIH

QC for all data

Legend of QC masks

  • label-lesion_algo1.nii.gz -> prediction of "algo 1" with binarization at 0.2
  • label-lesion_algo2.nii.gz -> prediction of "algo 2"

marseille-7T-iso-mp2rage (no GT)

  • Isotropic images r = 0.7 x 0.7 x 0.7 - no resampling was applied
  • To ensure that the XY axis is the same size as the algo-2 training "patch_size", was applied:
    • sct_crop_image -i IMAGE.nii.gz -m IMAGE_sc.nii.gz -o IMAGE_sc_crop.nii.gz -dilate 43x43x0

sub-MRS05 ses-M0
MRS-iso

marseille-7T-aniso-mp2rage (no GT)

  • anisotropic images r = 0.4 x 0.4 x 4 - no resampling was applied
  • To ensure that the XY axis is the same size as the algo-2 training "patch_size", was applied:
    • sct_crop_image -i IMAGE.nii.gz -m IMAGE_sc.nii.gz -o IMAGE_sc_crop.nii.gz -dilate 100x100x0

sub-MRS00 sesM0 run-1
MRS-aniso

Next steps:

  1. Single-class model VS multi-class model
  2. GT for MS lesion on nih-ms-mp2rage , marseille-7T-iso-mp2rage and marseille-7T-aniso-mp2rage

@Nilser3
Copy link

Nilser3 commented Jan 5, 2024

Single-class VS Multi-class model

  • algo 2 singleclass -> algo 2 (Described here)
  • algo 2 multiclass-> Region-based nnunet multiclass model (SC, MS lesion), same parameters, data, splits used in algo 2

Results

image

  • Dice single-class MS lesion mean: 0.62 (STD: 0.25)
  • Dice multi-class MS lesion mean: 0.61 (STD: 0.24 )
  • Dice multi-class SC mean: 0.964 (STD: 0.015)

@jcohenadad
Copy link
Member Author

jcohenadad commented Jan 19, 2024

algo2 seems better to start off for creating the GT for:

  • NIH dataset
  • marseille-7T-iso-mp2rage
  • marseille-7T-aniso-mp2rage

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants