Multi-task and Adversarial CNN Training: Learning Interpretable Pathology Features Improves CNN Generalization
With physicians being accountable for the diagnosis, it is fundamental that CNNs ensure that relevant pathology features are being considered.
Building on top of successfully existing techniques such as multi-task learning, domain adversarial training and concept-based interpretability, we addresses the challenge of introducing diagnostic factors in the training objectives.
The architecture in this repo learns end-to-end an uncertainty-based weighting combination of multi-task and adversarial losses.
This can be used to encourage the model to focus on pathology features such as density and pleomorphism of nuclei, e.g. variations in size and appearance, while discarding misleading features such as staining differences.
Explore the docs »
View Examples
·
Report Bug
·
Our Regression Concept Vectors toolbox (https://github.com/maragraziani/rcvtool) generates explanations about the relevance of a given concept to the decision making of a CNN classifier. No possibility is given to act on the training process and modify the learning of a concept. The architecture in this paper aims at filling this gap, allowing us to discourage the learning of a confounding concept, e.g. domain, staining, watermarks, and to encourage the learning of discriminant concepts. The architecture merges the developmental efforts of three successful techniques, namely multi-task learning [1], adversarial training [2] and high-level concept learning in internal network features [3,4]. This architecture is trained on the histopathology task of breast cancer classification, with the aim of enforcing the learning of diagnostic features that match the physicians' diagnosis procedure, such as nuclei morphology and density.
- 1 Caruana, Rich. "Multitask learning." Machine learning 28.1 (1997): 41-75.
- 2 Ganin, Yaroslav, et al. "Domain-adversarial training of neural networks." The journal of machine learning research 17.1 (2016): 2096-2030.
- 3 Kim, Been, et al. "Interpretability beyond feature attribution: Quantitative testing with concept activation vectors (tcav)." International conference on machine learning. PMLR, 2018.
- 4 Graziani, Mara, Vincent Andrearczyk, and Henning Müller. "Regression concept vectors for bidirectional explanations in histopathology." Understanding and Interpreting Machine Learning in Medical Image Computing Applications. Springer, Cham, 2018. 124-132.
To get a local copy up and running follow these simple steps.
This code was developed in Tensorflow 1.8 and Keras 2.2.4. Standard packages (e.g. numpy, scikit-learn, pandas, matplotlib, etc.) are needed to replicate the experiments. The complete list of dependencies is in requirements.txt. Follow the instructions in Installation to set the environment.
Installation should take ~20 minutes on a normal laptop. Follow the steps below.
- Clone the repo
git clone https://github.com/maragraziani/multitask_adversarial
- Install python packages with pip and the correct versions of Tensorflow and Keras.
pip install -r requirements.txt pip install https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl pip install keras==2.2.4 setproctitle pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn
- Install further dependencies for the histopathology application
cd lib/TASK_2_UC1 git clone https://github.com/medgift/PROCESS_L2.git mv PROCESS_L2/* .
Training time may take a few hours (5-6 hours) depending on the model being trained and on the chosen configuration.
Rerun the command below by replacing the elements in the brakets with desired values (e.g. EXPERIMENT_NAME=BASELINE, SEED=1001). Run [-h] option for help.
bash routines/train_baseline.sh [EXPERIMENT_NAME] [SEED]
Expected outcome (in results/):
├results/[EXPERIMENT_NAME]
├── [EXPERIMENT_NAME]_log.txt
├── best_model.h5
├── lr_monitor.log
├── normalizer.npy
├── normalizing_patch.npy
├── seed.txt
├── training_log.npy
You can rerun the command below by replacing the elements in the brakets with the desired values (e.g. EXPERIMENT_NAME=MTA, SEED=1001, CONCEPT_LIST="ncount, narea"). Run [-h] option for help.
bash routines/train_uncertainty_weighted_multitask.sh [EXPERIMENT_NAME] [SEED] [CONCEPT_LIST]
The expected outcome (to be found in results/) looks as follows:
├results/[EXPERIMENT_NAME]
├── [EXPERIMENT_NAME]_log.txt
├── best_model.h5
├── lr_monitor.log
├── normalizer.npy
├── normalizing_patch.npy
├── seed.txt
├── training_log.npy
├── ERR.log
├── loss1_rec.log
├── loss2_rec.log
├── train_by_epoch.txt
├── training_log.npy
├── val_by_epoch.txt
├── val_acc_log.npy
To train the multi-task adversarial model (including the adversarial branch (adversarial to WSI acquisition center)):
Rerun the command below by replacing the elements in the brakets with the desired values (e.g. EXPERIMENT_NAME=MTA, SEED=1001, CONCEPT_LIST="domain, ncount, narea"). Run [-h] option for help.
bash routines/train_uncertainty_weighted_mta.sh [EXPERIMENT_NAME] [SEED] [CONCEPT_LIST]
Expected outcome:
├results/[EXPERIMENT_NAME]
├── [EXPERIMENT_NAME]_log.txt
├── best_model.h5
├── lr_monitor.log
├── normalizer.npy
├── normalizing_patch.npy
├── seed.txt
├── training_log.npy
├── ERR.log
├── loss1_rec.log
├── loss2_rec.log
├── train_by_epoch.txt
├── training_log.npy
├── val_by_epoch.txt
├── val_acc_log.npy
The main python script is train_multitask_adversarial.py. It is called in this way, for example, to run the CNN with domain-adversarial training and the additional (desired) learning target of nuclei count:
python train_multitask_adversarial.py 0 DOMAIN-COUNT domain count
For more examples, please refer to the Notebooks folder
To replicate baseline results (line ID 1 in Table 2)
bash routines/replicate_baseline.sh
To replicate multi-task adversarial results (line IDs 2 to 8 in Table 2)
bash routines/replicate_mta.sh
To replicate Figure 4, see the notebook notebooks/visualize/UMAP_p3.ipynb
Distributed under the MIT License. See LICENSE
for more information.
Mara Graziani - @mormontre - mara.graziani@hevs.ch
If you use this software, please cite it as below
cff-version: 1.1.0 message: "If you use this software, please cite it as below." authors:
- family-names: Graziani
given-names: Mara title: maragraziani/multitask_adversarial: Official Release version: 0.2 date-released: 2017-12-18
Cite also the paper for this work:
Mara Graziani, Sebastian Otalora, Stephane Marchand-Maillet, Henning Muller, Vincent Andrearczyk. Learning Interpretable Pathology Features by Multi-task and Adversarial Training Improves CNN Generalization. Under Review.