This is a EffecientnetB0 reimplementation of CheXNet. The model takes a chest X-ray image as input and outputs the probability of each thoracic disease along with a likelihood map of pathologies. This implementation has been updated to use modern deep learning techniques, including EfficientNet as the backbone architecture, mixed precision training, and advanced evaluation metrics.
- Modern Architecture: Replaced DenseNet121 with EfficientNetB0 for better performance and efficiency.
- Mixed Precision Training: Utilized PyTorch's
autocast
andGradScaler
for faster training and reduced memory usage. - Advanced Evaluation Metrics: Added F1-score, precision, and recall alongside AUROC for comprehensive model evaluation.
- Improved Localization: Enhanced localization maps using Grad-CAM for better interpretability.
- User-Friendly API: Added functions for easy model loading, prediction, and response generation.
The ChestX-ray14 dataset comprises 112,120 frontal-view chest X-ray images of 30,805 unique patients with 14 disease labels. To evaluate the model, we randomly split the dataset into training (70%), validation (10%), and test (20%) sets. Partitioned image names and corresponding labels are placed under the directory labels.
- Python 3.8+
- PyTorch 1.10+
- Torchvision
- NumPy
- Scikit-learn
- PIL (Pillow)
- CUDA (optional but recommended for GPU acceleration)
Install the required dependencies using:
pip install torch torchvision numpy scikit-learn pillow
git clone https://github.com/your-repo/ModelRay.git
cd ModelRay
- Download the ChestX-ray14 dataset from the NIH website.
- Extract the images to the directory
Model-Ray02/images
.
To train the model, run:
python train.py
Training Configuration:
- Batch size: 64
- Learning rate: 0.001
- Optimizer: AdamW
- Mixed precision training: Enabled
- Data augmentation: Random rotations, flips, and color jittering
To evaluate the model on the test set, run:
python evaluate.py
This will compute the following metrics for each pathology:
- AUROC
- F1-score
- Precision
- Recall
To make predictions on a new chest X-ray image, use the predict.py
script:
python predict.py --image_path path_to_your_image.jpg
This will output the predicted probabilities for each of the 14 pathologies.
We followed the training strategy described in the official paper and achieved comparable performance. Below is a comparison of our implementation with other state-of-the-art models:
Pathology | Wang et al. | CheXNet | ModelRay | Our Implemented ModelRay | Our Improved Model |
---|---|---|---|---|---|
Atelectasis | 0.716 | 0.772 | 0.8094 | 0.8294 | 0.8311 |
Cardiomegaly | 0.807 | 0.904 | 0.9248 | 0.9165 | 0.9220 |
Effusion | 0.784 | 0.859 | 0.8638 | 0.8870 | 0.8891 |
Infiltration | 0.609 | 0.695 | 0.7345 | 0.7143 | 0.7146 |
Mass | 0.706 | 0.792 | 0.8676 | 0.8597 | 0.8627 |
Nodule | 0.671 | 0.717 | 0.7802 | 0.7873 | 0.7883 |
Pneumonia | 0.633 | 0.713 | 0.7680 | 0.7745 | 0.7820 |
Pneumothorax | 0.806 | 0.841 | 0.8887 | 0.8726 | 0.8844 |
Consolidation | 0.708 | 0.788 | 0.7901 | 0.8142 | 0.8148 |
Edema | 0.835 | 0.882 | 0.8878 | 0.8932 | 0.8992 |
Emphysema | 0.815 | 0.829 | 0.9371 | 0.9254 | 0.9343 |
Fibrosis | 0.769 | 0.767 | 0.8047 | 0.8304 | 0.8385 |
Pleural Thickening | 0.708 | 0.765 | 0.8062 | 0.7831 | 0.7914 |
Hernia | 0.767 | 0.914 | 0.9164 | 0.9104 | 0.9206 |
Our implementation includes Grad-CAM for generating localization maps, which highlight regions of the chest X-ray that contribute most to the model's predictions. To generate localization maps, run:
python localize.py --image_path path_to_your_image.jpg
This will save the localization map as an image in the local
directory.