
Copyright 2022 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

**LocoProp: Enhancing BackProp via Local Loss Optimization**  
Ehsan Amid, Rohan Anil, Manfred K. Warmuth - AISTATS 2022
https://proceedings.mlr.press/v151/amid22a/amid22a.pdf


![picture](https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEjX5vceZXAWIJohaqhy5tPqs52ryTd78pxjlGiF4qOkAdTZ2tA_2nCFX2lFYJSqAHyWvXG_3vSwix6YhQPQlHLYcEN8JxrC-P-E2nK1b5oSKCqbST5AisTpmo8p0F0xN7UaKfErkit2juHxHc7U4TCEBiNtBzORZ0fpCFv4IK7k_aVj5_1VaBQ8mOjW0w/s16000/image1.gif)


Second-order methods have shown state-of-the-art performance for optimizing deep neural networks. Nonetheless, their large memory requirement and high computational complexity, compared to first-order methods, hinder their versatility in a typical low-budget setup. This paper introduces a general framework of layerwise loss construction for multilayer neural networks that achieves a performance closer to second-order methods while utilizing first-order optimizers only. Our methodology lies upon a three-component loss, target, and regularizer combination, for which altering each component results in a new update rule. We provide examples using squared loss and layerwise Bregman divergences induced by the convex integral functions of various transfer
functions. Our experiments on benchmark models and datasets validate the
efficacy of our new approach, reducing the gap between first-order and
second-order optimizers. See our [Google AI blog post](https://ai.googleblog.com/2022/07/enhancing-backpropagation-via-local.html) for further details.


Following illustrates how to train with LocoProp-M and LocoProp-S variants on MNIST with a Deep AutoEncoder. We primarily focus on optimizing training loss and autoencoders are known to be notoriously difficult to optimize. Our current version is in tensorflow-v1 and we plan to release JAX versions in the future.



In [None]:
"""LocoProp: Enhancing BackProp via Local Loss Optimization.

https://arxiv.org/abs/2106.06199, AISTATS 2022
Ehsan Amid, Rohan Anil, Manfred K. Warmuth

Second-order methods have shown state-of-the-art performance for optimizing
deep neural networks. Nonetheless, their large memory requirement and
high computational complexity, compared to first-order methods, hinder their
versatility in a typical low-budget setup. This paper introduces a general
framework of layerwise loss construction for multilayer neural networks that
achieves a performance closer to second-order methods while utilizing
first-order optimizers only. Our methodology lies upon a three-component loss,
target, and regularizer combination, for which altering each component results
in a new update rule. We provide examples using squared loss and layerwise
Bregman divergences induced by the convex integral functions of various transfer
functions. Our experiments on benchmark models and datasets validate the
efficacy of our new approach, reducing the gap between first-order and
second-order optimizers.

"""
import functools
import math

from absl import app
from absl import flags
from keras.datasets import mnist
import numpy as np
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt

tf.disable_v2_behavior()
tf.disable_eager_execution()

flags.DEFINE_float('learning_rate', 1e-5, help='Base learning rate.')
flags.DEFINE_float(
    'activation_learning_rate', 10, help='Activation learning rate.')
flags.DEFINE_integer('num_local_iters', 10, help='Number of local iterations.')
flags.DEFINE_enum('mode', 'LocoPropM', ['LocoPropS', 'LocoPropM', 'BP'],
                  'Which algorithm to use')
flags.DEFINE_enum('activation', 'TANH', ['RELU', 'TANH'],
                  'Which activation function to use')
flags.DEFINE_enum('optimizer', 'rmsprop', [
    'sgd', 'momentum', 'nesterov', 'adam', 'rmsprop', 'adagrad'],
                  'Which algorithm to use')

flags.DEFINE_float('one_minus_beta1', 0.001, help='Beta1 for Adam')
flags.DEFINE_float('one_minus_beta2', 0.1, help='Beta2 for Adam')
flags.DEFINE_float('epsilon', 1e-5, help='Diagonal epsilon')

flags.DEFINE_float('weight_decay', 1e-5, help='Weight decay.')
flags.DEFINE_integer('batch_size',
                     1000, help='Batch size.')
flags.DEFINE_integer('model_size_multiplier',
                     1, help='Multiply model size by a constant')
flags.DEFINE_integer('model_depth_multiplier',
                     1, help='Multiply model depth by a constant')
FLAGS = flags.FLAGS



### Training setup

In [None]:
def compute_squared_error(logits, targets):
  """Computes mean squared error between logits and targets."""
  return tf.reduce_mean(tf.square(targets - tf.nn.sigmoid(logits)))


def compute_cross_entropy_loss(logits, labels):
  """Computes cross entropy loss from logits."""
  loss_matrix = tf.nn.sigmoid_cross_entropy_with_logits(
      logits=logits, labels=labels)
  ce_loss = tf.reduce_sum(tf.reduce_mean(loss_matrix, axis=0))
  return ce_loss


def optimizer_from_params(params):
  """Construct a tf.train.Optimizer from params."""
  if params['optimizer'] == 'sgd':
    optimizer_class = tf.train.GradientDescentOptimizer
    optimizer_hparams = {'learning_rate': params['learning_rate']}
  elif params['optimizer'] == 'momentum':
    optimizer_class = tf.train.MomentumOptimizer
    optimizer_hparams = {
        'learning_rate': params['learning_rate'],
        'momentum': params['momentum']
    }
  elif params['optimizer'] == 'nesterov':
    optimizer_class = tf.train.MomentumOptimizer
    optimizer_hparams = {
        'learning_rate': params['learning_rate'],
        'momentum': params['momentum'],
        'use_nesterov': True
    }
  elif params['optimizer'] == 'adam':
    optimizer_class = tf.train.AdamOptimizer
    optimizer_hparams = {
        'learning_rate': params['learning_rate'],
        'beta1': params['beta1'],
        'beta2': params['beta2'],
        'epsilon': params['epsilon']}
  elif params['optimizer'] == 'rmsprop':
    optimizer_class = tf.train.RMSPropOptimizer
    optimizer_hparams = {
        'learning_rate': params['learning_rate'],
        'momentum': params['beta1'],
        'decay': params['beta2'],
        'epsilon': params['epsilon']
    }
  elif params['optimizer'] == 'adagrad':
    optimizer_class = tf.train.AdagradOptimizer
    optimizer_hparams = {
        'learning_rate': params['learning_rate'],
        'initial_accumulator_value': params['epsilon']
    }
  optimizer = optimizer_class(**optimizer_hparams)
  return optimizer

def act_fn(activation_fn_name):
  if activation_fn_name == 'NONE':
    act_fun = tf.identity
  elif activation_fn_name == 'TANH':
    act_fun = tf.nn.tanh
  elif activation_fn_name == 'RELU':
    act_fun = tf.nn.relu
  elif activation_fn_name == 'SIGMOID':
    act_fun = tf.nn.sigmoid
  return act_fun


# A simple autoencoder model with cross-entropy loss.
def create_autoencoder_model(input_image,
                             optimizer_hparams,
                             encoder_decoder_sizes,
                             ext='global',
                             mode='train',
                             act_lr=1.0,
                             transfer_func='RELU',
                             batch_size=1000):

  fc_layers = []
  fc_names = []
  fc_fns = []

  def get_weight_bias(name, ext, shape):
    return (tf.get_variable(
        name + '_' + ext + '_weight',
        shape=shape,
        initializer=tf.keras.initializers.glorot_uniform(),
        dtype=tf.float32),
            tf.get_variable(
                name + '_' + ext + '_bias',
                shape=(shape[1]),
                initializer=tf.keras.initializers.glorot_uniform(),
                dtype=tf.float32))

  encoder_sizes, decoder_sizes = encoder_decoder_sizes

  # A very simple autoencoder with a bottleneck layer.
  with tf.variable_scope('autoencoder_' + ext, reuse=tf.AUTO_REUSE):
    # First layer.
    fc_layers.append(
        get_weight_bias('encoder_layer_0', ext, (784, encoder_sizes[0])))
    fc_names.append('encoder_layer_0')
    fc_fns.append(transfer_func)

    for i in range(1, len(encoder_sizes)):
      fc_layers.append(
          get_weight_bias('encoder_layer_%d' % i, ext,
                          (encoder_sizes[i - 1], encoder_sizes[i])))
      fc_names.append('encoder_layer_%d' % i)
      if i == len(encoder_sizes) - 1:
        fc_fns.append('NONE')
      else:
        fc_fns.append(transfer_func)
    fc_layers.append(
        get_weight_bias('decoder_layer_0', ext,
                        (encoder_sizes[-1], decoder_sizes[0])))
    fc_names.append('decoder_layer_0')
    fc_fns.append(transfer_func)

    for i in range(1, len(decoder_sizes)):
      fc_layers.append(
          get_weight_bias('decoder_layer_%d' % i, ext,
                          (decoder_sizes[i - 1], decoder_sizes[i])))
      fc_names.append('decoder_layer_%d' % i)
      fc_fns.append(transfer_func)

    fc_layers.append(
        get_weight_bias('decoder_layer_%d' % len(decoder_sizes), ext,
                        (decoder_sizes[-1], 784)))
    fc_names.append('decoder_layer_%d' % len(decoder_sizes))
    fc_fns.append('SIGMOID')  # last layer (applied implicitly in the loss)
    activations = []
    post_activations = []

    # LocoProp requires knowing the activation / post activation values to
    # compute per layer targets.
    x = input_image
    for li in range(len(fc_layers)):
      x = tf.matmul(x, fc_layers[li][0]) + fc_layers[li][1]
      activations.append(x)
      if fc_fns[li] == 'TANH':
        x = tf.nn.tanh(x)
      elif fc_fns[li] == 'RELU':
        x = tf.nn.relu(x)
      elif fc_fns[li] == 'SIGMOID':
        y = x
        x = tf.nn.sigmoid(x)
      post_activations.append(x)
    weight_variables = [w for w, _ in fc_layers]
    bias_variables = [b for _, b in fc_layers]
    squared_err = compute_squared_error(x, input_image)
    ce_loss = compute_cross_entropy_loss(y, input_image)
    gs = tf.get_variable('total_steps', shape=[],
                         initializer=tf.zeros_initializer(), dtype=tf.int32)
    gs = gs.assign(gs + 1)
    assign_op = None
    train_op_loco_prop_s = None
    train_op_loco_prop_m = None
    train_op_bp = None
    assign_ops_input = []
    train_ops_s = []  # train_ops for LocoProp-S
    train_ops_m = []  # train_ops for LocoProp-M
    reset_optimizer_loco_prop = []
    input_checkpoints = []
    if mode == 'train':
      base_lr = optimizer_hparams['learning_rate']
      # construct the matching losses
      for i in range(len(fc_layers)):
        gs_layer = tf.get_variable(
            'total_steps_layer_' + str(i), shape=[],
            initializer=tf.zeros_initializer(), dtype=tf.int32)
        w, b = fc_layers[i]
        def _learning_fn(i):
          gs_layer = tf.get_variable(
              'total_steps_layer_' + str(i), shape=[],
              initializer=tf.zeros_initializer(), dtype=tf.int32)
          # Internally, LocoProp involves T steps of training. 
          # We use a decreasing schedule here.
          decay = tf.maximum(
              (1.0 - tf.cast(gs_layer, tf.float32) / FLAGS.num_local_iters),
              0.25)
          return base_lr * decay

        optimizer_hparams[
            'learning_rate'] = functools.partial(_learning_fn, i)
        optimizer_mp = optimizer_from_params(params=optimizer_hparams)
        reset_optimizer_loco_prop.append(
            tf.variables_initializer(optimizer_mp.variables()))
        act_fun = act_fn(fc_fns[i])
        activation = activations[i]
        post_activation = post_activations[i]
        if i == 0:
          input_to_layer = input_image
        else:
          input_to_layer = post_activations[i - 1]

        target_gd = activation - act_lr * tf.gradients(ce_loss, [activation])[0]
        target_primal = post_activation - act_lr * tf.gradients(
            ce_loss, [activation])[0]
        train_local_s = [tf.assign(gs_layer, 0)]
        train_local_m = [tf.assign(gs_layer, 0)]
        batch_target_gd = target_gd
        batch_target_primal = target_primal
        batch_input_layer = input_to_layer
        for _ in range(FLAGS.num_local_iters):
          with tf.control_dependencies(train_local_s):
            fake_activation = tf.matmul(batch_input_layer, w) + b
            delta_s = fake_activation - batch_target_gd
            gradient_weights_s = tf.matmul(
                tf.transpose(batch_input_layer), delta_s)
            gradient_bias_s = tf.reduce_sum(delta_s, 0)
            ## Two approaches to calculate the gradients:
            # (i) Using autodiff on the matching loss
            # (ii) Using matmul to calculate the gradients.
            train_local_s = [
                optimizer_mp.apply_gradients([(gradient_weights_s, w),
                                              (gradient_bias_s, b)],
                                             global_step=gs_layer)
            ]
          with tf.control_dependencies(train_local_m):
            fake_activation = tf.matmul(batch_input_layer, w) + b
            fake_post_activation = act_fun(fake_activation)
            delta_m = fake_post_activation - batch_target_primal
            gradient_weights_m = tf.matmul(
                tf.transpose(batch_input_layer), delta_m)
            gradient_bias_m = tf.reduce_sum(delta_m, 0)
            train_local_m = [
                optimizer_mp.apply_gradients([(gradient_weights_m, w),
                                              (gradient_bias_m, b)],
                                             global_step=gs_layer)
            ]

        train_ops_s.append(tf.group(*train_local_s))
        train_ops_m.append(tf.group(*train_local_m))
      global_optimizer = optimizer_from_params(params=optimizer_hparams)
      train_op_bp = global_optimizer.minimize(
          ce_loss, var_list=weight_variables + bias_variables)
      train_op_loco_prop_s = tf.group(*train_ops_s)
      train_op_loco_prop_m = tf.group(*train_ops_m)
      reset_optimizer_loco_prop = tf.group(*reset_optimizer_loco_prop)
  return {
      'train_op_loco_prop_s': train_op_loco_prop_s,
      'train_op_loco_prop_m': train_op_loco_prop_m,
      'input_checkpoints': input_checkpoints,
      'train_op_bp': train_op_bp,
      'loss': ce_loss,
      'squared_err': squared_err,
      'assign_op': assign_op,
      'assign_ops_input': assign_ops_input,
      'train_ops_sp': train_ops_s,
      'train_ops_mp': train_ops_m,
      'reset_optimizer_loco_prop': reset_optimizer_loco_prop,
      'fc_layers': fc_layers,
      'reconstructed_image': x,
      'activations': activations,
      'post_activations': post_activations,
      'fc_fns': fc_fns,
  }


### Training steps

In [None]:
(train_inputs, _), (test_inputs, test_labels) = mnist.load_data()
train_inputs = train_inputs.astype(np.float32)
test_inputs = test_inputs.astype(np.float32)

# Rescale input images to [0, 1]
train_inputs = np.reshape(train_inputs, [-1, 784]) / 255.0
test_inputs = np.reshape(test_inputs, [-1, 784]) / 255.0

num_train_examples = train_inputs.shape[0]
num_test_examples = test_inputs.shape[0]
print('MNIST dataset:')
print('Num train examples: ' + str(num_train_examples))
print('Num test examples: ' + str(num_test_examples))

tf.reset_default_graph()

batch_size = FLAGS.batch_size

# We find second-order methods and LocoProp to work quite well for deeper
# autoencoders.
encoder_sizes = [1000] +  [500] * FLAGS.model_depth_multiplier + [250, 30]
decoder_sizes = [250] +  [500] * FLAGS.model_depth_multiplier + [1000]

encoder_sizes = [FLAGS.model_size_multiplier * e for e in encoder_sizes]
decoder_sizes = [FLAGS.model_size_multiplier * e for e in decoder_sizes]
encoder_decoder_sizes = encoder_sizes, decoder_sizes
input_image_batch = tf.placeholder(tf.float32, (batch_size, 784))
input_image = tf.placeholder(tf.float32, (None, 784))

# LocoProp inner routine requires setting up activation learning rates.
act_lr = tf.placeholder(tf.float32, name='activation_lr')
lr = tf.placeholder(tf.float32, name='lr')
transfer_func = FLAGS.activation
optimizer_type = FLAGS.optimizer

optimizer_hparams = {
    'sgd': {
        'optimizer': 'sgd',
        'learning_rate': lr
    },
    'momentum': {
        'optimizer': 'momentum',
        'learning_rate': lr,
        'momentum': 1.0 - FLAGS.one_minus_beta1,
    },
    'nesterov': {
        'optimizer': 'nesterov',
        'learning_rate': lr,
        'momentum': 1.0 - FLAGS.one_minus_beta1,
        'use_nesterov': True,
    },
    'adam': {
        'optimizer': 'adam',
        'learning_rate': lr,
        'beta1': 1.0 - FLAGS.one_minus_beta1,
        'beta2': 1.0 - FLAGS.one_minus_beta2,
        'epsilon': FLAGS.epsilon
    },
    'adagrad': {
        'optimizer': 'adagrad',
        'learning_rate': lr,
        'epsilon': FLAGS.epsilon
    },
    'rmsprop': {
        'optimizer': 'rmsprop',
        'learning_rate': lr,
        'beta1': 1.0 - FLAGS.one_minus_beta1,
        'beta2': 1.0 - FLAGS.one_minus_beta2,
        'epsilon': FLAGS.epsilon,
    },
}

optimizer_hparams = optimizer_hparams[optimizer_type]
global_params_train = create_autoencoder_model(
    input_image_batch,
    optimizer_hparams,
    encoder_decoder_sizes,
    ext='global',
    mode='train',
    act_lr=act_lr,
    transfer_func=transfer_func,
    batch_size=batch_size)
global_params_infer = create_autoencoder_model(
    input_image,
    optimizer_hparams,
    encoder_decoder_sizes,
    ext='global',
    mode='infer',
    act_lr=act_lr,
    transfer_func=transfer_func,
    batch_size=batch_size)

# All experiments use 100 epochs of training with 5 epochs used as a warmup.
# A linear warmup followed by a decay is used for training.
num_epochs = 100
warmup_epochs = 5
disp_epoch = 1
act_lr_val = FLAGS.activation_learning_rate
reset_optimizers = False
lr_val = FLAGS.learning_rate
train_log = []
test_log = []
train_method = FLAGS.mode
num_local_iters = FLAGS.num_local_iters

print('-----------------------')
print('Authoencoder model with (%d x size, %d x depth) multipliers and'
      ' %s activation function.' % (FLAGS.model_size_multiplier,
                                    FLAGS.model_depth_multiplier,
                                    FLAGS.activation))
print('Train method: %s' % train_method)
if 'Loco' in train_method:
  print('Number of local iterations: %d' % num_local_iters)
  print('Internal optimizer: %s' % FLAGS.optimizer)
else:
  print('Optimizer: %s' % FLAGS.optimizer)
print('-----------------------')

# This is mainly recorded for hparam tuning setup that we used.
best_train_loss = 1e6
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  train_loss_val, train_squared_err_val = sess.run(
      [global_params_infer['loss'], global_params_infer['squared_err']],
      feed_dict={input_image: train_inputs})
  test_loss_val, test_squared_err_val = sess.run(
      [global_params_infer['loss'], global_params_infer['squared_err']],
      feed_dict={input_image: test_inputs})
  print('init (train, test) loss (%3.3f, %3.3f), '
                  '(train, test) squared error (%3.3f, %3.3f)' %
                  (train_loss_val, test_loss_val, train_squared_err_val,
                    test_squared_err_val))
  train_log.append([train_loss_val, train_squared_err_val])
  test_log.append([test_loss_val, test_squared_err_val])
  for epoch in range(99):
    idx_epoch = np.random.permutation(train_inputs.shape[0])
    for bb in range(int(num_train_examples / batch_size)):
      idx_batch = idx_epoch[np.arange(batch_size) + bb * batch_size]
      train_x = train_inputs[idx_batch]
      lr_bp = lr_val
      if epoch < warmup_epochs:
        lr_bp = lr_val * (epoch / warmup_epochs)
      else:
        lr_bp = lr_val * (1.0 - (epoch + 1 - warmup_epochs) /
                          (num_epochs - warmup_epochs))
      if train_method in ['LocoPropS', 'LocoPropM']:
        train_op_name = ('train_op_loco_prop_s' if train_method == 'LocoPropS'
                          else 'train_op_loco_prop_m')
        sess.run(
            global_params_train[train_op_name],
            feed_dict={
                input_image_batch: train_x,
                act_lr: act_lr_val,
                lr: lr_bp
            })

        if reset_optimizers:
          sess.run(global_params_train['reset_optimizer_loco_prop'])
      elif train_method == 'BP':
        sess.run(
            global_params_train['train_op_bp'],
            feed_dict={
                input_image_batch: train_x,
                lr: lr_bp
            })
    
    train_loss_val, train_squared_err_val = sess.run(
        [global_params_infer['loss'], global_params_infer['squared_err']],
        feed_dict={input_image: train_inputs})
    test_loss_val, test_squared_err_val = sess.run(
        [global_params_infer['loss'], global_params_infer['squared_err']],
        feed_dict={input_image: test_inputs})
    if (epoch + 1) % disp_epoch == 0:
      print('epoch %d, (train, test) loss (%3.3f, %3.3f), '
                      '(train, test) squared error (%3.3f, %3.3f)' %
                      (epoch + 1, train_loss_val, test_loss_val,
                        train_squared_err_val, test_squared_err_val))
    best_train_loss = min(train_loss_val, best_train_loss)

    # Used for hyper-parameter tuning early exits.
    if math.isnan(
        train_loss_val) or best_train_loss > 600 or train_loss_val > 1000:
      best_train_loss = 1000
      break
    if epoch > 10  and best_train_loss > 350:
      break
    train_log.append([train_loss_val, train_squared_err_val])
    test_log.append([test_loss_val, test_squared_err_val])


### Plot the results

In [None]:
plt.figure(figsize=(8, 6), dpi=300)

method_name = ('%s (num local iters=%d)' % (train_method, num_local_iters)
               if 'Loco' in train_method else 'BP')
plt.semilogy(
    np.arange(len(train_log)),
    [ll[0] for ll in train_log],
    '-',
    label='Train loss - ' + method_name,
    linewidth=2.0)
plt.semilogy(
    np.arange(len(test_log)),
    [ll[0] for ll in test_log],
    '--',
    label='Test loss  - ' + method_name,
    linewidth=2.0)
plt.xlabel('Epochs', fontsize=18)
plt.xticks(fontsize=12)
plt.ylabel('CE loss', fontsize=18)
plt.yticks(fontsize=12)
plt.legend(fontsize=14)
plt.show()