Skip to content

feiaxyt/EMC-Click

Repository files navigation

EMC-Click: Efficient Mask Correction for Click-Based Interactive Image Segmentation (CVPR2023)

The Pytorch code for "Efficient Mask Correction for Click-Based Interactive Image Segmentation" in CVPR2023.

Abstract

The goal of click-based interactive image segmentation is to extract target masks with the input of positive/negative clicks. Every time a new click is placed, existing methods run the whole segmentation network to obtain a corrected mask, which is inefficient since several clicks may be needed to reach satisfactory accuracy. To this end, we propose an efficient method to correct the mask with a lightweight mask correction network. The whole network remains a low computational cost from the second click, even if we have a large backbone. However, a simple correction network with limited capacity is not likely to achieve comparable performance with a classic segmentation network. Thus, we propose a click-guided self-attention module and a click-guided correlation module to effectively exploits the click information to boost performance. First, several templates are selected based on the semantic similarity with click features. Then the self-attention module propagates the template information to other pixels, while the correlation module directly uses the templates to obtain target out- lines. With the efficient architecture and two click-guided modules, our method shows preferable performance and efficiency compared to existing methods.



Environment setup

  1. Install the requirements by executing
pip install -r requirements.txt
  1. Prepare the dataset and pretrained backbone weights following: Data_Weight_Preparation.md

Evaluation

Download the pretrained checkpoints from Releases and put them into weights directory.

Run

python -m torch.distributed.launch --master_port=4321 --nproc_per_node=8 scripts/evaluate_model.py EMC-Click \
    --model_dir='./weights/' \
    --checkpoint=hr18s.pth,hr18.pth,hr32.pth,segb0.pth,segb3.pth \
    --n-clicks=20 \
    --gpus=0,1,2,3,4,5,6,7 \
    --target-iou=0.9 \
    --thresh=0.5 \
    --eval-mode='emc-click' \
    --datasets=GrabCut,Berkeley,SBD,DAVIS,PascalVOC

to evaluate all models on the GrabCut, Berkeley, SBD, DAVIS, PascalVOC datasets.


Train

Run

CONFIG=models/emcclick/hrnet18s_att_cclvis.py
EXP_NAME=hrnet18s_att_cclvis
nGPUS=4
nBS=64
nWORKERS=4
PORT=`expr $RANDOM + 5000`

python -m torch.distributed.launch --nproc_per_node=$nGPUS --master_port=$PORT \
    train.py $CONFIG \
    --ngpus=$nGPUS \
    --workers=$nWORKERS \
    --batch-size=$nBS \
    --exp-name=$EXP_NAME

to train with the hrnet18s backbone.

You could find a templet in ./trainval_scripts/train_xxx.sh.


Acknowledgement

The code is implemented based on RITM and ClickSEG. We would like to express our sincere thanks to the contributors.


License

The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source.


Citation

If you find this work is useful for your research, please cite our papers:

@inproceedings{emcclick,
  title={Efficient Mask Correction for Click-Based Interactive Image Segmentation},
  author={Du, Fei and Yuan, Jianlong and Wang, Zhibin and Wang, Fan},
  booktitle={CVPR},
  year={2023}
}

About

Efficient Mask Correction for Click-Based Interactive Image Segmentation (CVPR2023)

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages