Skip to content

Commit

Permalink
First changes to help reproducing arXiv v2 experiments
Browse files Browse the repository at this point in the history
- Update "best" .gin files to reflect the new experiments
- Start updating the instructions

PiperOrigin-RevId: 283621968
  • Loading branch information
lamblin committed Dec 4, 2019
1 parent d7b4f4f commit 4b909ad
Show file tree
Hide file tree
Showing 63 changed files with 821 additions and 317 deletions.
5 changes: 3 additions & 2 deletions README.md
Expand Up @@ -49,7 +49,7 @@ hope to inspire work in these directions.

# User instructions
## Installation
Meta-Dataset is now compatible with Python 2 and Python 3, but has mostly been used with Python 2 up to now, so glitches with Python 3 are still possible.
Meta-Dataset is now compatible with Python 2 and Python 3, please report any glitch with Python 3.
The code has not been tested with TensorFlow 2 yet.

- We recommend you follow [these instructions](https://www.tensorflow.org/install/pip) to install TensorFlow.
Expand Down Expand Up @@ -95,7 +95,8 @@ Experiments are defined via [gin](google/gin-config) configuration files, that a
- `default/` contains files that each correspond to one experiment, mostly defining a setup and a model, with default values for training hyperparameters.
- `best/` contains files with values for training hyperparameters that achieved the best performance during hyperparameter search.

There are two main architectures, or "backbones": `four_layer_convnet` (sometimes `convnet` for short) and `resnet`, that can be used in the baselines ("k-NN" and "Finetune"), ProtoNet, and MatchingNet. Their layers do not have a trainable bias since it would be negated by the use of batch normalization. For fo-MAML and ProtoMAML, each of the backbones have a version with trainable biases (due to the way batch normalization is handled), resp. `four_layer_convnet_maml` (or `mamlconvnet`) and `resnet_maml` (sometimes `mamlresnet`); these can also be used by the baseline for pre-training of the MAML models.
There are three main architectures, also called "backbones" (or "embedding networks"): `four_layer_convnet` (sometimes `convnet` for short), `resnet`, and `wide_resnet`. These architectures can be used by all baselines and episodic models.
Another backbone, `relationnet_embedding` (similar to `four_layer_convnet` but without pooling on the last layer), is only used by RelationNet (and baseline, for pre-training purposes).

### Reproducing results

Expand Down
2 changes: 1 addition & 1 deletion doc/reproducing_best_results.md
Expand Up @@ -61,7 +61,7 @@ do
done
```

Each of the jobs took between 8 and 12 hours to reach 50k steps (episodes).
Each of the jobs took between 12 and 18 hours to reach 75k steps (episodes).

## Training on ImageNet

Expand Down
23 changes: 23 additions & 0 deletions meta_dataset/learn/gin/best/baseline_all.gin
@@ -0,0 +1,23 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'
BatchSplitReaderGetReader.add_dataset_offset = True

# Backbone hypers.
LearnerConfig.embedding_network = 'resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = False
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.9509421043104758
LearnerConfig.learning_rate = 0.0005102311315353643
weight_decay = 0.0005155272693869694
23 changes: 23 additions & 0 deletions meta_dataset/learn/gin/best/baseline_all_option_to_pretrain.gin
@@ -0,0 +1,23 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'
BatchSplitReaderGetReader.add_dataset_offset = True

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'imagenet'
LearnerConfig.pretrained_checkpoint = '/path/to/checkpoints/baseline_imagenet_wide_resnet_best/model_46000.ckpt'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 500
LearnerConfig.decay_rate = 0.8778059962506467
LearnerConfig.learning_rate = 0.000253906846867988
weight_decay = 0.00002393929026012612
22 changes: 22 additions & 0 deletions meta_dataset/learn/gin/best/baseline_imagenet.gin
@@ -0,0 +1,22 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.5082121576573064
LearnerConfig.learning_rate = 0.007084776688116927
weight_decay = 0.000005078192976503067
22 changes: 22 additions & 0 deletions meta_dataset/learn/gin/best/baseline_imagenet_convnet.gin
@@ -0,0 +1,22 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'four_layer_convnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 1000
LearnerConfig.decay_rate = 0.9105573818947892
LearnerConfig.learning_rate = 0.008644114436633987
weight_decay = 0.000005171477829794739
@@ -0,0 +1,22 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'scratch'
LearnerConfig.pretrained_checkpoint = ''

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = False
BaselineLearner.cosine_logits_multiplier = 10
BaselineLearner.use_weight_norm = False

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 10000
LearnerConfig.decay_rate = 0.7294597641152971
LearnerConfig.learning_rate = 0.007634189137886614
weight_decay = 0.000007138118976497546
22 changes: 22 additions & 0 deletions meta_dataset/learn/gin/best/baseline_imagenet_resnet.gin
@@ -0,0 +1,22 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = False
BaselineLearner.cosine_logits_multiplier = 2
BaselineLearner.use_weight_norm = False

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 1000
LearnerConfig.decay_rate = 0.9967524905880909
LearnerConfig.learning_rate = 0.00375640851370052
weight_decay = 0.00002628042826116842
22 changes: 22 additions & 0 deletions meta_dataset/learn/gin/best/baseline_imagenet_wide_resnet.gin
@@ -0,0 +1,22 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baseline_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_checkpoint = ''
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.knn_distance = 'cosine'
BaselineLearner.cosine_classifier = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineLearner.use_weight_norm = True

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.5082121576573064
LearnerConfig.learning_rate = 0.007084776688116927
weight_decay = 0.000005078192976503067
24 changes: 24 additions & 0 deletions meta_dataset/learn/gin/best/baselinefinetune_all.gin
@@ -0,0 +1,24 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.cosine_classifier = False
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineFinetuneLearner.num_finetune_steps = 50
BaselineFinetuneLearner.finetune_lr = 0.1
BaselineFinetuneLearner.finetune_all_layers = True
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 2500
LearnerConfig.decay_rate = 0.5508783586336378
LearnerConfig.learning_rate = 0.005493938830376542
weight_decay = 0.0000031050368100770684
@@ -0,0 +1,27 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'
BatchSplitReaderGetReader.add_dataset_offset = True

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'imagenet'
LearnerConfig.pretrained_checkpoint = '/path/to/checkpoints/baseline_imagenet_wide_resnet_best/model_46000.ckpt'


# Model hypers.
BaselineLearner.cosine_classifier = False
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_all_layers = True
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 5000
LearnerConfig.decay_rate = 0.5559080744371039
LearnerConfig.learning_rate = 0.0027015533546616804
weight_decay = 0.00002266979856832968
@@ -0,0 +1,14 @@
include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_cosine_config.gin'

LearnerConfig.embedding_network = 'wide_resnet'
DataConfig.image_height = 84
LearnerConfig.pretrained_source = 'scratch'
weight_decay = 0

BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_with_adam = False
BaselineLearner.use_weight_norm = False
BaselineFinetuneLearner.finetune_all_layers = False
BaselineLearner.cosine_logits_multiplier = 10
@@ -0,0 +1,15 @@
include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot.
include 'meta_dataset/learn/gin/models/baselinefinetune_cosine_config.gin'

LearnerConfig.embedding_network = 'wide_resnet'
DataConfig.image_height = 84
LearnerConfig.pretrained_source = 'scratch'
weight_decay = 0

BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_with_adam = False
BaselineLearner.use_weight_norm = False
BaselineFinetuneLearner.finetune_all_layers = False
BaselineLearner.cosine_logits_multiplier = 10
24 changes: 24 additions & 0 deletions meta_dataset/learn/gin/best/baselinefinetune_imagenet.gin
@@ -0,0 +1,24 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'resnet'
LearnerConfig.pretrained_source = 'scratch'

# Model hypers.
BaselineLearner.cosine_classifier = True
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 10
BaselineFinetuneLearner.num_finetune_steps = 100
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_all_layers = False
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 126

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 2500
LearnerConfig.decay_rate = 0.7427378713742643
LearnerConfig.learning_rate = 0.003725198463674423
weight_decay = 0.000003337891450479888
@@ -0,0 +1,25 @@
include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'wide_resnet'
LearnerConfig.pretrained_source = 'imagenet'
LearnerConfig.pretrained_checkpoint = '/path/to/checkpoints/baseline_imagenet_wide_resnet_best/model_46000.ckpt'

# Model hypers.
BaselineLearner.cosine_classifier = False
BaselineLearner.use_weight_norm = True
BaselineLearner.cosine_logits_multiplier = 1
BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_all_layers = True
BaselineFinetuneLearner.finetune_with_adam = True

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 5000
LearnerConfig.decay_rate = 0.5559080744371039
LearnerConfig.learning_rate = 0.0027015533546616804
weight_decay = 0.00002266979856832968
@@ -0,0 +1,15 @@
include 'meta_dataset/learn/gin/setups/mini_imagenet.gin'
EpisodeDescriptionConfig.num_support = 5

include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

LearnerConfig.embedding_network = 'wide_resnet'
DataConfig.image_height = 84
LearnerConfig.pretrained_source = 'scratch'
weight_decay = 0

BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_with_adam = False
BaselineLearner.use_weight_norm = False
BaselineFinetuneLearner.finetune_all_layers = False
@@ -0,0 +1,14 @@
include 'meta_dataset/learn/gin/setups/mini_imagenet_five_way_five_shot.gin'
EpisodeDescriptionConfig.num_support = 1 # Change 5-shot to 1-shot.
include 'meta_dataset/learn/gin/models/baselinefinetune_config.gin'

LearnerConfig.embedding_network = 'wide_resnet'
DataConfig.image_height = 84
LearnerConfig.pretrained_source = 'scratch'
weight_decay = 0

BaselineFinetuneLearner.num_finetune_steps = 200
BaselineFinetuneLearner.finetune_lr = 0.01
BaselineFinetuneLearner.finetune_with_adam = False
BaselineLearner.use_weight_norm = False
BaselineFinetuneLearner.finetune_all_layers = False
23 changes: 23 additions & 0 deletions meta_dataset/learn/gin/best/maml_all.gin
@@ -0,0 +1,23 @@
include 'meta_dataset/learn/gin/setups/all.gin'
include 'meta_dataset/learn/gin/models/maml_config.gin'

# Backbone hypers.
LearnerConfig.embedding_network = 'four_layer_convnet'
LearnerConfig.pretrained_source = 'imagenet'
LearnerConfig.pretrained_checkpoint = '/path/to/checkpoints/baseline_imagenet_convnet_best/model_42500.ckpt'


# Model hypers.
MAMLLearner.first_order = True
MAMLLearner.alpha = 0.01
MAMLLearner.additional_test_update_steps = 0
MAMLLearner.num_update_steps = 10

# Data hypers.
DataConfig.image_height = 84

# Training hypers (not needed for eval).
LearnerConfig.decay_every = 100
LearnerConfig.decay_rate = 0.7391196845071122
LearnerConfig.learning_rate = 0.0033100944028230375
weight_decay = 0.0003592684468298601

0 comments on commit 4b909ad

Please sign in to comment.