Permalink
Browse files

Add SVHN2MNIST

  • Loading branch information...
Ming-Yu Liu
Ming-Yu Liu committed Oct 5, 2017
1 parent 8da2a98 commit 65481145d5fd666129409de1cd8cb551cde8c68b
Showing with 43 additions and 2 deletions.
  1. +8 −2 USAGE.md
  2. +35 −0 exps/unit/svhn2mnist.yaml
View
@@ -69,14 +69,20 @@ pip install tensorboard
3. Setup the yaml file. Check out <exps/unit/blondhair.yaml>
4. Do training
4. Go to <src> and do training
```
python cocogan_train.py --config ../exps/unit/blondhair.yaml --log ../logs
```
5. Resume training
5. Go to <src> and do resume training
```
python cocogan_train.py --config ../exps/unit/blondhair.yaml --log ../logs --resume 1
```
6. Intermediate image outputs and model binary files are in <outputs/unit/blondhair>
#### SVHN2MNIST Adaptation
1. Go to <src> and execute
```
python cocogan_train_domain_adaptation.py --config ../exps/unit/svhn2mnist.yaml --log ../logs
```
View
@@ -0,0 +1,35 @@
# Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
# Licensed under the CC BY-NC-ND 4.0 license (https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode).
train:
snapshot_save_iterations: 100 # How often do you want to save trained models
image_save_iterations: 500 # How often do you want to save output images during training
display: 10 # How often do you want to log the training stats
snapshot_prefix: ../outputs/unit/svhn2mnist/svhn2mnist # Where do you want to save the outputs
hyperparameters:
trainer: COCOGANDAContextTrainer # CoVAE-GAN for domain adaptation
gen: CoVAE32x32
dis: CoDis32x32
gen_ch: 64 # base channel number per layer
dis_ch: 64 # base channel number per layer
kl_normalized_direct_w: 0.0001 # weight on the KL divergence loss
ll_normalized_direct_w: 0.001 # weight on the reconstruction loss
feature_w: 0.0001 # weight on discriminator feature matching
cls_w: 10.0 # weight on classification accuracy
gan_w: 1.0 # weight on the adversarial loss
batch_size: 64 # image batch size per domain
test_batch_size: 100
max_iterations: 200000 # maximum number of training epochs
input_dim_a: 3
input_dim_b: 1
datasets:
train_a: # Domain 1 dataset
class_name: dataset_svhn_extra # dataset class
root: ../datasets/svhn/ # dataset root location
train_b: # Domain 2 dataset
class_name: dataset_mnist32x32_train # dataset class
root: /data/projects/unit_release/datasets/mnist/ # dataset root location
use_inversion: 1
test_b: # Domain 1 dataset
class_name: dataset_mnist32x32_test # dataset class
root: ../datasets/mnist/ # dataset root location
use_inversion: 0

0 comments on commit 6548114

Please sign in to comment.