The repo for Bayesian Invariant Risk Minimization, CVPR2022 (oral).
Our implementation is based on the source code of IRM, ColoredObject, and CifarMnist.
Our code works with the following environment.
python=3.7.0torch=1.3.1h5py==2.8.0
To install the necessary packages for the project, please run: pip install -r requirements.txt.
- To perform BIRM on CMNIST (with 20K training data). Run the command
sh auto_CMNIST.sh. The expected test accuracy is67.0±1.8. - To perform BIRM on ColoredObject. First run
sh prepare_coco_dataset.shto download MSCOCO dataset and preprocess the images (it may take several hours, please be patient). Second run the commandsh auto_CifarMnist.shto train BRIM. The expected test accuracy is78.1±0.6. - To perform BIRM on CifarMnist. Run the command
sh auto_ColoredObject.sh. The expected test accuracy is59.3±2.3.
Important arguments:
dataset: chosen inCMNIST,ColoredObjectandCifarMnist;l2_regularizer_weight: weight decay coeffient;lr: learning rate;opt: the optimizer. By default, we useadamfor CMNIST and usesgdfor ColoredObject and CifarMnist;data_num: (only valid for dataset CMNIST) the number of training data;penalty_weight: the weight of the BIRM penalty;penalty_anneal_iters: the steps that we trians ERM first, after which BIRM penalty will be applied.step_gamma: the ratio of step decay,0.1means the learning rate will decay by0.1at the middle of the training steps.
- CMNIST: the most popular dataset in IRM literatures. The invariant feature is the shape of the digit from MNIST and the spurious feature is the attached color.
- ColoredObject: Following Faruk-Ahmed, we construct coloredObject by superimposing objects extracted from MSCOCO on a colored background (spurious feature)
- CifarMnist: Following Shah, we construct each image in CifarMnist by by concatenating two component images: CIFAR-10 (invariant) and MNIST (spurious).
Refer to Section 5 of our paper for detailed discription of the datasets.
We provider interface for you to include your own data. You need to inherit the
class IRMDataProvider, and re-implement the function fetch_train and fetch_test. The main function will call fetch_train to get training data for each step. fetch_train should return the following values:
train_x: the feature tensor;train_y: the label tensor;train_g: the tensor contains values indicating which environmnets the data are from;train_c(optional): the tensor contains values indicating whether the spurious features align with the labels.
The structure of the return value of fetch_test are similar with fetch_train.
For help or issues using Bayesian Invariant Risk Minimization, please submit a GitHub issue.
For personal communication related to BayesianIRM, please contact Yong Lin (ylindf@connect.ust.hk).
If you use or extend our work, please cite the following paper:
@inproceedings{lin2022bayesian,
title={Bayesian Invariant Risk Minimization},
author={Lin, Yong and Dong, Hanze and Wang, Hao and Zhang, Tong},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={16021--16030},
year={2022}
}