# Reproduction results


The following code reproduces results for Rotated MNIST and Fashion-MNIST dataset, and Chest X-Ray corresponding to Figure 1, Table 4, Table 5 in the paper.


## Rotated MNIST

For convenience, we provide the exact commands for Rotated MNIST dataset with training domains set to [15, 30, 45, 60, 75] and the test domains set to [0, 90]. 

To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.

#### Prepare Data

From the directory `data`, run the following command: python data_gen.py resnet18


## Chest X-ray

Please perform the following step before running the code

    -Follow the steps in the Preprocess.ipynb notebook to download and process the Chest X-Ray datasets
    -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations

## Installing Libraries

Move back to the root directory.

List of all the required packages are mentioned in the file 'requirements.txt'

You may install them as follows: pip install -r requirements.txt

# Figure 1 & Table 4

# Rotated MNIST: 


## OOD Accuracy

### ERM
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --epochs 25

### RandomMatch
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --epochs 25

### MatchDG
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 256 --pos_metric cos

python train.py --dataset rot_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18

### CSD
python train.py --dataset rot_mnist --method_name csd --match_case 0.01 --penalty_ws 0.0 --rep_dim 512 --epochs 25

### IRM
python train.py --dataset rot_mnist --method_name irm --match_case 0.01 --penalty_irm 1.0 --penalty_s 5 --epochs 25

### Perfect Match
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1 --epochs 25

## Privacy Attacks and Mean Rank

Run the following command to generate results for the privacy attacks and mean rank.
The results will be stored in the results/rot_mnist/ directory 

### Classifier Attack
python3 metric_eval.py rot_mnist privacy_classifier

### Entropy Attack
python3 metric_eval.py rot_mnist privacy_entropy

### Loss Attack
python3 metric_eval.py rot_mnist privacy_loss_attack

### Attribute Attack
python3 metric_eval.py rot_mnist attribute_attack test

### Mean Rank
python3 metric_eval.py rot_mnist match_score test

# Chest X-Ray

    -Follow the steps in the Preprocess.ipynb notebook to donwload and process the Chest X-Ray datasets
    -Then follow the steps in the ChestXRay_Translate.ipynb notebook to perform image translations

## OOD Accuracy

### ERM
python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121

### RandomMatch
python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121


### MatchDG
python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 64 --pos_metric cos --train_domains nih_trans chex_trans --test_domains kaggle_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121

python train.py --dataset chestxray --method_name matchdg_erm  --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --model_name densenet121 --train_domains nih_trans kaggle_trans --test_domains chex_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --weight_decay 0.0005 --penalty_ws 50.0 --model_name densenet12


### CSD
python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --rep_dim 1024

### IRM
python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121

### MDGHybrid
python train.py --dataset chestxray --method_name hybrid  --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --model_name densenet121 --train_domains nih_trans kaggle_trans --test_domains chex_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --weight_decay 0.0005 --penalty_ws 1.0 --penalty_aug 50.0 --model_name densenet121

## Privacy Attacks and Mean Rank

Run the following command to generate results for the privacy attacks and mean rank.
The results will be stored in the results/rot_mnist/ directory 

### Classifier Attack
python3 metric_eval_chest.py privacy_classifier

### Entropy Attack
python3 metric_eval_chest.py privacy_entropy

### Loss Attack
python3 metric_eval_chest.py privacy_loss_attack

### Attribute Attack
python3 metric_eval_chest.py attribute_attack

### Mean Rank
python3 metric_eval_chest.py match_score