Writer : Masahiro Mitsuhara
Maintainer: Tsubasa Hirakawa
This repository is PyTorch implementation of Attention Branch Network for Multitask Learning.
In this repository, we use attribute classification task on CelebA dataset.
Please note that the model structure is different from the original Chainer implementation because we conducted further experiments and we seek better models by PyTorch. If you want to use or re-produce the original ABN paper at CVPR, please use the original Chainer implementation.
- 09 Jun 2019: The first release of Multitask ABN implemented by Chainer. You can find here (different repository).
- 11 Dec 2020: Implement MultitaskABN with PyTorch. We also improve the network architecture from the original Chainer implementation to achieve better results. You can find at tag:v0.1.
- 15 Jul 2022: Updated PyTorch implementation for PyTorch 1.11.0. We reproduce the original implementation by PyTorch. You can find at tag:v1.0.
- 23 May 2023: Updated PyTorch implementation for PyTorch 2.1.0. We modify network structure (output shape of perception branch for using BCELossWithLogits). And, we added Weighted Focal Loss [1].
If you find this repository is useful, please cite the following reference.
@inproceedings{fukui2018,
author = {Hiroshi Fukui and Tsubasa Hirakawa and Takayoshi Yamashita and Hironobu Fujiyoshi},
title = {Attention Branch Network: Learning of Attention Mechanism for Visual Explanation},
booktitle = {Computer Vision and Pattern Recognition},
year = {2019},
pages = {10705-10714}
}
@article{fukui2018,
author = {Hiroshi Fukui and Tsubasa Hirakawa and Takayoshi Yamashita and Hironobu Fujiyoshi},
title = {Attention Branch Network: Learning of Attention Mechanism for Visual Explanation},
journal = {arXiv preprint arXiv:1812.10025},
year = {2018}
}
You can find our papers as follows:
Please see docker/README.md
.
Please see data/README.md
.
Please see script_local
directory.
We can choose network models as follows:
-
--model [model name]
: network model. Please choose one from the following options.- ResNet:
resnet18
,resnet34
,resnet50
,resnet101
,resnet152
- Multitask ABN (V1):
mtabn_v1_resnet18
,mtabn_v1_resnet34
,mtabn_v1_resnet50
,mtabn_v1_resnet101
,mtabn_v1_resnet152
- Multitask ABN (V2):
mtabn_v2_resnet18
,mtabn_v2_resnet34
,mtabn_v2_resnet50
,mtabn_v2_resnet101
,mtabn_v2_resnet152
- ResNet:
-
--pretrained
: use pretrained model parameters by ImageNet as an intial parameter. -
--residual_attention
(ABN only): use residual attention mechanism. The differences are as follows:- No residual attention:
$f_c'(x_i) = M(x_i) \cdot g_c(x_i)$ - Residual attention:
$f_c'(x_i) = (1 + M(x_i)) \cdot g_c(x_i)$
- No residual attention:
In eval_celeba.py
, we prepare the following option for the type of attention map visualization.
--attention_type [pos/neg/both]
: The type of attention maps.pos
: Visualize only positive values.neg
: Visualize only negative values.both
: Visualize both of positive and negative values.
We show examples of each attention type.
pos (only positive values)
neg (only negative values)
both (both positive and negative values)
We publish pre-trained models. Please download from the following link.
NOTE
We have prepared he pre-trained models for each versions (see 1.1 Change Log for more details). Please use appropriate one.
- N. Sarafianos, et al., "Deep Imbalanced Attribute Classification using Visual Attention Aggregation," in ECCV, 2018.
- Machine Perception & Robotics Group
- Decision-making Analysis by Attention Mechanism and Applications (MPRG Tutorial)