Skip to content

Commit

Permalink
Corrections to the omit_from_saving_and_reloading flag's default valu…
Browse files Browse the repository at this point in the history
…e and to the backbone saver.

PiperOrigin-RevId: 289243134
  • Loading branch information
eleniTriantafillou authored and lamblin committed Jan 22, 2020
1 parent f707943 commit e6c3e70
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 34 deletions.
16 changes: 16 additions & 0 deletions meta_dataset/learner.py
Expand Up @@ -29,6 +29,8 @@
from six.moves import zip
import tensorflow.compat.v1 as tf

FLAGS = tf.flags.FLAGS

MAX_WAY = 50 # The maximum number of classes we will see in any batch.


Expand Down Expand Up @@ -1945,7 +1947,21 @@ def __init__(self, is_training, transductive_batch_norm,
Raises:
ValueError: The embedding function must be MAML-compatible.
RuntimeError: Requested to meta-learn the initialization of the linear
layer weights but they are unexpectedly omitted from saving/restoring.
"""
if not zero_fc_layer and not proto_maml_fc_layer_init:
# So the linear classifier weights initialization is meta-learned.
if 'linear_classifier' in FLAGS.omit_from_saving_and_reloading:
raise RuntimeError('The linear layer is requested to be meta-learned '
'since both zero_fc_layer and '
'proto_maml_fc_layer_init are False, but the '
'linear_classifier weights are found in '
'FLAGS.omit_from_saving_and_reloading so they will '
'not be properly restored. Please exclude these '
'weights from omit_from_saving_and_reloading for '
'this setting to work as expected.')

super(MAMLLearner, self).__init__(is_training, transductive_batch_norm,
backprop_through_moments, ema_object,
embedding_fn, reader)
Expand Down
33 changes: 29 additions & 4 deletions meta_dataset/train.py
Expand Up @@ -89,16 +89,41 @@
'sub-graphs of ImageNet too, since the test sub-graph evidently does not '
'exhibit enough variation in the fine-grainedness of its different tasks '
'to allow for a meaningful analysis.')

# The following flag specifies substrings of variable names that should not be
# reloaded. `num_left_in_epoch' is a variable that influences the behavior of
# the EpochTrackers. Since the state of those trackers is not reloaded, neither
# should this variable. `fc_finetune' is a substring of the names of the
# variables in the episode-specific linear layer of the finetune baseline (used
# at meta-validation and meta-test times). Since this layer gets re-initialized
# to random weights in each new episode, there is no need to ever restore these
# weights. `linear_classifier' plays that role but for the MAML model: similarly
# in each new episode it is re-initialized (e.g. set to zeros or to the
# prototypes in the case of proto-MAML), so there is no need to restore these
# weights. `adam_opt' captures the variables of the within-episode optimizer of
# the finetune baseline when it is configured to perform that finetuning with
# adam. `fc' captures the variable names of the fully-connected layer for the
# all-way classification problem that the baselines solve at training time.
# There are certain situations where we need to omit reloading these weights to
# avoid getting an error. Consider for example the experiments where we train
# a baseline model, starting from weights that were previously trained on
# ImageNet. If this training now takes place on all datasets, the size of the
# all-way classification layer is now different (equal to the number of
# meta-training classes of all datasets not just of ImageNet). Thus when
# training baselines from pre-trained weights, we only reload the backbone and
# not the `fc' all-way classification layer (similarly for inference-only
# experiments for the same reason).
tf.flags.DEFINE_multi_enum(
'omit_from_saving_and_reloading',
['num_left_in_epoch', 'finetune', 'linear_classifier', 'adam_opt', 'fc'], [
'num_left_in_epoch', 'finetune', 'linear_classifier', 'adam_opt', 'fc',
'omit_from_saving_and_reloading', [
'num_left_in_epoch', 'fc_finetune', 'linear_classifier', 'adam_opt',
'weight_copy'
], [
'num_left_in_epoch', 'fc_finetune', 'linear_classifier', 'adam_opt',
'weight_copy', 'fc'
],
'A comma-separated list of substrings such that all variables containing '
'them should not be saved and reloaded.')


FLAGS = tf.flags.FLAGS


Expand Down
58 changes: 28 additions & 30 deletions meta_dataset/trainer.py
Expand Up @@ -413,6 +413,7 @@ def __init__(
self.eval_finegrainedness = eval_finegrainedness
self.eval_finegrainedness_split = eval_finegrainedness_split
self.eval_imbalance_dataset = eval_imbalance_dataset
self.omit_from_saving_and_reloading = omit_from_saving_and_reloading

self.eval_split = VALID_SPLIT if is_training else TEST_SPLIT
if eval_finegrainedness:
Expand Down Expand Up @@ -827,41 +828,38 @@ def initialize_session(self):

# Load the embedding variables from the pre-trained checkpoint. Since the
# pre-trained checkpoint comes from a BaselineLearner, we need a Saver
# that only considers Variables from a BaselineLearner. In particular, we
# exclude 'relationnet*' Variables as they are not present in the
# checkpoint.
not_relationnet_vars = []
# that only considers embedding Variables from a BaselineLearner. In
# particular, we exclude 'relationnet*' Variables as they are not present
# in the checkpoint. We also exclude any variables that are not related
# to the embedding (e.g. `beta1_power:0') and any variables that are
# requested to be omitted. Notably, this leads to not reloading ADAM
# variables. We do not reload these since this episodic finetuning
# procedure is a different optimization problem than the original training
# of the baseline whose embedding weights are re-used.
baselinelearner_embed_vars_to_reload = []
for var in tf.global_variables():
if not var.name.startswith('relationnet'):
not_relationnet_vars.append(var)
is_relationnet_var = var.name.startswith('relationnet')
requested_to_omit = any([
substring in var.name
for substring in self.omit_from_saving_and_reloading
])
is_embedding_var = any(
keyword in var.name for keyword in EMBEDDING_KEYWORDS)
is_adam_var = 'adam' in var.name.lower()
if (not is_relationnet_var and not requested_to_omit and
is_embedding_var):
if is_adam_var:
raise RuntimeError('Variable name unexpectedly indicates it is '
'both related to an embedding, and to ADAM: %s' %
var.name)
baselinelearner_embed_vars_to_reload.append(var)
backbone_saver = tf.train.Saver(
var_list=not_relationnet_vars, max_to_keep=1)
var_list=baselinelearner_embed_vars_to_reload, max_to_keep=1)
backbone_saver.restore(self.sess,
self.learner_config.pretrained_checkpoint)
logging.info('Restored checkpoint: %s',
logging.info('Restored only vars %s from checkpoint: %s',
[var.name for var in baselinelearner_embed_vars_to_reload],
self.learner_config.pretrained_checkpoint)
# We only want the embedding weights of the checkpoint we just restored.
# So we re-initialize everything that's not an embedding weight. Also,
# since this episodic finetuning procedure is a different optimization
# problem than the original training of the baseline whose embedding
# weights are re-used, we do not reload ADAM's variables and instead learn
# them from scratch.
# TODO(etriantafillou): modify backbone_saver's set of variables in order
# to exclude *all* non-embedding variables. Then we won't need the
# following block of code which explicitly re-initializes those, as they
# won't have been reloaded in the first place.
vars_to_reinit, embedding_var_names, vars_to_reinit_names = [], [], []
for var in tf.global_variables():
if (any(keyword in var.name for keyword in EMBEDDING_KEYWORDS) and
'adam' not in var.name.lower()):
embedding_var_names.append(var.name)
continue
vars_to_reinit.append(var)
vars_to_reinit_names.append(var.name)
logging.info('Initializing all variables except for %s.',
embedding_var_names)
self.sess.run(tf.variables_initializer(vars_to_reinit))
logging.info('Re-initialized vars %s.', vars_to_reinit_names)

def _create_held_out_specification(self, split=TEST_SPLIT):
"""Create an EpisodeSpecification for either validation or testing.
Expand Down

0 comments on commit e6c3e70

Please sign in to comment.