In [15]:
import numpy as np
import pandas as pd
from NeuralNet import NeuralNetwork, calculate_accuracy
from sklearn.datasets import load_svmlight_file
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from bokeh.io import output_notebook
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool
output_notebook()

In [16]:
x, label = load_svmlight_file('Datasets/mnist.scale')
x = x.toarray()
label = label.astype(int)

x_train, x_val, label_train, label_val = train_test_split(x, label, test_size=0.30, random_state=42)

y_train = np.zeros((np.size(label_train, 0), len(np.unique(label_train))))
for i in range(np.size(label_train, 0)):
    y_train[i, label_train[i]] = 1

y_val = np.zeros((np.size(label_val, 0), len(np.unique(label_val))))
for i in range(np.size(label_val, 0)):
    y_val[i, label_val[i]] = 1

In [17]:
x_test, label_test = load_svmlight_file('Datasets/mnist.scale.t')
x_test = x_test.toarray()
x_test = np.hstack((x_test, np.zeros((x_test.shape[0], 2))))
label_test = label_test.astype(int)

y_test = np.zeros((np.size(label_test, 0), len(np.unique(label_test))))
for i in range(np.size(label_test, 0)):
    y_test[i, label_test[i]] = 1

In [18]:
max_iter = 2000
nn = NeuralNetwork(layers=[50, 75], learning_rate=10e-5)
nn.train_t(x_train, y_train, x_val, y_val, x_test, y_test, max_iter, print_error=True)

teration: 1833	Training error: 3575.3176211496016	Validation error: 3332.5094756121694	Testing error 1720.9351445932373
Iteration: 1834	Training error: 3573.7276406152955	Validation error: 3332.447660627103	Testing error 1720.820909107069
Iteration: 1835	Training error: 3572.1389064763875	Validation error: 3332.3863179210907	Testing error 1720.7068827299584
Iteration: 1836	Training error: 3570.551416124086	Validation error: 3332.3254465823466	Testing error 1720.593065117793
Iteration: 1837	Training error: 3568.9651669414725	Validation error: 3332.265045699923	Testing error 1720.4794559294326
Iteration: 1838	Training error: 3567.3801563034745	Validation error: 3332.2051143636995	Testing error 1720.3660548267715
Iteration: 1839	Training error: 3565.79638157686	Validation error: 3332.1456516643752	Testing error 1720.252861474821
Iteration: 1840	Training error: 3564.2138401202105	Validation error: 3332.0866566934733	Testing error 1720.1398755417763
Iteration: 1841	Training error: 3562.6325

In [19]:
itr = np.linspace(1, max_iter, len(nn.train_err))

source = ColumnDataSource(dict(
    itr = itr,
    train_err = nn.train_err,
    val_err = nn.val_err,
    test_err = nn.test_err
))

hover = HoverTool(tooltips=[('Iteration', '@itr'), ('Training error', '@train_err'), ('Validation error', '@val_err'), ('Testing error', '@test_err')])

plot = figure(
    title='Training-Validation-Testing errors vs Iterations',
    x_axis_label='Iterations',
    y_axis_label='Logistic error'
)

plot.line(x='itr', y='train_err', source=source, legend_label='Training error', line_width=2, color='purple')
plot.circle(x='itr', y='train_err', source=source, width=2, color='purple')
plot.line(x='itr', y='val_err', source=source, legend_label='Validation error', line_width=2, color='orange')
plot.circle(x='itr', y='val_err', source=source, width=2, color='orange')
plot.line(x='itr', y='test_err', source=source, legend_label='Testing error', line_width=2, color='green')
plot.circle(x='itr', y='test_err', source=source, width=2, color='green')
plot.add_tools(hover)
show(plot)

In [20]:
pred_train = nn.test(x_train)
pred_val = nn.test(x_val)
pred_test = nn.test(x_test)

train_acc = calculate_accuracy(pred_train, label_train)
val_acc = calculate_accuracy(pred_val, label_val)
test_acc = calculate_accuracy(pred_test, label_test)

print("training accuracy: {}".format(train_acc))
print("validation accuracy: {}".format(val_acc))
print("testing accuracy: {}".format(test_acc))

training accuracy: 0.9795
validation accuracy: 0.9471111111111111
testing accuracy: 0.9506


In [21]:
confusion_matrix(pred_test, label_test)

array([[ 961,    0,    5,    0,    1,    5,    9,    2,    3,    6],
       [   0, 1114,    1,    2,    0,    0,    3,   11,    2,    6],
       [   1,    4,  977,   13,    3,    4,    5,   15,    3,    3],
       [   1,    2,   12,  951,    2,   27,    0,    3,   15,    6],
       [   0,    0,    6,    0,  943,    2,    7,    4,    5,   25],
       [   6,    1,    3,   17,    1,  827,   12,    1,   12,    6],
       [   3,    3,    5,    2,    5,   10,  917,    1,   10,    1],
       [   4,    3,   12,   10,    3,    3,    2,  966,    5,   13],
       [   3,    6,   10,   12,    2,    9,    3,    4,  911,    4],
       [   1,    2,    1,    3,   22,    5,    0,   21,    8,  939]])