# L4 stepsize adaptation performance on MNIST

This short notebook contains a minimum working example of L4 optimizers (Rolinek, Martius 2018) performing on the classical MNIST dataset.

## Imports

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import layers

import L4

## Network structure

In [2]:
def mlp(x, hidden=(300,100), num_output=10):
    in_dim = x.get_shape().as_list()[1]
    y_layer = x
    for l,n in enumerate(hidden):
        W = tf.get_variable("W{}".format(l), [in_dim, n],
                            initializer=layers.xavier_initializer())
        b = tf.get_variable("b{}".format(l), [n],
                            initializer=tf.zeros_initializer())
        y_layer = tf.nn.relu(tf.matmul(y_layer, W) + b)
        in_dim = n
    W = tf.get_variable("W_final", [in_dim, num_output],
                        initializer=tf.zeros_initializer())
    b = tf.get_variable("b_final", [num_output],
                        initializer=tf.zeros_initializer())
    y = tf.matmul(y_layer, W) + b
    return y

## Training parameters

In [3]:
config = {'data_dir': 'data',
          'epochs': 35,
          'batch_size': 64,
          'hidden': [300, 100],
          'epochs_per_report': 1}


## Computational graph setup

In [4]:
MNIST_size = 60000
mnist = input_data.read_data_sets(config['data_dir'], one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
y = mlp(x, hidden=config['hidden'], num_output=10)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


## Optimizer choice

In [5]:
opt = L4.L4Adam(fraction=0.25)
#opt = L4.L4Mom(fraction=0.25)
#opt = tf.train.AdamOptimizer(0.001, epsilon=1e-4)
#opt = tf.train.MomentumOptimizer(learning_rate=0.05, momentum=0.9)
#opt = tf.train.GradientDescentOptimizer(learning_rate=0.7)

train_op = opt.minimize(cross_entropy)

## Training Session

In [6]:
sess = tf.InteractiveSession()    
tf.global_variables_initializer().run()

batches_per_epoch = (MNIST_size // config['batch_size'])
batches_to_run = config['epochs'] * batches_per_epoch

for b in range(batches_to_run+1):    
    batch_xs, batch_ys = mnist.train.next_batch(config['batch_size'])
    _, loss = sess.run((train_op, cross_entropy), feed_dict={x: batch_xs, y_: batch_ys})

    if b % batches_per_epoch == 0:
        epoch_nr = b // batches_per_epoch
        if epoch_nr % config['epochs_per_report'] == 0:
            print("Epoch {}; Current Batch Loss: {}".format(epoch_nr, loss))

# Test trained model
accuracy_value = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print("Test accuracy: {}".format(accuracy_value))

Epoch 0; Current Batch Loss: 2.3025853633880615
Epoch 1; Current Batch Loss: 0.05435074120759964
Epoch 2; Current Batch Loss: 0.0886421799659729
Epoch 3; Current Batch Loss: 0.049092553555965424
Epoch 4; Current Batch Loss: 0.08038255572319031
Epoch 5; Current Batch Loss: 0.0029454075265675783
Epoch 6; Current Batch Loss: 0.0061236158944666386
Epoch 7; Current Batch Loss: 0.00017982714052777737
Epoch 8; Current Batch Loss: 0.00031989847775548697
Epoch 9; Current Batch Loss: 0.0009350153268314898
Epoch 10; Current Batch Loss: 8.277507004095241e-05
Epoch 11; Current Batch Loss: 2.9688142149097985e-06
Epoch 12; Current Batch Loss: 7.34933273633942e-05
Epoch 13; Current Batch Loss: 1.6763797461294416e-08
Epoch 14; Current Batch Loss: 1.8626450382086546e-09
Epoch 15; Current Batch Loss: 3.1664953326071554e-08
Epoch 16; Current Batch Loss: 5.196699248699588e-07
Epoch 17; Current Batch Loss: 1.1175867342672063e-08
Epoch 18; Current Batch Loss: 5.401665958970625e-08
Epoch 19; Current Batch Los