This repository contains the official implementation of CountXplain, a novel approach for interpretable cell counting that combines prototype-based learning with density map estimation. The method provides both accurate counting predictions and interpretable explanations for the model's decisions.
CountXplain addresses the critical need for interpretability in cell counting applications by introducing a prototype-based architecture that:
- Learns interpretable prototypes that represent typical cellular patterns
- Provides density map estimation for spatial understanding of cell distributions
- Offers model explanations through prototype similarity visualization
- Maintains high counting accuracy while ensuring interpretability
The model consists of two main components:
- Counting Model (CSRNet): A density estimation network based on VGG-16 frontend with dilated convolutions
- Prototype Network: Learns a set of prototypes that capture representative cellular patterns and provides interpretability
- Prototype-based interpretability: Visual explanations through learned prototypes
- Density map estimation: Spatial understanding of cell distributions
- Python 3.8+
- PyTorch
- PyTorch Lightning
- OpenCV
- NumPy
- Matplotlib
- Weights & Biases (for logging)
- H5PY (for density map storage)
- Clone the repository:
git clone https://github.com/yourusername/countxplain.git
cd countxplain- Install dependencies:
pip install torch torchvision pytorch-lightning opencv-python numpy matplotlib wandb h5py tqdm scipy pandasOrganize your dataset with the following structure:
Dataset/
├── trainval/
│ ├── images/ # Training images (.png)
│ └── densities/ # Ground truth density maps (.h5)
└── test/
├── images/ # Test images (.png)
└── densities/ # Test density maps (.h5)
First, train the base counting model (CSRNet):
python train_counting_model.py --dataset DCC --model_name csrnet --batch_size 2 --lr 0.001Train the CountXplain model with prototypes:
python train_push.py --dataset DCC --num_prototypes 20 --fg_coef 1 --diversity_coef 1 --proto_to_feature_coef 1 --batch_size 2 --lr 0.001--num_prototypes: Number of prototypes to learn (default: 20)--fg_coef: Weight for density estimation loss--diversity_coef: Prototype diversity loss coefficient--proto_to_feature_coef: Prototype-to-feature alignment coefficient--batch_size: Training batch size--lr: Learning rate
If you use this code in your research, please cite:
@article{countxplain2024,
title={CountXplain: Interpretable Cell Counting with Prototype-Based Density Map Estimation},
author={[Authors]},
journal={[Journal]},
year={2024}
}This project is licensed under the Creative Commons Attribution 4.0 International License - see the license.txt file for details.