# Learning to Transfer Learn

* Linchao Zhu, Sercan O. Arik, Yi Yang, Tomas Pfister, "Learning to Transfer Learn", arXiv preprint arXiv:1908.11406 (2019) - https://arxiv.org/abs/1908.11406


We demonstrate the effectiveness of L2TL on the MNIST->SVHN experiments.

We illustrate the experimental setup. We then show the scripts to train and evaluate the model

## Requirements

```
virtualenv -p python3 .
source ./bin/activate

pip install -r requirements.txt

cp svhn_data/__init__.py lib/python3.5/site-packages/tensorflow_datasets/image
cp svhn_data/svhn_small.py lib/python3.5/site-packages/tensorflow_datasets/image
cd svhn_data
wget -nc http://ufldl.stanford.edu/housenumbers/train_32x32.mat
wget -nc http://ufldl.stanford.edu/housenumbers/test_32x32.mat
python gen_svhn_mat.py
cd ..
```

## Baseline model functions for training

In [57]:
import model
import model_utils
import tensorflow as tf
from tensorflow.python.estimator import estimator
import tensorflow_datasets as tfds

import os
import re


def get_train_model_fn(train_batch_size,
                       target_dataset,
                       target_base_learning_rate,
                       src_num_classes=5,
                       weight_decay=0.0005):
  """Returns the model definition."""

  def lr_schedule():
    """Learning rate scheduling."""
    target_lr = target_base_learning_rate
    current_step = tf.train.get_global_step()

    if target_dataset == 'mnist':
      return tf.train.piecewise_constant(current_step, [
          500, 1500,
    ], [target_lr, target_lr * 0.1, target_lr * 0.01])
    else:
      return tf.train.piecewise_constant(current_step, [
          800,
      ], [target_lr, target_lr * 0.1])


  def model_fn(features, labels, mode, params):
    """Returns the model function."""
    feature = features['feature']
    print(feature)
    labels = labels['label']
    one_hot_labels = model_utils.get_label(
        labels,
        params,
        src_num_classes,
        batch_size=train_batch_size)

    def get_logits():
      """Return the logits."""
      avg_pool = model.conv_model(feature, mode,
                                  target_dataset=target_dataset)
      name = 'final_dense_dst'
      with tf.variable_scope('target_CLS'):
        logits = tf.layers.dense(
            inputs=avg_pool, units=src_num_classes, name=name,
            kernel_initializer=tf.random_normal_initializer(stddev=.05),
        )
      return logits

    logits = get_logits()
    logits = tf.cast(logits, tf.float32)

    dst_loss = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
    )
    dst_l2_loss = weight_decay * tf.add_n([
        tf.nn.l2_loss(v)
        for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name and 'kernel' in v.name
    ])

    loss = dst_loss + dst_l2_loss

    train_op = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      cur_finetune_step = tf.train.get_global_step()
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        finetune_learning_rate = lr_schedule()
        optimizer = tf.train.MomentumOptimizer(
            learning_rate=finetune_learning_rate,
            momentum=0.9,
            use_nesterov=True
        )
        train_op = tf.contrib.slim.learning.create_train_op(loss, optimizer)
        with tf.variable_scope('finetune'):
          train_op = optimizer.minimize(loss, cur_finetune_step)
    else:
      train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:
      eval_metrics = model_utils.metric_fn(labels, logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      with tf.control_dependencies([train_op]):
        tf.summary.scalar('classifier/finetune_lr', finetune_learning_rate)
    else:
      train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metrics,
    )

  return model_fn


## Evaluation and baseline training setup

In [58]:
def get_eval_model_fn(src_num_classes,
                      train_batch_size,
                      target_dataset,
                      cls_dense_name):
  """Returns the model definition."""

  def model_fn(features, labels, mode, params):
    """Returns the model function."""
    feature = features['feature']
    labels = labels['label']
    one_hot_labels = model_utils.get_label(
        labels,
        params,
        src_num_classes,
        batch_size=train_batch_size)

    def get_logits():
      """Return the logits."""
      network_output = model.conv_model(feature, mode,
                                        target_dataset=target_dataset)
      name = cls_dense_name
      with tf.variable_scope('target_CLS'):
        logits = tf.layers.dense(
            inputs=network_output, units=src_num_classes, name=name)
      return logits

    logits = get_logits()
    logits = tf.cast(logits, tf.float32)

    dst_loss = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
    )
    loss = dst_loss

    eval_metrics = model_utils.metric_fn(labels, logits)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=None,
        eval_metric_ops=eval_metrics,
    )

  return model_fn


def evaluate(target_dataset,
             train_batch_size,
             cls_dense_name,
             ckpt_path,
             src_num_classes):
  NUM_EVAL_IMAGES = {
      'mnist': 10000,
      'svhn_cropped_small': 6000,
  }
  config = tf.estimator.RunConfig()

  classifier = tf.estimator.Estimator(
      get_eval_model_fn(
          src_num_classes,
          train_batch_size,
          target_dataset,
          cls_dense_name),
      config=config)

  def _merge_datasets(test_batch):
    feature, label = test_batch['image'], test_batch['label'],
    features = {'feature': feature}
    labels = {'label': label,}
    return (features, labels)

  def get_dataset(dataset_split):
    """Returns dataset creation function."""

    def make_input_dataset():
      """Returns input dataset."""
      test_data = tfds.load(name=target_dataset, split=dataset_split)
      test_data = test_data.batch(train_batch_size)
      dataset = tf.data.Dataset.zip((test_data,))
      dataset = dataset.map(_merge_datasets)
      dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
      return dataset

    return make_input_dataset

  num_eval_images = NUM_EVAL_IMAGES[target_dataset]
  eval_steps = num_eval_images // train_batch_size

  classifier.evaluate(
      input_fn=get_dataset('test'),
      steps=eval_steps,
      checkpoint_path=ckpt_path,
  )


