Skip to content

lindsey98/Influence_function_metric_learning

Repository files navigation

Debugging and Explaining Metric Learning Approach: An Influence Function Perspective (NeurIPS 2022)

Introduction

Deep metric learning (DML) learns a generalizable embedding space of a dataset, where semantically similar samples are mapped closer. Recently, the record-breaking methodologies have been generally evolving from pairwise-based approaches to proxy-based approaches. However, many recent works begin to achieve only marginal improvements on the classical datasets. Thus, the explanation approaches of DML are in need for understanding why the trained model can confuse the dissimilar samples?.

The question motivates us to design an influence function based explanation framework to investigate the existing datasets, consisting of:

  • Scalable training-sample attribution:
    • We propose empirical influence function to identify what training samples contribute to the generalization errors, and quantify how much contribution they make to the errors.
  • Dataset relabelling recommendation:
    • We further aim to identify the potentially ``buggy'' training samples with mistaken labels and generate their relabelling recommendation.

Link to our website, paper

Requirements

pip install -r requirements.txt

Download datasets from

Caltech_birds2011[1]

Cars196[2]

In-shop Clothes Retrieval[3]

Put them under mnt/datasets/

Training details

  • We follow the train-test split provided by the original datasets
  • We use the same hyperparameters specified in Proxy-NCA++, except for In-Shop we reduce the batch size to 32*3 due to the limit of our GPU resources.

Project Structure

|__ config/: training config json files
|__ dataset/: define dataloader
|__ mnt/datasets/
   |__ CARS_196/
   |__ CUB200_2011/
   |__ inshop/
|__ evaluation/: evaluation script for recall@k, NMI etc.
|__ experiments/: scripts for experiments
|__ Influence_function/: implementation of IF and EIF
|__ train.py: normal training script
|__ train_noisy_data.py: noisy data trianing script
|__ train_sample_reweight.py: re-weighted training script

Instructions

  • Training the original models
    • Training the DML models with Proxy-NCA++ loss or with SoftTriple loss
python train.py --dataset [cub|cars|inshop] \
--loss-type ProxyNCA_prob_orig \
--seed [0|1|2|3|4] \
--config [config/cub_ProxyNCA_prob_orig.json|config/cars_ProxyNCA_prob_orig.json|config/inshop_ProxyNCA_prob_orig.json]
python train.py --dataset [cub|cars|inshop] \
--loss-type SoftTriple \
--seed [0|1|2|3|4] \
--config [config/cub_SoftTriple.json|config/cars_SoftTriple.json|config/inshop_SoftTriple.json]
  • Training the models with mislabelled data
    • Training the DML models with Proxy-NCA++ loss or with SoftTriple loss
python train_noisydata.py --dataset [cub_noisy|cars_noisy|inshop_noisy] \
--loss-type ProxyNCA_prob_orig_noisy_0.1 \
--seed [0|1|2|3|4] \
--mislabel_percentage 0.1 \
--config [config/cub_ProxyNCA_prob_orig.json|config/cars_ProxyNCA_prob_orig.json|config/inshop_ProxyNCA_prob_orig.json]
python train_noisydata.py --dataset [cub_noisy|cars_noisy|inshop_noisy] \
--loss-type SoftTriple_noisy_0.1 \
--seed [0|1|2|3|4] \
--mislabel_percentage 0.1 \
--config [config/cub_SoftTriple.json|config/cars_SoftTriple.json|config/inshop_SoftTriple.json]
  • DML training experiment (Table 1): comparing or

       See experiments/EIF_group_confusion.py, experiments/IF_group_confusion.py, experiments/EIF_pair_confusion.py, experiments/IF_pair_confusion.py

  • Mislabelled detection experiment

       See experiments/EIFvsIF_mislabel_evaluation.py

  • Field study

       See experiments/sample_recommendation_evaluation.py

Results

Please consider cite our paper

@article{liu2022debugging,
  title={Debugging and Explaining Metric Learning Approaches: An Influence Function Based Perspective},
  author={Liu, Ruofan and Lin, Yun and Yang, Xianglin and Dong, Jin Song},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  pages={7824--7837},
  year={2022}
}

References

[1] Wah, C., Branson, S., Welinder, P., Perona, P., & Belongie, S. (2011). The caltech-ucsd birds-200-2011 dataset.

[2] Krause, J., Stark, M., Deng, J., & Fei-Fei, L. (2013). 3d object representations for fine-grained categorization. In Proceedings of the IEEE international conference on computer vision workshops (pp. 554-561).

[3] Liu, Z., Luo, P., Qiu, S., Wang, X., & Tang, X. (2016). Deepfashion: Powering robust clothes recognition and retrieval with rich annotations. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 1096-1104).

About

Influence Function for Deep Metric Learning

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages