In [21]:
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils

from Dense import Dense
from Convolutional import Convolutional
from Reshape import Reshape
from Activations import ReLu, Sigmoid, Softmax
from Losses import binary_cross_entropy, binary_cross_entropy_prime
from Network import train, predict
from Layer import Layer

def preprocess_data(x, y, limit):
    zero_index = np.where(y == 0)[0][:limit]
    one_index = np.where(y == 1)[0][:limit]
    all_indices = np.hstack((zero_index, one_index))
    all_indices = np.random.permutation(all_indices)
    x, y = x[all_indices], y[all_indices]
    x = x.reshape(len(x), 1, 28, 28)
    x = x.astype("float32") / 255
    y = np_utils.to_categorical(y)
    y = y.reshape(len(y), 2, 1)
    return x, y

def preprocess_data_all_digits(x, y, limit):
    all_indices = np.random.permutation(len(x))[:limit]
    x, y = x[all_indices], y[all_indices]
    x = x.reshape(len(x), 1, 28, 28)
    x = x.astype("float32") / 255
    y = np_utils.to_categorical(y)
    y = y.reshape(len(y), 10, 1)
    return x, y

# load MNIST from server, limit to 100 images per class since we're not training on GPU
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = preprocess_data(x_train, y_train, 100)
x_test, y_test = preprocess_data(x_test, y_test, 100)

network = [
    Convolutional((1, 28, 28), 3, 5),
    Sigmoid(),
    Reshape((5, 26, 26), (5 * 26 * 26, 1)),
    Dense(5 * 26 * 26, 100),
    Sigmoid(),
    Dense(100, 2),
    Sigmoid()
]
# train
train(
    network,
    binary_cross_entropy,
    binary_cross_entropy_prime,
    x_train,
    y_train,
    epochs=20,
    learning_rate=0.1,
    batch_size=1
)

# test, keep track of correct predictions
correct = 0
for x, y in zip(x_test, y_test):
    output = predict(network, x)
    print(f"pred: {np.argmax(output)}, true: {np.argmax(y)}")
    if np.argmax(output) == np.argmax(y):
        correct += 1

# print the accuracy
print(f"{correct / len(x_test) * 100}% of test data correctly predicted")



1/20, error=0.6068859299535104
2/20, error=0.48767436095218214
3/20, error=0.6761470737122821
4/20, error=0.8010702993476315
5/20, error=0.8625519253987862
6/20, error=0.8311413589764026
7/20, error=0.8435947408778496
8/20, error=0.8625927141847266
9/20, error=0.8625961871205408
10/20, error=0.8625950006142089
11/20, error=0.8625933863139015
12/20, error=0.8625924459962981
13/20, error=0.8625920085966903
14/20, error=0.8625917723419596
15/20, error=0.8625916212238697
16/20, error=0.8625915176705563
17/20, error=0.8625914486876909
18/20, error=0.8625914059206145
19/20, error=0.8625913809518295
20/20, error=0.8625913655602722
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 0
pred: 0, true: 0
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 1
pred: 0, true: 0
pred: 0, true: 0
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 1
pred: 0, true: 1
pred: 0, true: 0
pred: 0, true: 0
pred: 0, true: 1
pred: 0, true: 1
pred: 0, tr