Skip to content

matteo-bastico/MI-Seg

Repository files navigation

Contributors Forks Stargazers Issues MIT License LinkedIn


MI-Seg

MI-Seg is a framework based on MONAI libray for Cross-Modality clinical images Segmentation using Conditional Models and Interleaved Training.
Explore the docs »

Report Bug · Request Feature

Table of Contents
  1. About The Project
  2. Getting Started
  3. Usage
  4. Roadmap
  5. Contributing
  6. License
  7. Contact
  8. Acknowledgments

Citation

Our paper has been accepted at ICCVW 2023 and is available here and on ArXiv. Please cite our work with

  @InProceedings{Bastico_2023_ICCV,
    author    = {Bastico, Matteo and Ryckelynck, David and Cort\'e, Laurent and Tillier, Yannick and Decenci\`ere, Etienne},
    title     = {A Simple and Robust Framework for Cross-Modality Medical Image Segmentation Applied to Vision Transformers},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
    month     = {October},
    year      = {2023},
    pages     = {4128-4138}
  }

(back to top)

Built With

Our released implementation is tested on:

  • Ubuntu 22.04
  • Python 3.10.8
  • PyTorch 1.13.1 and PyTorch Lightning 1.8.6
  • Ray 2.2.0
  • NVIDIA CUDA 11.7
  • Monai 1.1.0
  • Optuna 3.1.0

(back to top)

Getting Started

Prerequisites

  • Clone our project folder
  • Create and lunch your conda environment with
    conda create -n MI-Seg python=3.10.8
    conda activate MI-Seg

(back to top)

Usage

Dataset

The dataset used in our experiments can be downloaded here upon access request. Download and unzip it into /dataset/MM-WHS folder.

[Optional] Convert label and Perform N4 Bias Correction of MRIs using the provided Notebook load_data.ipynb

You should end up with a similar data structure (sub-folders are not represented here)

  MM-WHS
  ├── ct_train	# Ct training folder
  │   ├── ct_train_1001_image.nii.gz # Image
  │   ├── ct_train_1001_label.nii.gz # Label
  │   ...
  ├── ct_test
  ├── mr_train
  ├── mr_test
  ...

The splits we used for our cross_validation are provided in CT_fold1.json and CT_fold2.json.

Training

To train a model you can use the train.py script provided. Single training are based on PyTorch Lightning and all the Trainer arguments can be passed to the script (see here). Additionally, we provide model, data and logger-specific arguments. To have a full list of the possible arguments execute python train.py --help.

An example of C-Swin-UNETR training on single GPU is shown in the following

python train.py --model_name=swin_unetr --out_channels=6 --feature_size=48 --num_heads=3 --accelerator=gpu --devices=1 --max_epochs=2500 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --lr=1e-4 --batch_size=1 --patches_training_sample=1

The available models are unet, unetr and swin_unetr and pre_swin_unetr (in this case the pretrained model of monai must be provided as --pre_swin.

Furthermore, we use WandB to log the experiments and specifications can be set as arguments. In the previous example wandb will run in online mode, so you need to provided login and API key. To change wandb mode set wandb_mode=offline.

Note: AMP (--no_amp) should be disabled with checkpointing to save memory during training of Swin_Unetr based models (--use_checkpoint).

Testing

Our pre-trained models can be downloaded here and tested with the test.py script. The path of the model weights should be provided as --checkpoint (note that the model weight should be under the state_dict key).

Example:

python test.py --out_channels=6 --model_name=swin_unetr --num_workers=2 --feature_size=48  --num_heads=3 --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --checkpoint=experiments/<path>

Hyper-parameters Optimization

Hyper-parameters optimization is based on Optuna. For the moment, the script supports automatic setup of distributed tuning ONLY on Slurm environments. Therefore, it needs to be adapted by the user to run in different multi-GPUs enviroments.

The hyper-parameters grid is set in automatic for each model as stated in our paper and the tuning can be started as in the following. The script will run 10 trials, with TPE optimizer and ASHA pruner, and save the in the MI-Seg.log log file (if Slurm) or MI-Seg.sqlite (if not Slurm).

python -u tune.py --num_workers=2 --out_channels=6 --no_include_background --criterion=generalized_dice_focal --scheduler=warmup_cosine --model_name=swin_unetr --n_trials=10 --study_name=c-swin-unetr --max_epochs=2500 --check_val_every_n_epoch=50 --batch_size=1 --patches_training_sample=4 --iters_to_accumulate=4 --cycles=0.5 --storage_name=MI-Seg --min_lr=1e-5 --max_lr=1e-3 --vit_norm_name=instance_cond --encoder_norm_name=instance_cond  --port=23456

The script can be run multiple time with the same --storage_name in order to continue a previous tuning.

To open log files dashboards not stored as RDB, we provide the utils/run_server.py --path=<storage> script. The dashboard of our tuning presented in the paper is available at experiments/optuna/MI-Seg.log and can be open with

python utils/run_server.py --path=experiments/optuna/MI-Seg.log

(back to top)

Pre-Trained Models

The best pre-trained model weights for Conditional UNet and Swin-UNETR resulting from our hyper-parameters optimization can be downloaded here.

For instance, to produce the segmentation on the test dataset using the provided weights you can run for Conditional UNet:

python predict_whs.py --model=unet_vanilla --encoder_norm_name=instance_cond --feature_size 16 64 128 256 512 --num_res_units=3 --strides 1 2 2 2 1 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result

or for Conditional Swin-UNETR:

python -u predict_whs.py --model=swin_unetr --encoder_norm_name=instance_cond --vit_norm_name=instance_cond --feature_size=36 --num_heads=4 --out_channels=8 --checkpoint=path/to/weights.pt --result_dir=path/to/result

Roadmap

  • Implement LN for convolutional layers of Monai (testing purposes)
  • Implement distributed tuning on not-Slurm environment

See the open issues for a full list of proposed features (and known issues).

(back to top)

Contributing

If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again!

  1. Fork the Project
  2. Create your Feature Branch (git checkout -b feature/my_feature)
  3. Commit your Changes (git commit -m 'Add my_feature')
  4. Push to the Branch (git push origin feature/my_feature)
  5. Open a Pull Request

(back to top)

License

Distributed under the MIT (or other) License. See LICENSE.txt for more information.

(back to top)

Contact

Matteo Bastico - @matteobastico - matteo.bastico@minesparis.psl.eu

Project Link: https://github.com/matteo-bastico/MI-Seg

(back to top)

Acknowledgments

This work was supported by the H2020 European Project ...

(back to top)

About

Independent Multi-Modal Segmentation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published