Official PyTorch implementation of Domain Generalization by Rejecting Extreme Augmentations.
pip install -r requirements.txt
python -m domainbed.scripts.download --data_dir=/my/datasets/path
Environment details used for our study.
Python: 3.8.16
PyTorch: 1.11.0+cu113
Torchvision: 0.12.0+cu113
CUDA: 11.3
NumPy: 1.23.5
PIL: 9.4.0
train_all.py
script conducts a leave-one-out cross-validation for a given target domain.
python train_all.py exp_name --dataset PACS --test_envs target_domain --data_dir /my/datasets/path
train_seed.py
script is a similar to train_all.py
but run for multiple seeds.
python train_seed.py exp_name --dataset PACS --test_envs target_domain --data_dir /my/datasets/path --trails 3
We provide the instructions to reproduce the main results of the paper, Table 3. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.
- PACS, VLCS, OfficeHome, TerraIncognita
for DS in "PACS" "VLCS" "OfficeHome" "TerraIncognita"
do
for test_envs in 0 1 2 3
do
python train_seed.py domain configs/config_erm_domain.yaml --algorithm "ERMDAdv" --dataset $DS --test_envs $test_envs --deterministic --trials 3
done
for test_envs in 0 1 2 3
do
python train_seed.py teach_label configs/config_erm_label.yaml --algorithm "ERMAdv" --dataset $DS --test_envs $test_envs --deterministic --trials 3 --use_teacher True
done
for test_envs in 0 1 2 3
do
python train_seed.py ta_wider configs/config_erm_ta.yaml --algorithm "ERM" --dataset $DS --test_envs $test_envs --deterministic --trials 3 --auto_da "uniform" --tf_range "wider" --da_mode "online"
done
done
- DomainNet
for test_envs in 0 1 2 3 4 5
do
python train_seed.py domain configs/config_erm_domain.yaml --algorithm "ERMDAdv" --dataset DomainNet --test_envs $test_envs --deterministic --trials 3
done
for test_envs in 0 1 2 3 4 5
do
python train_seed.py teach_label configs/config_erm_label.yaml --algorithm "ERMAdv" --dataset DomainNet --test_envs $test_envs --deterministic --trials 3 --use_teacher True
done
for test_envs in 0 1 2 3 4 5
do
python train_seed.py ta_wider configs/config_erm_ta.yaml --algorithm "ERM" --dataset DomainNet --test_envs $test_envs --deterministic --trials 3 --auto_da "uniform" --tf_range "wider" --da_mode "online"
done
This project includes some code from DomainBed, and SWAD also MIT licensed.