This repository is part of a final project for the "Deep Learning for Advanced Computer Vision 224C" course at the University of California, Santa Cruz. It reproduces and enhances the TransUNet model, as detailed in the paper "TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation". The project's initial phase involved training a baseline TransUNet model, followed by targeted modifications to improve its segmentation performance.
Authors : Marian Zlateva (mzlateva@ucsc.edu) and Marzia Binta Nizam(manizam@ucsc.edu)
Details of our experiments can be found here
This study focuses on enhancing the TransUNet model for medical image segmentation by improving its generalization capabilities. Initially, the baseline TransUNet model achieved a Mean Dice score of 0.769 and a median Hausdorff Distance of 32.87. By incorporating Channel Attention, Dual Attention mechanisms, and CutMix data augmentation, significant improvements were made, culminating in a Mean Dice score of 0.823 and a reduced median Hausdorff Distance of 19.74. These strategic modifications have enhanced the model's ability to accurately segment complex anatomical structures, advancing the application of medical imaging.
The following table summarizes the performance improvements across various organs:
Model | Average Dice | Median HD95 | Aorta | Gallbladder | Kidney (L) | Kidney (R) | Liver | Pancreas | Spleen | Stomach |
---|---|---|---|---|---|---|---|---|---|---|
TransUNet (Baseline) | 0.769 | 32.87 | 0.868 | 0.596 | 0.814 | 0.740 | 0.945 | 0.542 | 0.873 | 0.778 |
TransUNet (ours) | 0.823 | 19.74 | 0.882 | 0.631 | 0.860 | 0.831 | 0.946 | 0.693 | 0.907 | 0.833 |
Note: All values are rounded to three decimal places for clarity.
Below are the segmentation visualizations comparing the baseline and enhanced models:
Figure 1: Visual comparison of segmentation performance between baseline TransUNet and TransUNet with Dual Attention and CutMix.
Please prepare an environment with python=3.7, and then use the command (following the original TransUnet repo)
pip install -r requirements.txt
You can view our environment specification here.
Download the Google pre-trained ViT models following the original repository's instruction.
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz &&
mkdir ./model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ./model/vit_checkpoint/imagenet21k/{MODEL_NAME}.npz
The experiments were conducted on the Synapse multi-organ segmentation dataset. Please refer to the original repository for the data preparation.
Run the train script on synapse dataset. We used batch size 8 due to our limited GPU access, but the original code supports multiple GPUs as well.
CUDA_VISIBLE_DEVICES=0 python train.py --dataset Synapse --vit_name R50-ViT-B_16
Run the test script on synapse dataset. It supports testing for both 2D images and 3D volumes.
python test.py --dataset Synapse --vit_name R50-ViT-B_16 --is_savenii
You can download our trained model from here to test.
Please refer to this notebook for visualizing the predictions.
We appreciate the developers of TransUNet and the provider of the Synapse multi-organ segmentation dataset. We are grateful to Professor Yuyin Zhou for her invaluable guidance and insightful suggestions throughout the duration of this project. 😃 😃