def train(model_dir,
          target_dataset="svhn_cropped_small",
          train_batch_size=8,
          train_steps=1200,
          target_base_learning_rate=0.005,
          src_num_classes=5,
          warm_start_ckpt_path=None):
  tf.set_random_seed(1)

  run_config_args = {
      'model_dir': model_dir,
      'save_checkpoints_steps': 200,
      'log_step_count_steps': 128,
      'keep_checkpoint_max': 20,
  }

  config = tf.estimator.RunConfig(**run_config_args)

  if warm_start_ckpt_path:
    var_names = []
    checkpoint_path = warm_start_ckpt_path
    reader = tf.train.NewCheckpointReader(checkpoint_path)
    for key in reader.get_variable_to_shape_map():
      keep_str = 'Momentum|global_step|finetune_global_step|Adam|final_dense_dst'
      if not re.findall('({})'.format(keep_str,), key):
        var_names.append(key)

    tf.logging.info('Warm-starting tensors: %s', sorted(var_names))

    vars_to_warm_start = var_names
    warm_start_settings = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=checkpoint_path,
        vars_to_warm_start=vars_to_warm_start)
  else:
    warm_start_settings = None

  classifier = tf.estimator.Estimator(
      get_train_model_fn(
          train_batch_size=train_batch_size,
          target_dataset=target_dataset,
          target_base_learning_rate=target_base_learning_rate,
          src_num_classes=src_num_classes,
      ),
      config=config, warm_start_from=warm_start_settings)

  def _merge_datasets(train_batch):
    feature, label = train_batch['image'], train_batch['label'],
    features = {'feature': feature}
    labels = {'label': label}
    return (features, labels)

  def get_dataset(dataset_split):
    """Returns dataset creation function."""

    def make_input_dataset():
      """Returns input dataset."""
      train_data = tfds.load(name=target_dataset, split=dataset_split)
      train_data = train_data.shuffle(1024).repeat().batch(
          train_batch_size)
      dataset = tf.data.Dataset.zip((train_data,))
      dataset = dataset.map(_merge_datasets)
      dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
      return dataset

    return make_input_dataset

  # pylint: disable=protected-access
  current_step = estimator._load_global_step_from_checkpoint_dir(
      model_dir)

  while current_step < train_steps:
    print('Run {}'.format(current_step))
    next_checkpoint = current_step + 400
    classifier.train(input_fn=get_dataset('train'), max_steps=next_checkpoint)
    current_step = next_checkpoint

## Train SVHN from scratch

In [70]:
model_dir = "./tmp/scratch_svhn"
target_dataset = "svhn_cropped_small"
train_batch_size = 8
train_steps = 1200
learning_rate = 0.005
warm_start_ckpt_path = None

train(model_dir=model_dir,
      target_dataset=target_dataset,
      train_batch_size=train_batch_size,
      train_steps=train_steps,
      target_base_learning_rate=learning_rate,
      warm_start_ckpt_path=warm_start_ckpt_path)
evaluate(
    target_dataset=target_dataset,
    train_batch_size=600,
    cls_dense_name='final_dense_dst',
    ckpt_path=os.path.join(model_dir, 'model.ckpt-%d' % train_steps),
    src_num_classes=5
)

