Pytorch implementation for MICCAI 2024 paper Language-Enhanced Local-Global Aggregation Network for Multi-Organ Trauma Detection
Abdominal trauma is one of the leading causes of death in the elderly population and increasingly poses a global challenge. However, interpreting CT scans for abdominal trauma is considerably challenging for deep learning models. Trauma may exist in various organs presenting different shapes and morphologies. In addition, a thorough comprehension of visual cues and various types of trauma is essential, demanding a high level of domain expertise. To address these issues, this paper introduces a language-enhanced local-global aggregation network that aims to fully utilize both global contextual information and local organ-specific information inherent in images for accurate trauma detection. Furthermore, the network is enhanced by text embedding from Large Language Models (LLM). This LLM-based text embedding possesses substantial medical knowledge, enabling the model to capture anatomical relationships of intra-organ and intra-trauma connections. We have conducted experiments on one public dataset of RSNA Abdominal Trauma Detection (ATD) and one in-house dataset. Compared with existing state-of-the-art methods, the F1-score of organ-level trauma detection improves from 51.4% to 62.5% when evaluated on the public dataset and from 61.9% to 65.2% on the private cohort, demonstrating the efficacy of our proposed approach for multi-organ trauma detection.
Package Version
---------------------- -------------------
einops 0.7.0
monai 1.3.0
nibabel 5.2.1
numpy 1.23.5
pandas 2.0.3
positional_encodings 6.0.3
scikit_learn 1.3.2
scipy 1.13.0
SimpleITK 2.3.1
torch 2.1.2
torchvision 0.16.2
Our local vision encoder use pre-trained weights [link] in CLIP-Driven Universal Model. Please download it before running the code.
[1] J. Liu, Y. Zhang, J. Chen, J. Xiao, Y. Lu, B. Landman, Y. Yuan, A. Yuille, Y. Tang, and Z. Zhou. Clip-driven universal model for organ segmentation and tumor detection. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 21152–21164, 2023.
Our used text embeddings consists of Organ-wise Prompt and Category-wise Prompt embeddings. Please download them at appropriate stages.
The organ-wise prompt is used in training and inference stages, it is composed of specific organ names with the medical template, as shown in the main figure.
Types | Download |
---|---|
Organs | link |
The category-wise prompt is exclusively involved in the training process to guide the predictions, it's generated according to the label and has a similar template to the organ-wise prompt. We test three types of prompts: Fine-grained, Position, and Category, which are listed below. We select the Category type prompt finally, as shown in the main Figure.
Types | Download |
---|---|
Fine-grained | link |
Position | link |
Category | link |
If you want to design your own prompts, you can follow the text encoding method (e.g. CLIP-Driven Universal Model)
1. Download the RSNA-ATD dataset: https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/data.
python dicom2nii.py
3. Train a segmentation model (e.g. TransUnet) to obtain all segmentation maps of organs.
[2] J. Chen, Y. Lu, Q. Yu, X. Luo, E. Adeli, Y. Wang, L. Lu, A. Yuille, and Y. Zhou. Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306, 2021.
python monai_preprocessing.py
python crop_to_size.py
.
├── code
│ ├── dataset
│ ├── models
│ ├── runs
│ │ └── our_checkpoints
│ ...
│ ├── label.csv
│ ├── train_data.txt
│ ├── val_data.txt
│ ├── test_data.txt
│ ├── unet.pth (pre-trained weights)
│ ├── four_organ.pth (organ-wise prompt embedding)
│ └── Trauma_Label.pth (category-wise prompt embedding)
└── preprocessed_data
└──Global_method
│ ├── Patient-1.npz
│ └── *.npz
└──Local_method
├── Patient-1.npz
└── *.npz
python GLFF_train.py --model_name local_prompt_global_prompt --alfa 0.9 --prompt_loss True
If you want to reproduce our experiment result, you can use our model weights [link], and run the command below:
python inference_global_local.py --model_name local_prompt_global_prompt \
--model_path "./runs/our_checkpoints/model_best.pt" \
--local_path "/research/d1/rshr/qxhu/PublicDataset/Jianxun/our_methods" \
--global_path "/research/d1/rshr/qxhu/PublicDataset/Jianxun/baseline_methods" \
--label_path "./label.csv" \
--test_list "./test_data.txt"
If this repository is useful for your research, please cite:
@article{yu2024lelgan,
title={Language-Enhanced Local-Global Aggregation Network for Multi-Organ Trauma Detection},
author={Yu, Jianxun and Hu, Qixin and Jiang, Meirui and Wang, Yaning and Wong, Chin Ting and Wang, Jing and Zhang, Huimao and Dou, Qi},
journal={International Conference on Medical Image Computing and Computer Assisted Intervention},
year={2024}
}
For any questions, please contact ‘jianxyu98@gmail.com’