In [None]:
from learning_lib.nn.ffnn import FFNN

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

import plotly.offline as plotly
import plotly.graph_objs as go

from PIL import Image

In [None]:
plotly.init_notebook_mode(connected=True)

# Data

In [None]:
dataset = input_data.read_data_sets("MNIST_data/", one_hot=True)

train_data = dataset.train.next_batch(1000000)
train_in = train_data[0] - train_data[0].mean()
train_labels = train_data[1]

test_data = dataset.test.next_batch(1000000)

# MSE Loss

In [None]:
network = FFNN([784, 2048, 2048, 10])

In [None]:
%%time
network.train(train_in, train_labels, epochs=81400, batch_size=50, lc_interval=200)

In [None]:
train_pred = network.evaluate(train_in)

In [None]:
np.sum(train_pred.argmax(axis=1) == train_labels.argmax(axis=1)) / train_in.shape[0]

In [None]:
lc = pd.DataFrame(network.learning_curve)
plotly.iplot([go.Scatter(
    x=lc[0],
    y=lc[1],
    mode='lines'
)])

# Cross Entropy Loss

In [None]:
lc = [
    784,
    {
        'n_nodes': 2048, 'activation': tf.nn.tanh, 'init_weight_lower': 0, 'init_weight_upper': 1,
        'init_bias_lower': 0, 'init_bias_upper': 1
    },
    {
        'n_nodes': 10, 'activation': tf.identity, 'init_weight_lower': 0, 'init_weight_upper': 1,
        'init_bias_lower': 0, 'init_bias_upper': 1
    }
]

In [None]:
network = FFNN(lc, post_proc_function=tf.nn.softmax)

In [None]:
def cross_entropy_with_softmax(model_output, true_output):
    return tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=true_output, logits=model_output))

In [None]:
%%time
network.train(
    train_in,
    train_labels,
    epochs=1100,
    batch_size=100,
    report_interval=200,
    loss_func=cross_entropy_with_softmax,
    optimizer=tf.train.GradientDescentOptimizer(0.1)
)

In [None]:
train_pred = network.evaluate(train_in)

In [None]:
np.sum(train_pred.argmax(axis=1) == train_labels.argmax(axis=1)) / train_in.shape[0]

In [None]:
lc = pd.DataFrame(network.learning_curve)
plotly.iplot([go.Scatter(
    x=lc[0][1:],
    y=lc[1][1:],
    mode='lines'
)])