INFO:tensorflow:Using config: {'_model_dir': './tmp/scratch_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb887bc320>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': './tmp/scratch_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb887bc320>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






Run 0
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:loss = 1.8790246, step = 0


INFO:tensorflow:loss = 1.8790246, step = 0


INFO:tensorflow:global_step/sec: 142.16


INFO:tensorflow:global_step/sec: 142.16


INFO:tensorflow:loss = 1.8503022, step = 128 (0.905 sec)


INFO:tensorflow:loss = 1.8503022, step = 128 (0.905 sec)


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 183.349


INFO:tensorflow:global_step/sec: 183.349


INFO:tensorflow:loss = 1.93315, step = 256 (0.696 sec)


INFO:tensorflow:loss = 1.93315, step = 256 (0.696 sec)


INFO:tensorflow:global_step/sec: 199.797


INFO:tensorflow:global_step/sec: 199.797


INFO:tensorflow:loss = 1.8712884, step = 384 (0.640 sec)


INFO:tensorflow:loss = 1.8712884, step = 384 (0.640 sec)


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 1.8143762.


INFO:tensorflow:Loss for final step: 1.8143762.


Run 400
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-400


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-400


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:loss = 1.8000953, step = 400


INFO:tensorflow:loss = 1.8000953, step = 400


INFO:tensorflow:global_step/sec: 128.608


INFO:tensorflow:global_step/sec: 128.608


INFO:tensorflow:loss = 1.8308533, step = 528 (0.997 sec)


INFO:tensorflow:loss = 1.8308533, step = 528 (0.997 sec)


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 185.062


INFO:tensorflow:global_step/sec: 185.062


INFO:tensorflow:loss = 1.5834279, step = 656 (0.693 sec)


INFO:tensorflow:loss = 1.5834279, step = 656 (0.693 sec)


INFO:tensorflow:global_step/sec: 209.381


INFO:tensorflow:global_step/sec: 209.381


INFO:tensorflow:loss = 0.6006835, step = 784 (0.613 sec)


INFO:tensorflow:loss = 0.6006835, step = 784 (0.613 sec)


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.7741643.


INFO:tensorflow:Loss for final step: 0.7741643.


Run 800
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-800


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-800


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:loss = 0.85390353, step = 800


INFO:tensorflow:loss = 0.85390353, step = 800


INFO:tensorflow:global_step/sec: 128.352


INFO:tensorflow:global_step/sec: 128.352


INFO:tensorflow:loss = 0.46725968, step = 928 (0.999 sec)


INFO:tensorflow:loss = 0.46725968, step = 928 (0.999 sec)


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 178


INFO:tensorflow:global_step/sec: 178


INFO:tensorflow:loss = 0.45704037, step = 1056 (0.719 sec)


INFO:tensorflow:loss = 0.45704037, step = 1056 (0.719 sec)


INFO:tensorflow:global_step/sec: 222.285


INFO:tensorflow:global_step/sec: 222.285


INFO:tensorflow:loss = 0.5036832, step = 1184 (0.578 sec)


INFO:tensorflow:loss = 0.5036832, step = 1184 (0.578 sec)


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/scratch_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.41581798.


INFO:tensorflow:Loss for final step: 0.41581798.






INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmprf81d13g', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb88603080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmprf81d13g', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb88603080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-03-08T21:02:38Z


INFO:tensorflow:Starting evaluation at 2020-03-08T21:02:38Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-1200


INFO:tensorflow:Restoring parameters from ./tmp/scratch_svhn/model.ckpt-1200


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2020-03-08-21:02:40


INFO:tensorflow:Finished evaluation at 2020-03-08-21:02:40


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.11495, top_1_accuracy = 0.64016664, top_5_accuracy = 1.0


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.11495, top_1_accuracy = 0.64016664, top_5_accuracy = 1.0


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/scratch_svhn/model.ckpt-1200


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/scratch_svhn/model.ckpt-1200


#### Train SVHN from scratch: the top-1 accuracy is 0.64%.

## Pre-train MNIST

In [71]:
model_dir = "./tmp/mnist_pretrain"
target_dataset = "mnist"
train_batch_size = 128
train_steps = 2000
learning_rate = 0.01
warm_start_ckpt_path = None

train(model_dir=model_dir,
      target_dataset=target_dataset,
      train_batch_size=train_batch_size,
      train_steps=train_steps,
      target_base_learning_rate=learning_rate,
      src_num_classes=10,
      warm_start_ckpt_path=warm_start_ckpt_path)
evaluate(
    target_dataset=target_dataset,
    train_batch_size=600,
    cls_dense_name='final_dense_dst',
    ckpt_path=os.path.join(model_dir, 'model.ckpt-%d' % train_steps),
    src_num_classes=10
)

INFO:tensorflow:Using config: {'_model_dir': './tmp/mnist_pretrain', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb90157ef0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': './tmp/mnist_pretrain', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb90157ef0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






Run 0
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:loss = 2.5687585, step = 0


INFO:tensorflow:loss = 2.5687585, step = 0


INFO:tensorflow:global_step/sec: 76.9963


INFO:tensorflow:global_step/sec: 76.9963


INFO:tensorflow:loss = 0.46617663, step = 128 (1.665 sec)


INFO:tensorflow:loss = 0.46617663, step = 128 (1.665 sec)


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:global_step/sec: 89.1677


INFO:tensorflow:global_step/sec: 89.1677


INFO:tensorflow:loss = 0.41375953, step = 256 (1.438 sec)


INFO:tensorflow:loss = 0.41375953, step = 256 (1.438 sec)


INFO:tensorflow:global_step/sec: 57.0442


INFO:tensorflow:global_step/sec: 57.0442


INFO:tensorflow:loss = 0.33184546, step = 384 (2.242 sec)


INFO:tensorflow:loss = 0.33184546, step = 384 (2.242 sec)


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Loss for final step: 0.29182482.


INFO:tensorflow:Loss for final step: 0.29182482.


Run 400
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-400


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-400


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:loss = 0.51420915, step = 400


INFO:tensorflow:loss = 0.51420915, step = 400


INFO:tensorflow:global_step/sec: 74.1483


INFO:tensorflow:global_step/sec: 74.1483


INFO:tensorflow:loss = 0.33345065, step = 528 (1.729 sec)


INFO:tensorflow:loss = 0.33345065, step = 528 (1.729 sec)


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:global_step/sec: 69.9444


INFO:tensorflow:global_step/sec: 69.9444


INFO:tensorflow:loss = 0.2970869, step = 656 (1.829 sec)


INFO:tensorflow:loss = 0.2970869, step = 656 (1.829 sec)


INFO:tensorflow:global_step/sec: 58.8558


INFO:tensorflow:global_step/sec: 58.8558


INFO:tensorflow:loss = 0.32211864, step = 784 (2.179 sec)


INFO:tensorflow:loss = 0.32211864, step = 784 (2.179 sec)


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Loss for final step: 0.35087806.


INFO:tensorflow:Loss for final step: 0.35087806.


Run 800
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-800


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-800


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:loss = 0.37344068, step = 800


INFO:tensorflow:loss = 0.37344068, step = 800


INFO:tensorflow:global_step/sec: 36.7906


INFO:tensorflow:global_step/sec: 36.7906


INFO:tensorflow:loss = 0.34726804, step = 928 (3.483 sec)


INFO:tensorflow:loss = 0.34726804, step = 928 (3.483 sec)


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:global_step/sec: 97.7485


INFO:tensorflow:global_step/sec: 97.7485


INFO:tensorflow:loss = 0.36406294, step = 1056 (1.307 sec)


INFO:tensorflow:loss = 0.36406294, step = 1056 (1.307 sec)


INFO:tensorflow:global_step/sec: 74.6865


INFO:tensorflow:global_step/sec: 74.6865


INFO:tensorflow:loss = 0.31425574, step = 1184 (1.716 sec)


INFO:tensorflow:loss = 0.31425574, step = 1184 (1.716 sec)


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Loss for final step: 0.3123991.


INFO:tensorflow:Loss for final step: 0.3123991.


Run 1200
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-1200


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-1200


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:loss = 0.34780872, step = 1200


INFO:tensorflow:loss = 0.34780872, step = 1200


INFO:tensorflow:global_step/sec: 44.8194


INFO:tensorflow:global_step/sec: 44.8194


INFO:tensorflow:loss = 0.32264873, step = 1328 (2.860 sec)


INFO:tensorflow:loss = 0.32264873, step = 1328 (2.860 sec)


INFO:tensorflow:Saving checkpoints for 1400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1400 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:global_step/sec: 73.0058


INFO:tensorflow:global_step/sec: 73.0058


INFO:tensorflow:loss = 0.37420636, step = 1456 (1.751 sec)


INFO:tensorflow:loss = 0.37420636, step = 1456 (1.751 sec)


INFO:tensorflow:global_step/sec: 57.4841


INFO:tensorflow:global_step/sec: 57.4841


INFO:tensorflow:loss = 0.28522557, step = 1584 (2.230 sec)


INFO:tensorflow:loss = 0.28522557, step = 1584 (2.230 sec)


INFO:tensorflow:Saving checkpoints for 1600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Loss for final step: 0.3057857.


INFO:tensorflow:Loss for final step: 0.3057857.


Run 1600
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-1600


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-1600


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 1600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1600 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:loss = 0.3174793, step = 1600


INFO:tensorflow:loss = 0.3174793, step = 1600


INFO:tensorflow:global_step/sec: 64.9286


INFO:tensorflow:global_step/sec: 64.9286


INFO:tensorflow:loss = 0.29272425, step = 1728 (1.976 sec)


INFO:tensorflow:loss = 0.29272425, step = 1728 (1.976 sec)


INFO:tensorflow:Saving checkpoints for 1800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1800 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:global_step/sec: 66.1024


INFO:tensorflow:global_step/sec: 66.1024


INFO:tensorflow:loss = 0.2900865, step = 1856 (1.934 sec)


INFO:tensorflow:loss = 0.2900865, step = 1856 (1.934 sec)


INFO:tensorflow:global_step/sec: 57.2027


INFO:tensorflow:global_step/sec: 57.2027


INFO:tensorflow:loss = 0.294807, step = 1984 (2.240 sec)


INFO:tensorflow:loss = 0.294807, step = 1984 (2.240 sec)


INFO:tensorflow:Saving checkpoints for 2000 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Saving checkpoints for 2000 into ./tmp/mnist_pretrain/model.ckpt.


INFO:tensorflow:Loss for final step: 0.28967467.


INFO:tensorflow:Loss for final step: 0.28967467.






INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpq_adqh3i', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb80454668>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpq_adqh3i', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb80454668>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-03-08T21:03:31Z


INFO:tensorflow:Starting evaluation at 2020-03-08T21:03:31Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Restoring parameters from ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/16]


INFO:tensorflow:Evaluation [1/16]


INFO:tensorflow:Evaluation [2/16]


INFO:tensorflow:Evaluation [2/16]


INFO:tensorflow:Evaluation [3/16]


INFO:tensorflow:Evaluation [3/16]


INFO:tensorflow:Evaluation [4/16]


INFO:tensorflow:Evaluation [4/16]


INFO:tensorflow:Evaluation [5/16]


INFO:tensorflow:Evaluation [5/16]


INFO:tensorflow:Evaluation [6/16]


INFO:tensorflow:Evaluation [6/16]


INFO:tensorflow:Evaluation [7/16]


INFO:tensorflow:Evaluation [7/16]


INFO:tensorflow:Evaluation [8/16]


INFO:tensorflow:Evaluation [8/16]


INFO:tensorflow:Evaluation [9/16]


INFO:tensorflow:Evaluation [9/16]


INFO:tensorflow:Evaluation [10/16]


INFO:tensorflow:Evaluation [10/16]


INFO:tensorflow:Evaluation [11/16]


INFO:tensorflow:Evaluation [11/16]


INFO:tensorflow:Evaluation [12/16]


INFO:tensorflow:Evaluation [12/16]


INFO:tensorflow:Evaluation [13/16]


INFO:tensorflow:Evaluation [13/16]


INFO:tensorflow:Evaluation [14/16]


INFO:tensorflow:Evaluation [14/16]


INFO:tensorflow:Evaluation [15/16]


INFO:tensorflow:Evaluation [15/16]


INFO:tensorflow:Evaluation [16/16]


INFO:tensorflow:Evaluation [16/16]


INFO:tensorflow:Finished evaluation at 2020-03-08-21:03:33


INFO:tensorflow:Finished evaluation at 2020-03-08-21:03:33


INFO:tensorflow:Saving dict for global step 2000: global_step = 2000, loss = 0.054582532, top_1_accuracy = 0.9814583, top_5_accuracy = 0.9998958


INFO:tensorflow:Saving dict for global step 2000: global_step = 2000, loss = 0.054582532, top_1_accuracy = 0.9814583, top_5_accuracy = 0.9998958


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: ./tmp/mnist_pretrain/model.ckpt-2000


## Fine-tuning SVHN with pre-trained MNIST model

In [73]:
model_dir = "./tmp/finetune_svhn"
target_dataset = "svhn_cropped_small"
train_batch_size = 8
train_steps = 1200
learning_rate = 0.005
warm_start_ckpt_path = "./tmp/mnist_pretrain/model.ckpt-2000"

train(model_dir=model_dir,
      target_dataset=target_dataset,
      train_batch_size=train_batch_size,
      train_steps=train_steps,
      target_base_learning_rate=learning_rate,
      warm_start_ckpt_path=warm_start_ckpt_path)
evaluate(
    target_dataset=target_dataset,
    train_batch_size=600,
    cls_dense_name='final_dense_dst',
    ckpt_path=os.path.join(model_dir, 'model.ckpt-%d' % train_steps),
    src_num_classes=5
)

INFO:tensorflow:Warm-starting tensors: ['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense1/bias', 'dense1/kernel', 'dense2/bias', 'dense2/kernel']


INFO:tensorflow:Warm-starting tensors: ['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense1/bias', 'dense1/kernel', 'dense2/bias', 'dense2/kernel']


INFO:tensorflow:Using config: {'_model_dir': './tmp/finetune_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb60087c50>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': './tmp/finetune_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 20, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 128, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb60087c50>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






Run 0
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:loss = 2.3080974, step = 0


INFO:tensorflow:loss = 2.3080974, step = 0


INFO:tensorflow:global_step/sec: 128.541


INFO:tensorflow:global_step/sec: 128.541


INFO:tensorflow:loss = 1.5260081, step = 128 (1.002 sec)


INFO:tensorflow:loss = 1.5260081, step = 128 (1.002 sec)


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 190.712


INFO:tensorflow:global_step/sec: 190.712


INFO:tensorflow:loss = 0.60132706, step = 256 (0.666 sec)


INFO:tensorflow:loss = 0.60132706, step = 256 (0.666 sec)


INFO:tensorflow:global_step/sec: 217.392


INFO:tensorflow:global_step/sec: 217.392


INFO:tensorflow:loss = 0.48986638, step = 384 (0.588 sec)


INFO:tensorflow:loss = 0.48986638, step = 384 (0.588 sec)


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.6591983.


INFO:tensorflow:Loss for final step: 0.6591983.


Run 400
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-400


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-400


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:loss = 0.48054022, step = 400


INFO:tensorflow:loss = 0.48054022, step = 400


INFO:tensorflow:global_step/sec: 129.917


INFO:tensorflow:global_step/sec: 129.917


INFO:tensorflow:loss = 0.2696886, step = 528 (0.987 sec)


INFO:tensorflow:loss = 0.2696886, step = 528 (0.987 sec)


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 185.933


INFO:tensorflow:global_step/sec: 185.933


INFO:tensorflow:loss = 0.37132114, step = 656 (0.688 sec)


INFO:tensorflow:loss = 0.37132114, step = 656 (0.688 sec)


INFO:tensorflow:global_step/sec: 217.251


INFO:tensorflow:global_step/sec: 217.251


INFO:tensorflow:loss = 0.28101885, step = 784 (0.589 sec)


INFO:tensorflow:loss = 0.28101885, step = 784 (0.589 sec)


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.27427408.


INFO:tensorflow:Loss for final step: 0.27427408.


Run 800
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


Tensor("IteratorGetNext:0", shape=(?, 32, 32, 3), dtype=uint8)
INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Warm-started 16 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-800


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-800


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:loss = 0.47439495, step = 800


INFO:tensorflow:loss = 0.47439495, step = 800


INFO:tensorflow:global_step/sec: 146.73


INFO:tensorflow:global_step/sec: 146.73


INFO:tensorflow:loss = 0.26509672, step = 928 (0.874 sec)


INFO:tensorflow:loss = 0.26509672, step = 928 (0.874 sec)


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 186.88


INFO:tensorflow:global_step/sec: 186.88


INFO:tensorflow:loss = 0.2611407, step = 1056 (0.687 sec)


INFO:tensorflow:loss = 0.2611407, step = 1056 (0.687 sec)


INFO:tensorflow:global_step/sec: 216.033


INFO:tensorflow:global_step/sec: 216.033


INFO:tensorflow:loss = 0.2608101, step = 1184 (0.592 sec)


INFO:tensorflow:loss = 0.2608101, step = 1184 (0.592 sec)


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/finetune_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.26533544.


INFO:tensorflow:Loss for final step: 0.26533544.






INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30mfsrjf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8844a588>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30mfsrjf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8844a588>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-03-08T21:04:19Z


INFO:tensorflow:Starting evaluation at 2020-03-08T21:04:19Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-1200


INFO:tensorflow:Restoring parameters from ./tmp/finetune_svhn/model.ckpt-1200


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2020-03-08-21:04:21


INFO:tensorflow:Finished evaluation at 2020-03-08-21:04:21


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.688844, top_1_accuracy = 0.70933336, top_5_accuracy = 1.0


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.688844, top_1_accuracy = 0.70933336, top_5_accuracy = 1.0


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/finetune_svhn/model.ckpt-1200


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/finetune_svhn/model.ckpt-1200


#### Fine-tuning SVHN with a  pre-trained MNIST model: top-1 accuracy is 70.9%.

## Train L2TL

In [74]:
import tensorflow_probability as tfp


dst_weight_decay = 0.0005
first_pretrain_steps = 0
init_rl_learning_rate = 0.01
learning_rate = 0.005
loss_weight_scale = 100.
model_dir = "./tmp/l2tl_svhn"
num_choices = 100
source_dataset = 'mnist'
source_train_batch_multiplier = 2
src_num_classes = 10
target_dataset = 'svhn_cropped_small'
target_train_batch_multiplier = 1
target_val_batch_multiplier = 4
target_num_classes = 5
train_batch_size = 8
train_steps = 1200
uniform_weight = 0
warm_start_ckpt_path = "./tmp/mnist_pretrain/model.ckpt-2000"


def get_global_step(name):
  """Returns the global step variable."""
  global_step = tf.get_variable(
      name,
      shape=[],
      dtype=tf.int64,
      initializer=tf.initializers.zeros(),
      trainable=False,
      collections=[tf.GraphKeys.GLOBAL_VARIABLES])
  return global_step


def get_src_train_op(loss):  # pylint: disable=unused-argument
  """Returns the source training op."""
  global_step = tf.train.get_global_step()
  src_learning_rate = learning_rate
  src_learning_rate = tf.train.piecewise_constant(
      global_step, [800,],
      [learning_rate, learning_rate * 0.1])
  optimizer = tf.train.MomentumOptimizer(
      learning_rate=src_learning_rate,
      momentum=0.9,
      use_nesterov=True
  )
  with tf.variable_scope('src'):
    return optimizer.minimize(loss, global_step), src_learning_rate


def meta_train_op(acc, rl_entropy, log_prob, rl_scope, params):  # pylint: disable=unused-argument
  """Returns the target training op.

  Update the control variables using policy gradient.
  Args:
    acc: reward on validation set. In our case, the reward is the top-1 acc;
    rl_entropy: entropy of action logits;
    log_prob: log prob of the action;
    rl_scope: variable scope;
    params: other params;

  Returns:
    target_train_op: train op;
    rl_learning_rate: lr;
    out_metric: metric dict;
  """
  target_global_step = get_global_step('train_rl_global_step')
  rl_reward = acc
  rl_step_baseline = rl_reward
  rl_baseline_momentum = 0.9
  rl_entropy_regularization = 0.001

  def update_rl_baseline():
    return model_utils.update_exponential_moving_average(
        rl_step_baseline, momentum=rl_baseline_momentum)

  rl_baseline = update_rl_baseline()

  rl_advantage = rl_reward - rl_baseline
  rl_empirical_loss = -tf.stop_gradient(rl_advantage) * log_prob

  rl_entropy_loss = -rl_entropy_regularization * rl_entropy

  enable_rl_optimizer = tf.cast(
      tf.greater_equal(target_global_step, first_pretrain_steps),
      tf.float32)
  rl_learning_rate = init_rl_learning_rate * enable_rl_optimizer
  rl_learning_rate = tf.train.piecewise_constant(
      target_global_step, [800,],
      [rl_learning_rate, rl_learning_rate * 0.1])

  optimizer = tf.train.AdamOptimizer(rl_learning_rate)
  target_train_op = optimizer.minimize(
      rl_empirical_loss,
      target_global_step,
      var_list=tf.trainable_variables(rl_scope.name))

  out_metric = {
      'rl_empirical_loss': rl_empirical_loss,
      'rl_entropy_loss': rl_entropy_loss,
      'rl_reward': rl_reward,
      'rl_step_baseline': rl_step_baseline,
      'rl_baseline': rl_baseline,
      'rl_advantage': rl_advantage,
      'log_prob': log_prob,
  }
  return target_train_op, rl_learning_rate, out_metric


def get_logits(feature, mode, dataset_name, reuse=None):
  """Returns the network logits."""
  avg_pool = model.conv_model(feature, mode,
                              target_dataset=target_dataset,
                              src_hw=28,
                              target_hw=32,
                              dataset_name=dataset_name,
                              reuse=reuse)
  return avg_pool


def do_cls(avg_pool, num_classes, name='dense'):
  """Applies classification."""
  with tf.variable_scope('target_CLS', reuse=tf.AUTO_REUSE):
    logits = tf.layers.dense(
        inputs=avg_pool,
        units=num_classes,
        kernel_initializer=tf.random_normal_initializer(stddev=.05),
        name=name)
    return logits


def get_model_logits(src_features, finetune_features, mode, num_classes,
                     target_num_classes):
  """Gets the logits from different models."""
  src_avg_pool = get_logits(
      src_features, mode, source_dataset, reuse=None)
  dst_avg_pool = get_logits(
      finetune_features, mode, target_dataset, reuse=True)

  src_logits = do_cls(src_avg_pool, num_classes, name='final_dense_dst')
  dst_logits = do_cls(
      dst_avg_pool, target_num_classes, name='final_target_dense')
  return src_logits, dst_logits


def get_final_loss(src_logits, src_one_hot_labels, dst_logits,
                   finetune_one_hot_labels, global_step, loss_weights,
                   inst_weights):
  """Gets the final loss for l2tl."""
  if uniform_weight:
    inst_weights = 1.0

  def get_loss(logits, inst_weights, one_hot_labels):
    """Returns the loss function."""
    loss = tf.losses.softmax_cross_entropy(
        logits=logits, weights=inst_weights, onehot_labels=one_hot_labels)
    return loss

  src_loss = get_loss(src_logits, inst_weights, src_one_hot_labels)
  dst_loss = get_loss(dst_logits, 1., finetune_one_hot_labels)
  l2_loss = []
  for v in tf.trainable_variables():
    if 'batch_normalization' not in v.name and 'rl_controller' not in v.name:
      l2_loss.append(tf.nn.l2_loss(v))
  l2_loss = dst_weight_decay * tf.add_n(l2_loss)

  enable_pretrain = tf.cast(
      tf.greater_equal(global_step, first_pretrain_steps), tf.float32)

  loss = src_loss * tf.stop_gradient(loss_weights) * enable_pretrain
  loss += dst_loss + l2_loss

  return tf.identity(loss), src_loss, dst_loss


def train_model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
  """Defines the model function."""
  global_step = tf.train.get_global_step()

  src_features, src_labels = features['src'], tf.cast(labels['src'], tf.int64)
  finetune_features = features['finetune']
  target_features = features['target']

  num_classes = src_num_classes

  finetune_one_hot_labels = tf.one_hot(
      tf.cast(labels['finetune'], tf.int64), target_num_classes)
  target_one_hot_labels = tf.one_hot(
      tf.cast(labels['target'], tf.int64), target_num_classes)

  with tf.variable_scope('rl_controller') as rl_scope:
    # It creates a `rl_scope` which will be used for ops.
    pass
  rl_entropy, label_weights, log_prob = rl_label_weights(rl_scope)
  loss_entropy, loss_weights, loss_log_prob = get_loss_weights(rl_scope)

  def gather_init_weights():
    inst_weights = tf.stop_gradient(tf.gather(label_weights, src_labels))
    return inst_weights

  inst_weights = gather_init_weights()
  bs = train_batch_size
  hw = 28
  inst_weights, indices = tf.nn.top_k(
      inst_weights,
      k=bs,
      sorted=True,
  )
  src_features = tf.reshape(src_features, [
      bs * source_train_batch_multiplier,
      hw,
      hw,
      1,
  ])
  src_features = tf.gather(src_features, indices, axis=0)
  src_features = tf.stop_gradient(src_features)

  src_labels = tf.gather(src_labels, indices)

  inst_weights = bs * inst_weights / tf.reduce_sum(inst_weights)

  src_one_hot_labels = tf.one_hot(tf.cast(src_labels, tf.int64), num_classes)

  src_logits, dst_logits = get_model_logits(src_features, finetune_features,
                                            mode, num_classes,
                                            target_num_classes)

  loss, _, _ = get_final_loss(src_logits, src_one_hot_labels, dst_logits,
                              finetune_one_hot_labels, global_step,
                              loss_weights, inst_weights)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

  with tf.control_dependencies(update_ops):
    src_train_op, _ = get_src_train_op(loss)
    with tf.control_dependencies([src_train_op]):
      target_avg_pool = get_logits(
          target_features, mode, target_dataset, reuse=True)
      target_logits = do_cls(
          target_avg_pool, target_num_classes, name='final_target_dense')
      is_prediction_correct = tf.equal(
          tf.argmax(tf.identity(target_logits), axis=1),
          tf.argmax(target_one_hot_labels, axis=1))
      acc = tf.reduce_mean(tf.cast(is_prediction_correct, tf.float32))

      entropy = loss_entropy + rl_entropy
      log_prob = loss_log_prob + log_prob
      train_op, _, _ = meta_train_op(acc, entropy, log_prob, rl_scope, params)

  return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)


def rl_label_weights(name=None):
  """Returns the weight for importance."""
  with tf.variable_scope(name, 'rl_op_selection'):
    num_classes = src_num_classes

    logits = tf.get_variable(
        name='logits_rl_w',
        initializer=tf.initializers.zeros(),
        shape=[num_classes, num_choices],
        dtype=tf.float32)
    dist = tfp.distributions.Categorical(logits=logits)
    dist_entropy = tf.reduce_sum(dist.entropy())

    sample = dist.sample()
    sample_masks = 1. * tf.cast(sample, tf.float32) / num_choices
    sample_log_prob = tf.reduce_mean(dist.log_prob(sample))

  return (dist_entropy, sample_masks, sample_log_prob)


def get_loss_weights(name=None):
  """Returns the weight for loss."""
  with tf.variable_scope(name, 'rl_op_selection'):

    logits = tf.get_variable(
        name='loss_logits_rl_w',
        initializer=tf.initializers.zeros(),
        shape=[
            num_choices,
        ],
        dtype=tf.float32)
    dist = tfp.distributions.Categorical(logits=logits)
    dist_entropy = tf.reduce_sum(dist.entropy())

    sample = dist.sample()
    sample_masks = 1. * tf.cast(sample, tf.float32) / loss_weight_scale
    sample_log_prob = tf.reduce_mean(dist.log_prob(sample))

  return (dist_entropy, sample_masks, sample_log_prob)


def train_l2tl():
  tf.set_random_seed(1)

  run_config_args = {
      'model_dir': model_dir,
      'save_checkpoints_steps': 200,
      'log_step_count_steps': 100,
      'keep_checkpoint_max': 100,
  }
  config = tf.contrib.tpu.RunConfig(**run_config_args)

  if warm_start_ckpt_path:
    var_names = []
    checkpoint_path = warm_start_ckpt_path
    reader = tf.train.NewCheckpointReader(checkpoint_path)
    for key in reader.get_variable_to_shape_map():
      keep_str = 'Momentum|global_step|finetune_global_step'
      if not re.findall('({})'.format(keep_str,), key):
        var_names.append(key)

    tf.logging.info('Warm-starting tensors: %s', sorted(var_names))

    vars_to_warm_start = var_names
    warm_start_settings = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=checkpoint_path,
        vars_to_warm_start=vars_to_warm_start)
  else:
    warm_start_settings = None

  l2tl_classifier = tf.estimator.Estimator(
      train_model_fn, config=config, warm_start_from=warm_start_settings)

  def make_input_dataset():
    """Return input dataset."""

    def _merge_datasets(train_batch, finetune_batch, target_batch):
      """Merge different splits."""
      train_features, train_labels = train_batch['image'], train_batch['label']
      finetune_features, finetune_labels = finetune_batch[
          'image'], finetune_batch['label']
      target_features, target_labels = target_batch['image'], target_batch[
          'label']
      features = {
          'src': train_features,
          'finetune': finetune_features,
          'target': target_features
      }
      labels = {
          'src': train_labels,
          'finetune': finetune_labels,
          'target': target_labels
      }
      return (features, labels)

    source_train_batch_size = int(
        round(train_batch_size * source_train_batch_multiplier))

    train_data = tfds.load(name=source_dataset, split='train')
    train_data = train_data.shuffle(512).repeat().batch(source_train_batch_size)

    target_train_batch_size = int(
        round(train_batch_size * target_train_batch_multiplier))
    finetune_data = tfds.load(name=target_dataset, split='train')
    finetune_data = finetune_data.shuffle(512).repeat().batch(
        target_train_batch_size)

    target_val_batch_size = int(
        round(train_batch_size * target_val_batch_multiplier))

    target_data = tfds.load(name=target_dataset, split='validation')
    target_data = target_data.shuffle(512).repeat().batch(target_val_batch_size)

    dataset = tf.data.Dataset.zip((train_data, finetune_data, target_data))
    dataset = dataset.map(_merge_datasets)
    dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
    return dataset

  max_train_steps = train_steps
  l2tl_classifier.train(make_input_dataset, max_steps=max_train_steps)

