# Reproduction code for MatchDG

### Paper: Domain Generalization using Causal Matching [Arxiv](https://arxiv.org/abs/2006.07500)

The following code reproduces results for Rotated MNIST and Fashion-MNIST datasets, corresponding to Tables 1, 2 and 3 in the paper.

For convenience, we provide the exact commands for Rotated MNIST dataset with training domains set to [15, 30, 45, 60, 75] and the test domains set to [0, 90]. 

To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.

To obtain results for the different set of training domains in the paper, change the input to the parameter `--train_domains` with the list of training domains: `--train_domains [30, 45]` or `--train_domains [30, 45, 60]`

In [1]:
%cd ../../

/data/home/t-dimaha/RobustDG/robustdg


## Prepare Data

From the directory `data/rot_mnist`, run

In [2]:
%%bash
cd data/rot_mnist
python data_gen.py resnet18

## Table 1
Now move back to the root directory.

* ERM: 

In [3]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0

<class 'torch.Tensor'> torch.Size([2000]) torch.Size([2000, 1, 28, 28])
Source Domain  15
Source Domain  30
Source Domain  45
Source Domain  60
Source Domain  75
Max Class Size:  195 0 0
Max Class Size:  229 0 1
Max Class Size:  206 0 2
Max Class Size:  204 0 3
Max Class Size:  192 0 4
Max Class Size:  178 0 5
Max Class Size:  212 0 6
Max Class Size:  211 0 7
Max Class Size:  172 0 8
Max Class Size:  201 0 9
torch.Size([10000, 224, 224]) torch.Size([10000]) (10000,)
[2000, 2000, 2000, 2000, 2000]
torch.Size([10000, 224, 224]) torch.Size([10000, 10]) torch.Size([10000, 5]) (10000,)
<class 'torch.Tensor'> torch.Size([100]) torch.Size([100, 1, 28, 28])
Source Domain  15
Source Domain  30
Source Domain  45
Source Domain  60
Source Domain  75
torch.Size([500, 224, 224]) torch.Size([500]) (500,)
[100, 100, 100, 100, 100]
torch.Size([500, 224, 224]) torch.Size([500, 10]) torch.Size([500, 5]) (500,)
<class 'torch.Tensor'> torch.Size([2000]) torch.Size([2000, 1, 28, 28])
Source Domain  0
Source

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)


* ERM_RandomMatch:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1

* ERM_PerfectMatch:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1

* MatchDG:

In [4]:
%%bash
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --batch_size 128 --match_flag 1
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5

<class 'torch.Tensor'> torch.Size([2000]) torch.Size([2000, 1, 28, 28])
Source Domain  15
Source Domain  30
Source Domain  45
Source Domain  60
Source Domain  75
Max Class Size:  195 0 0
Max Class Size:  229 0 1
Max Class Size:  206 0 2
Max Class Size:  204 0 3
Max Class Size:  192 0 4
Max Class Size:  178 0 5
Max Class Size:  212 0 6
Max Class Size:  211 0 7
Max Class Size:  172 0 8
Max Class Size:  201 0 9
torch.Size([10000, 224, 224]) torch.Size([10000]) (10000,)
[2000, 2000, 2000, 2000, 2000]
torch.Size([10000, 224, 224]) torch.Size([10000, 10]) torch.Size([10000, 5]) (10000,)
<class 'torch.Tensor'> torch.Size([100]) torch.Size([100, 1, 28, 28])
Source Domain  15
Source Domain  30
Source Domain  45
Source Domain  60
Source Domain  75
torch.Size([500, 224, 224]) torch.Size([500]) (500,)
[100, 100, 100, 100, 100]
torch.Size([500, 224, 224]) torch.Size([500, 10]) torch.Size([500, 5]) (500,)
<class 'torch.Tensor'> torch.Size([2000]) torch.Size([2000, 1, 28, 28])
Source Domain  0
Source

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret = ret.dtype.type(ret / rcount)


## Table 2

* ERM: 

In [None]:
%%bash
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --test_metric match_score 

* MatchDG (Default):

In [None]:
%%bash
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --test_metric match_score

* MatchDG (PerfMatch):

In [None]:
%%bash
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --test_metric match_score

## Table 3

* Approx 25:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.25 --penalty_ws 0.1

* Approx 50:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.50 --penalty_ws 0.1

* Approx 75:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.75 --penalty_ws 0.1