Variational Adversarial Deep Domain Adaptation implementation in TensorFlow
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
datasets
.gitignore
.gitmodules
README.md
VRADA.py
VRNN.py
flip_gradient.py
image_datasets.py
kamiak_config.sh
kamiak_download.sh
kamiak_process.srun
kamiak_process_download.sh
kamiak_run_db.sh
kamiak_tflogs.sh
kamiak_train.srun
kamiak_train_cpu.srun
kamiak_upload.sh
load_data.py
model.py
plot.py
pool.py
process_watch_data.py
remote_to_local_models.sh
tcn.py

README.md

Variational Recurrent Adversarial Deep Domain Adaptation (VRADA)

Implementation of VRADA in TensorFlow. See their paper or blog post for details about the method. In their 2016 workshop paper, they called this Variational Adversarial Deep Domain Adaptation (VADDA). It's more-or-less the same method though they might do iterative optimization slightly differently.

You have a choice of running with or without domain adaptation and with two types of RNNs. In their paper, they refer to the LSTM with domain adaptation as "R-DANN" and the VRNN with domain adaptation as "VRADA."

  • --lstm -- use LSTM without adaptation
  • --vrnn -- use VRNN without adaptation
  • --cnn -- use CNN without adaptation
  • --tcn -- use TCN without adaptation
  • --lstm-da -- use LSTM with adaptation (R-DANN)
  • --vrnn-da -- use VRNN with adaptation (VRADA/VADDA)
  • --cnn-da -- use CNN with adaptation
  • --tcn-da -- use TCN with adaptation

To try these out, make sure you clone the repository recursively since there's submodules.

git clone --recursive https://github.com/floft/vrada
cd vrada

Datasets

This method uses RNNs, so requires time-series datasets. See README.md in datasets/ for information about generating some simple synthetic datasets or using an RF sleep stage dataset or the MIMIC-III health care dataset that the VRADA paper used. You can select which dataset to use with a command-line argument:

  • --mimic-{ahrf,icd9} -- use MIMIC-III health care datasets
  • --sleep -- use the RF sleep stage dataset
  • --trivial-{line,sine} -- use datasets generated by datasets/generate_trivial_datasets.py

Usage

Training Locally

For example, to run domain adaption locally using a VRNN on the synthetically generated "trivial" dataset:

python3 VRADA.py --logdir logs --modeldir models --debug --vrnn-da --trivial-line

Note the "--debug" flag tells it to start a new log and model directory (incrementing the folder number each time) for each run rather than continuing from where the previous run left off.

Training on a High-Performance Cluster

Alternatively, training on a cluster with Slurm (in my case on Kamiak) after editing kamiak_config.sh:

sbatch kamiak_train.srun --vrnn-da --trivial-line

Then on your local computer to monitor the progress in TensorBoard:

./kamiak_tflogs.sh
tensorboard --logdir vrada-logs

If you want to see images at more than 10 time steps:

tensorboard --logdir vrada-logs --samples_per_plugin images=100