In [2]:
from lib_learning.nn.ffnn import FFNN
from lib_learning.nn_monitoring.loss import LossMonitor

import gzip
import pandas as pd
import numpy as np
import tensorflow as tf

import plotly.offline as plotly
import plotly.graph_objs as go

In [3]:
plotly.init_notebook_mode(connected=True)

# Data

In [4]:
with gzip.open('../data/MNIST_data/test_inputs.gz', 'rb') as f:
    test_in = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28) / 255.0
with gzip.open('../data/MNIST_data/train_inputs.gz', 'rb') as f:
    train_in = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28) / 255.0
with gzip.open('../data/MNIST_data/test_targets.gz', 'rb') as f:
    test_labels_raw = np.frombuffer(f.read(), np.uint8, offset=8)
with gzip.open('../data/MNIST_data/train_targets.gz', 'rb') as f:
    train_labels_raw = np.frombuffer(f.read(), np.uint8, offset=8)
    
test_labels = np.zeros((test_labels_raw.shape[0], 10))
test_labels[np.arange(test_labels_raw.shape[0]), test_labels_raw] = 1
train_labels = np.zeros((train_labels_raw.shape[0], 10))
train_labels[np.arange(train_labels_raw.shape[0]), train_labels_raw] = 1

test_in = test_in.astype('float32')
train_in = train_in.astype('float32')
test_lables = test_labels.astype('float32')
train_lables = train_labels.astype('float32')

In [5]:
# Run this if you want to normalize to zero mean
test_in = test_in - train_in.mean()
train_in = train_in - train_in.mean()

# Problem Statement

Theoretically, neural networks with Mean Squared Loss should be able fit a standard classification problem. However many classification models use cross-entropy loss with softmax instead. Here we investigate the hypothesis that the latter exhibits much stronger optimization properties and investigate why this might be.

# MSE Loss

In [20]:
lc = [
    784,
    {
        'n_nodes': 2048, 'activation': tf.nn.tanh, 'init_weight_mean': 0, 'init_weight_stddev': 0.01,
        'init_bias_mean': 0, 'init_bias_stddev': 0.01
    },
    {
        'n_nodes': 10, 'activation': tf.nn.softmax, 'init_weight_mean': 0, 'init_weight_stddev': 0.01,
        'init_bias_mean': 0, 'init_bias_stddev': 0.01
    }
]

In [21]:
monitors = [LossMonitor(100)]

In [22]:
network_mse_softmax = FFNN(lc, monitors=monitors)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [27]:
%%time
network_mse_softmax.init_session()
network_mse_softmax.train_offline(
    train_in,
    train_labels,
    epochs=440 * 100,
    batch_size=100
)



Epoch: 44000
CPU times: user 1h 24min 41s, sys: 26min 3s, total: 1h 50min 45s
Wall time: 46min 38s


# Cross Entropy Loss

In [28]:
def cross_entropy_with_softmax(model_output, true_output):
    return tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels=true_output, logits=model_output))

In [32]:
lc = [
    784,
    {
        'n_nodes': 2048, 'activation': tf.nn.tanh, 'init_weight_mean': 0, 'init_weight_stddev': 0.01,
        'init_bias_mean': 0, 'init_bias_stddev': 0.01
    },
    {
        'n_nodes': 10, 'activation': tf.nn.softmax, 'init_weight_mean': 0, 'init_weight_stddev': 0.01,
        'init_bias_mean': 0, 'init_bias_stddev': 0.01
    }
]

In [33]:
monitors = [LossMonitor(100)]

In [38]:
%%time
network_xe = FFNN(lc, post_proc_function=tf.nn.softmax, loss_func=cross_entropy_with_softmax, monitors=monitors)
network_xe.init_session()
network_xe.train_offline(
    train_in,
    train_labels,
    epochs=440 * 100,
    batch_size=100,
)

Epoch: 44000
CPU times: user 1h 24min 46s, sys: 26min 49s, total: 1h 51min 35s
Wall time: 48min 13s


# Comparison

In [41]:
train_pred = network_xe.predict(train_in)
np.sum(train_pred.argmax(axis=1) == train_labels.argmax(axis=1)) / train_in.shape[0]

0.98085

In [42]:
train_pred = network_mse_softmax.predict(train_in)
np.sum(train_pred.argmax(axis=1) == train_labels.argmax(axis=1)) / train_in.shape[0]

0.9299

In [49]:
plotly.iplot([
    go.Scatter(
        x=network_xe.monitors[0].values['epochs'],
        y=network_xe.monitors[0].values['loss']
    )
])

In [48]:
plotly.iplot([
#     go.Scatter(
#         x=network_xe.monitors[0].values['epochs'],
#         y=network_xe.monitors[0].values['loss']
#     ),
    go.Scatter(
        x=network_mse_softmax.monitors[0].values['epochs'],
        y=network_mse_softmax.monitors[0].values['loss']
    )
])

# Conclusion

These two examples show that all else being equal, having a softmax + cross entopy loss converges at a much faster rate than just MSE. Both methods appear to exhibit strong fitment to the data given sufficient time but convergence of the MSE loss was significantly slower than cross-entropy.