# Getting started with RobustDG: Generalization and Privacy Attacks on Rotated MNIST dataset

Domain Generalization (DG) is the task of learning a predictive model that can generalize to different data distributions. Intuitively, models trained by just aggregating the data from different domain might overfit to the domains observed during training. Many DG methods have been proposed to improve the generalization of models for OOD data.

Here we present a simple application of the RobustDG library to build a model on a modified MNIST dataset and then evaluate its out-of-distribution accuracy and robustness to privacy attacks. 

## Dataset: Rotated MNIST

Rotated MNIST consists of various data domains, each corresponding to a specific rotation. It provides a very easy way to genereate out of distribution (OOD) data samples. For example, the model would be shown data containing rotations between 15 to 75 degrees during training; while at the test time it has to classify digits rotated by 90 degrees. Hence, different rotations/domains lead to a difference between the train and the test distributions

TODO: Include Images from the rotated MNIST dataset.



## Training ML models that can generalize to new domains 

### Baseline: Empirical risk minimization
We first train a model using ERM that simply pools data from different domains and builds a model.

In [None]:
# Training ERM model (TODO)

### MatchDG: Domain generalization via causal matching

The MatchDG model regularize the ERM training objective by matching data samples across domains that were generated from the same base object. More details are in the [Arxiv paper](https://arxiv.org/abs/2006.07500).

Train the MatchDG model on Rotated MNIST by executing the following command

MatchDG operates in two phases; in the first phase it learns a matching function and in the second phase it learns a classifier regularized as per the matching function learnt in the first phase

#### Phase 1: Learning Match Function

<code> python3 train.py --method_name match_dg_ctr </code>

#### Phase 2: Learning Classifier regularised on the Match Function

<code> python3 train.py --method_name match_dg_erm </code>

## Evaluating the trained model
After training the model; we can evaluate the model on various test metrics like test accuracy on the unseen domain; match function metrics, etc.

### Out-of-distribution accuracy

Here, we evalute the representations learnt with contrastive learning (Phase 1) using T-SNE plots

<code> python3 test.py --test_metric t_sne </code>

### Robustness to membership inference privacy attack

In [None]:
# Privacy attack