This is the official implementation of MAG-MS.
For the official implementation of MAGNET: A Modality-Agnostic Networks for Medical Image Segmentation, please check to branch stable-1.1.
MAG-MS is designed to be compatible with MAGNET (v1). The new MAGNET (v2) used in MAG-MS is designed to support multi-modality self-distillation and multi-modality feature distillation.
- Python >= 3.9
- PyTorch >= 1.12.1
- torchmanager >= 1.1
- Monai >= 1.1
Use the package manager pip to install MAG-MS.
pip install magms
- Load datasets
training_dataset = ...
validation_dataset = ...
- Simpy build the MAGNET (UNETR backbone) with
magnet.build
function, or use themagnet.build_v2
(UNETR backbone)/magnet.build_v2_unet
(3D UNet backbone) function for the new MAGNET used in MAG-MS
num_modalities: int = ...
num_classes: int = ...
img_size: Union[int, Sequence[int]] = ...
model = magnet.build_v2(num_modalities, num_classes, img_size, target_dict=target_dict)
- Or use the deeper
magnet.nn
framework to customize MAGNET backbone
encoder1: torch.nn.Module = ...
encoder2: torch.nn.Module = ...
fusion: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = magnet.nn.MAGNET2(encoder1, encoder2, fusion=fusion, decoder=decoder)
- Define MAGMS loss function
main_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
kldiv_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
mse_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
self_distillation_loss_fn = magnet.losses.MAGSelfDistillationLoss(main_loss_fn, kldiv_loss_fn)
feature_distillation_loss_fn = magnet.losses.MAGFeatureDistillationLoss(self_distillation_loss_fn, mse_loss_fn)
loss_fn = feature_distillation_loss_fn
- Compile manager and train/test
optimizer = ...
metric_fns = ...
epochs = ...
callbacks = ...
manager = magnet.Manager(model, optimizer, loss_fn=loss_fn, metric_fns=metric_fns)
manager.fit(training_dataset, epochs, val_dataset=validation_dataset, callbacks=callbacks)
summary.test(validation_dataset)
print(summary)
- Using
magnet.MonaigManager
instead ofManager
- Post processing support with
post_labels
andpost_predicts
post_labels = [...]
post_predicts = [...]
manager = magnet.MonaigManager(model, post_labels=post_labels, post_predicts=post_predicts, optimizer=optimizer, loss_fn=loss_fn, metric_fns=metric_fns)
@article{he2023modality,
title={Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation},
author={He, Qisheng and Summerfield, Nicholas and Dong, Ming and Glide-Hurst, Carri},
journal={arXiv preprint arXiv:2306.03730},
year={2023}
}