Domain-Adversarial Neural Network (DANN) is one of the well-known benchmarks for unsupervised domain adaptation tasks. DANN was presented in these papers:
@misc{https://doi.org/10.48550/arxiv.1409.7495,
url = {https://arxiv.org/abs/1409.7495},
author = {Ganin, Yaroslav and Lempitsky, Victor},
title = {Unsupervised Domain Adaptation by Backpropagation},
publisher = {arXiv},
year = {2014},
copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{https://doi.org/10.48550/arxiv.1505.07818,
url = {https://arxiv.org/abs/1505.07818},
author = {Ganin, Yaroslav and Ustinova, Evgeniya and Ajakan, Hana and Germain, Pascal and Larochelle, Hugo and Laviolette, François and Marchand, Mario and Lempitsky, Victor},
title = {Domain-Adversarial Training of Neural Networks},
publisher = {arXiv},
year = {2015},
copyright = {arXiv.org perpetual, non-exclusive license}
}
This DANN implementation uses a 2D toy dataset with built-in plots that help to visualize how the DANN algorithm is learning the new features.
The code starts by retrieving source dataset
from data folder. Then it performs a rotation (domain shift) on a copy of the dataset. The rotated dataset is the target dataset
. Here is a visualization of source and target datasets:
The function core.train_src
trains the feature_extractor
to separate source class 0
and source class 1
. Then, the learned model is tested on the test data:
Avg Loss = 0.20282, Avg Accuracy = 88.500000%, ARI = 0.59085
Now, we used the same feature_extractor
to classify target
samples. Note that we still did not perform domain adaptation:
Avg Loss = 0.61630, Avg Accuracy = 81.000000%, ARI = 0.38154
Most of domain adaptation logic is performed in core.train_tgt
function. The goal is to train the feature_extractor
to learn features for both source
and target
smaples. The feature_extractor
attempts to minimize a loss computed only on source
samples, since target
samples do not have labels. The feature_extractor
is optimized simultaneously with the discriminator
, which tries to "discriminate" if the sample is coming from source
or target
domains. Eventually, the feature_extractor
will learn features that make the discriminator
unable to tell which domain the sample is coming from. Now, we can use the feature_extractor
to classify target
samples:
Avg Loss = 0.26856, Avg Accuracy = 88.000000%, ARI = 0.57547
I reused some code from this repository.