In [22]:
import numpy as np
import pandas as pd
from NeuralNet import NeuralNetwork, calculate_accuracy
from sklearn.datasets import load_svmlight_file
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from bokeh.io import output_notebook
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool
output_notebook()

In [23]:
x, label = load_svmlight_file('Datasets/mnist.scale')
x = x.toarray()
label = label.astype(int)

x_train, x_val, label_train, label_val = train_test_split(x, label, test_size=0.30, random_state=42)

y_train = np.zeros((np.size(label_train, 0), len(np.unique(label_train))))
for i in range(np.size(label_train, 0)):
    y_train[i, label_train[i]] = 1

y_val = np.zeros((np.size(label_val, 0), len(np.unique(label_val))))
for i in range(np.size(label_val, 0)):
    y_val[i, label_val[i]] = 1

In [24]:
x_test, label_test = load_svmlight_file('Datasets/mnist.scale.t')
x_test = x_test.toarray()
x_test = np.hstack((x_test, np.zeros((x_test.shape[0], 2))))
label_test = label_test.astype(int)

y_test = np.zeros((np.size(label_test, 0), len(np.unique(label_test))))
for i in range(np.size(label_test, 0)):
    y_test[i, label_test[i]] = 1

In [25]:
max_iter = 20
nn = NeuralNetwork(layers=[100], learning_rate=10e-5)
nn.train_t(x_train, y_train, x_val, y_val, x_test, y_test, max_iter, print_error=True)

Iteration: 0	Training error: 280831.19455437746	Validation error: 175265.72221374736	Testing error 97869.36603631935
Iteration: 1	Training error: 411961.5409998209	Validation error: 201761.35066781548	Testing error 113171.66022245465
Iteration: 2	Training error: 472079.9305473893	Validation error: 200206.58640575618	Testing error 110413.79270543878
Iteration: 3	Training error: 467061.4742303811	Validation error: 143461.26022395832	Testing error 79810.8845504151
Iteration: 4	Training error: 336793.3700733509	Validation error: 93697.23932778022	Testing error 51808.19314884973
Iteration: 5	Training error: 219746.09829032264	Validation error: 37718.88763872492	Testing error 20924.580083379085
Iteration: 6	Training error: 88025.56896306915	Validation error: 23289.720737346208	Testing error 12725.339150057433
Iteration: 7	Training error: 54496.03427095401	Validation error: 18869.717437908905	Testing error 10222.991124477065
Iteration: 8	Training error: 43805.35120361663	Validation error: 175

In [26]:
itr = np.linspace(1, max_iter, len(nn.train_err))

source = ColumnDataSource(dict(
    itr = itr,
    train_err = nn.train_err,
    val_err = nn.val_err,
    test_err = nn.test_err
))

hover = HoverTool(tooltips=[('Iteration', '@itr'), ('Training error', '@train_err'), ('Validation error', '@val_err'), ('Testing error', '@test_err')])

plot = figure(
    title='Training-Validation-Testing errors vs Iterations',
    x_axis_label='Iterations',
    y_axis_label='Logistic error'
)

plot.line(x='itr', y='train_err', source=source, legend_label='Training error', line_width=2, color='purple')
plot.circle(x='itr', y='train_err', source=source, width=2, color='purple')
plot.line(x='itr', y='val_err', source=source, legend_label='Validation error', line_width=2, color='orange')
plot.circle(x='itr', y='val_err', source=source, width=2, color='orange')
plot.line(x='itr', y='test_err', source=source, legend_label='Testing error', line_width=2, color='green')
plot.circle(x='itr', y='test_err', source=source, width=2, color='green')
plot.add_tools(hover)
show(plot)

In [27]:
pred_train = nn.test(x_train)
pred_val = nn.test(x_val)
pred_test = nn.test(x_test)

train_acc = calculate_accuracy(pred_train, label_train)
val_acc = calculate_accuracy(pred_val, label_val)
test_acc = calculate_accuracy(pred_test, label_test)

print("training accuracy: {}".format(train_acc))
print("validation accuracy: {}".format(val_acc))
print("testing accuracy: {}".format(test_acc))

training accuracy: 0.8021428571428572
validation accuracy: 0.8012222222222222
testing accuracy: 0.8116


In [28]:
confusion_matrix(pred_test, label_test)

array([[ 880,    0,   19,   15,    4,   25,   18,    4,   16,    9],
       [   0, 1072,   15,    5,    5,   11,    9,   18,   22,    8],
       [  11,   18,  780,   44,    7,   23,   32,   23,   31,    4],
       [   7,    2,   29,  796,    3,   83,    0,   16,   48,   24],
       [   0,    0,   26,    4,  774,   34,   24,   17,   33,   94],
       [  30,    7,   11,   60,   11,  590,   21,    8,   55,   16],
       [  34,    5,   52,    7,   29,   31,  844,    5,   29,    4],
       [   6,    2,   22,   25,    5,   31,    0,  871,   10,   27],
       [   8,   28,   68,   44,   10,   44,    6,    9,  706,   20],
       [   4,    1,   10,   10,  134,   20,    4,   57,   24,  803]])