Skip to content
/ magnet Public

Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation

License

Notifications You must be signed in to change notification settings

kisonho/magnet

Repository files navigation

Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation

This is the official implementation of MAG-MS.

For the official implementation of MAGNET: A Modality-Agnostic Networks for Medical Image Segmentation, please check to branch stable-1.1.

MAG-MS is designed to be compatible with MAGNET (v1). The new MAGNET (v2) used in MAG-MS is designed to support multi-modality self-distillation and multi-modality feature distillation.

Pre-request

Installation

Use the package manager pip to install MAG-MS.

pip install magms

Get Started

  1. Load datasets
training_dataset = ...
validation_dataset = ...
  1. Simpy build the MAGNET (UNETR backbone) with magnet.build function, or use the magnet.build_v2 (UNETR backbone)/magnet.build_v2_unet (3D UNet backbone) function for the new MAGNET used in MAG-MS
num_modalities: int = ...
num_classes: int = ...
img_size: Union[int, Sequence[int]] = ...
model = magnet.build_v2(num_modalities, num_classes, img_size, target_dict=target_dict)
  1. Or use the deeper magnet.nn framework to customize MAGNET backbone
encoder1: torch.nn.Module = ...
encoder2: torch.nn.Module = ...
fusion: torch.nn.Module = ...
decoder: torch.nn.Module = ...
model = magnet.nn.MAGNET2(encoder1, encoder2, fusion=fusion, decoder=decoder)
  1. Define MAGMS loss function
main_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
kldiv_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
mse_loss_fn: list[Callable[[Any, Any], torch.Tensor]] = ...
self_distillation_loss_fn = magnet.losses.MAGSelfDistillationLoss(main_loss_fn, kldiv_loss_fn)
feature_distillation_loss_fn = magnet.losses.MAGFeatureDistillationLoss(self_distillation_loss_fn, mse_loss_fn)
loss_fn = feature_distillation_loss_fn
  1. Compile manager and train/test
optimizer = ...
metric_fns = ...

epochs = ...
callbacks = ...

manager = magnet.Manager(model, optimizer, loss_fn=loss_fn, metric_fns=metric_fns)
manager.fit(training_dataset, epochs, val_dataset=validation_dataset, callbacks=callbacks)
summary.test(validation_dataset)
print(summary)

Monai Support

  • Using magnet.MonaigManager instead of Manager
  • Post processing support with post_labels and post_predicts
post_labels = [...]
post_predicts = [...]

manager = magnet.MonaigManager(model, post_labels=post_labels, post_predicts=post_predicts, optimizer=optimizer, loss_fn=loss_fn, metric_fns=metric_fns)

Cite this work

@article{he2023modality,
  title={Modality-Agnostic Learning for Medical Image Segmentation Using Multi-modality Self-distillation},
  author={He, Qisheng and Summerfield, Nicholas and Dong, Ming and Glide-Hurst, Carri},
  journal={arXiv preprint arXiv:2306.03730},
  year={2023}
}