Skip to content

Commit

Permalink
Merge pull request #7 from DoctorKey/issues6
Browse files Browse the repository at this point in the history
add an argument to use only labeled data in training process.
  • Loading branch information
avital committed Jul 14, 2019
2 parents 7068633 + 33de18e commit 81fe4b2
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 4 deletions.
8 changes: 7 additions & 1 deletion lib/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_simple_mixed_batch(
shuffle_buffer_size,
labeled_data_filter_fn=None,
unlabeled_data_filter_fn=None,
mode="mix",
):
"""A less flexible, more memory-efficient version of get_simple_mixed_batch.
Expand All @@ -53,6 +54,8 @@ def get_simple_mixed_batch(
key. Returns a boolean tensor. Defaults to no filter (all images).
unlabeled_data_filter_fn (function): Function to decide which unlabeled
data to look at. Same signature as labeled_data_filter.
mode (str): "labeled" - use only labeled data,
"mix" (default) - use mixed data
Returns:
A tuple (images, labels, batch_count, remainder, num_classes), where:
Expand Down Expand Up @@ -124,10 +127,13 @@ def get_simple_mixed_batch(

# These operations merge the datasets in a way that intersperses
# elements from each, rather than just concatenating them.
if labeled_dataset_name == "imagenet_32":
if labeled_dataset_name == "imagenet_32" or mode == "labeled":
# Don't waste batch space for imagenet pre-training.
dataset = labeled_dataset
elif mode == "unlabeled":
dataset = unlabeled_dataset
else:
assert mode == "mix"
dataset = dataset_utils.shuffle_merge(
labeled_dataset, unlabeled_dataset
)
Expand Down
2 changes: 1 addition & 1 deletion runs/figure-2-cifar10-4000-fullysup-olna.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ windows:
layout: even-vertical
shell_command_before: cd /root/realistic-ssl-evaluation
panes:
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='cifar10' --secondary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna --n_labeled=4000 --consistency_model=none --labeled_classes_filter=2,3,4,5,6,7 --unlabeled_classes_filter=0,1,8,9 2>&1 | tee /mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna_train.log
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='cifar10' --secondary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna --n_labeled=4000 --dataset_model=labeled --consistency_model=none --labeled_classes_filter=2,3,4,5,6,7 --unlabeled_classes_filter=0,1,8,9 2>&1 | tee /mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna_train.log
- CUDA_VISIBLE_DEVICES=1 python3 evaluate_model.py --split=test --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna --consistency_model=none --labeled_classes_filter=2,3,4,5,6,7 2>&1 | tee /mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna_eval_test.log
- CUDA_VISIBLE_DEVICES=2 python3 evaluate_model.py --split=valid --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna --consistency_model=none --labeled_classes_filter=2,3,4,5,6,7 2>&1 | tee /mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna_eval_valid.log
- CUDA_VISIBLE_DEVICES=3 python3 evaluate_model.py --split=train --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna --consistency_model=none --labeled_classes_filter=2,3,4,5,6,7 2>&1 | tee /mnt/experiment-logs/figure-2-cifar10-4000-fullysup-olna_eval_train.log
2 changes: 1 addition & 1 deletion runs/table-1-cifar10-4000-fullysup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ windows:
layout: even-vertical
shell_command_before: cd /root/realistic-ssl-evaluation
panes:
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='cifar10' --secondary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/table-1-cifar10-4000-fullysup --n_labeled=4000 --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-cifar10-4000-fullysup_train.log
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='cifar10' --secondary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/table-1-cifar10-4000-fullysup --n_labeled=4000 --dataset_model=labeled --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-cifar10-4000-fullysup_train.log
- CUDA_VISIBLE_DEVICES=1 python3 evaluate_model.py --split=test --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/table-1-cifar10-4000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-cifar10-4000-fullysup_eval_test.log
- CUDA_VISIBLE_DEVICES=2 python3 evaluate_model.py --split=valid --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/table-1-cifar10-4000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-cifar10-4000-fullysup_eval_valid.log
- CUDA_VISIBLE_DEVICES=3 python3 evaluate_model.py --split=train --verbosity=0 --primary_dataset_name='cifar10' --root_dir=/mnt/experiment-logs/table-1-cifar10-4000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-cifar10-4000-fullysup_eval_train.log
2 changes: 1 addition & 1 deletion runs/table-1-svhn-1000-fullysup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ windows:
layout: even-vertical
shell_command_before: cd /root/realistic-ssl-evaluation
panes:
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='svhn' --secondary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/table-1-svhn-1000-fullysup --n_labeled=1000 --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-svhn-1000-fullysup_train.log
- CUDA_VISIBLE_DEVICES=0 python3 train_model.py --verbosity=0 --primary_dataset_name='svhn' --secondary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/table-1-svhn-1000-fullysup --n_labeled=1000 --dataset_model=labeled --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-svhn-1000-fullysup_train.log
- CUDA_VISIBLE_DEVICES=1 python3 evaluate_model.py --split=test --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/table-1-svhn-1000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-svhn-1000-fullysup_eval_test.log
- CUDA_VISIBLE_DEVICES=2 python3 evaluate_model.py --split=valid --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/table-1-svhn-1000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-svhn-1000-fullysup_eval_valid.log
- CUDA_VISIBLE_DEVICES=3 python3 evaluate_model.py --split=train --verbosity=0 --primary_dataset_name='svhn' --root_dir=/mnt/experiment-logs/table-1-svhn-1000-fullysup --consistency_model=none 2>&1 | tee /mnt/experiment-logs/table-1-svhn-1000-fullysup_eval_train.log
7 changes: 7 additions & 0 deletions train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
"datasets being used as unlabeled data. Defaults to all "
"classes.",
)
flags.DEFINE_string(
"dataset_mode",
"mix",
"'labeled' - use only labeled data to train the model. "
"'mix' (default) - use mixed data to train the model",
)

# Flags for book-keeping
flags.DEFINE_string(
Expand Down Expand Up @@ -154,6 +160,7 @@ def train(hps, result_dir, tuner=None, trial_name=None):
shuffle_buffer_size=1000,
labeled_data_filter_fn=labeled_data_filter_fn,
unlabeled_data_filter_fn=unlabeled_data_filter_fn,
mode=FLAGS.dataset_mode,
)

logging.info("Training data tensors constructed.")
Expand Down

0 comments on commit 81fe4b2

Please sign in to comment.