train_l2tl()

evaluate(
    target_dataset=target_dataset,
    train_batch_size=600,
    cls_dense_name='final_target_dense',
    ckpt_path=os.path.join(model_dir, 'model.ckpt-%d' % train_steps),
    src_num_classes=5
)

INFO:tensorflow:Warm-starting tensors: ['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense1/bias', 'dense1/kernel', 'dense2/bias', 'dense2/kernel', 'target_CLS/final_dense_dst/bias', 'target_CLS/final_dense_dst/kernel']


INFO:tensorflow:Warm-starting tensors: ['conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense1/bias', 'dense1/kernel', 'dense2/bias', 'dense2/kernel', 'target_CLS/final_dense_dst/bias', 'target_CLS/final_dense_dst/kernel']


INFO:tensorflow:Using config: {'_model_dir': './tmp/l2tl_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 100, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8835be10>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=None, num_cores_per_replica=None, 

INFO:tensorflow:Using config: {'_model_dir': './tmp/l2tl_svhn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 200, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 100, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8835be10>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=None, num_cores_per_replica=None, 





INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['target_CLS/final_dense_dst/kernel', 'target_CLS/final_dense_dst/bias', 'conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./tmp/mnist_pretrain/model.ckpt-2000', vars_to_warm_start=['target_CLS/final_dense_dst/kernel', 'target_CLS/final_dense_dst/bias', 'conv1/bias', 'conv1/kernel', 'conv2/bias', 'conv2/kernel', 'dense2/bias', 'dense1/bias', 'dense1/kernel', 'dense2/kernel'], var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-starting from: ./tmp/mnist_pretrain/model.ckpt-2000


INFO:tensorflow:Warm-started 10 variables.


INFO:tensorflow:Warm-started 10 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:loss = 1.9980861, step = 0


INFO:tensorflow:loss = 1.9980861, step = 0


INFO:tensorflow:global_step/sec: 75.5949


INFO:tensorflow:global_step/sec: 75.5949


INFO:tensorflow:loss = 1.7651411, step = 100 (1.326 sec)


INFO:tensorflow:loss = 1.7651411, step = 100 (1.326 sec)


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 200 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 87.2253


INFO:tensorflow:global_step/sec: 87.2253


INFO:tensorflow:loss = 1.3776346, step = 200 (1.146 sec)


INFO:tensorflow:loss = 1.3776346, step = 200 (1.146 sec)


INFO:tensorflow:global_step/sec: 97.3368


INFO:tensorflow:global_step/sec: 97.3368


INFO:tensorflow:loss = 0.6075663, step = 300 (1.028 sec)


INFO:tensorflow:loss = 0.6075663, step = 300 (1.028 sec)


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 400 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 90.3951


INFO:tensorflow:global_step/sec: 90.3951


INFO:tensorflow:loss = 0.41853, step = 400 (1.107 sec)


INFO:tensorflow:loss = 0.41853, step = 400 (1.107 sec)


INFO:tensorflow:global_step/sec: 96.9352


INFO:tensorflow:global_step/sec: 96.9352


INFO:tensorflow:loss = 0.39777467, step = 500 (1.032 sec)


INFO:tensorflow:loss = 0.39777467, step = 500 (1.032 sec)


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 600 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 84.913


INFO:tensorflow:global_step/sec: 84.913


INFO:tensorflow:loss = 0.5270952, step = 600 (1.177 sec)


INFO:tensorflow:loss = 0.5270952, step = 600 (1.177 sec)


INFO:tensorflow:global_step/sec: 95.1414


INFO:tensorflow:global_step/sec: 95.1414


INFO:tensorflow:loss = 0.29745102, step = 700 (1.052 sec)


INFO:tensorflow:loss = 0.29745102, step = 700 (1.052 sec)


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 800 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 87.806


INFO:tensorflow:global_step/sec: 87.806


INFO:tensorflow:loss = 0.6977656, step = 800 (1.137 sec)


INFO:tensorflow:loss = 0.6977656, step = 800 (1.137 sec)


INFO:tensorflow:global_step/sec: 92.699


INFO:tensorflow:global_step/sec: 92.699


INFO:tensorflow:loss = 0.26891646, step = 900 (1.080 sec)


INFO:tensorflow:loss = 0.26891646, step = 900 (1.080 sec)


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1000 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:global_step/sec: 85.4764


INFO:tensorflow:global_step/sec: 85.4764


INFO:tensorflow:loss = 0.2688369, step = 1000 (1.169 sec)


INFO:tensorflow:loss = 0.2688369, step = 1000 (1.169 sec)


INFO:tensorflow:global_step/sec: 94.5767


INFO:tensorflow:global_step/sec: 94.5767


INFO:tensorflow:loss = 0.2757079, step = 1100 (1.059 sec)


INFO:tensorflow:loss = 0.2757079, step = 1100 (1.059 sec)


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1200 into ./tmp/l2tl_svhn/model.ckpt.


INFO:tensorflow:Loss for final step: 0.27123824.


INFO:tensorflow:Loss for final step: 0.27123824.






INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn16vixb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8835bef0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn16vixb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fcb8835bef0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}






INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-03-08T21:05:50Z


INFO:tensorflow:Starting evaluation at 2020-03-08T21:05:50Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ./tmp/l2tl_svhn/model.ckpt-1200


INFO:tensorflow:Restoring parameters from ./tmp/l2tl_svhn/model.ckpt-1200


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2020-03-08-21:05:52


INFO:tensorflow:Finished evaluation at 2020-03-08-21:05:52


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.0148227, top_1_accuracy = 0.76783335, top_5_accuracy = 1.0


INFO:tensorflow:Saving dict for global step 1200: global_step = 1200, loss = 1.0148227, top_1_accuracy = 0.76783335, top_5_accuracy = 1.0


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/l2tl_svhn/model.ckpt-1200


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1200: ./tmp/l2tl_svhn/model.ckpt-1200


#### Train L2TL with a  pre-trained MNIST model: top-1 accuracy is 76.8%.