In [1]:
from mnistdataloader import *
from network import *
from os.path import join
import numpy as np
import matplotlib.pyplot as plt

# Loading data from files
input_path = 'data'
training_images_filepath = join(input_path, 'train-images-idx3-ubyte/train-images-idx3-ubyte')
training_labels_filepath = join(input_path, 'train-labels-idx1-ubyte/train-labels-idx1-ubyte')
test_images_filepath = join(input_path, 't10k-images-idx3-ubyte/t10k-images-idx3-ubyte')
test_labels_filepath = join(input_path, 't10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte')

mnist_dataloader = MnistDataloader(training_images_filepath, training_labels_filepath, test_images_filepath, test_labels_filepath)

(x_train, y_train), (x_test, y_test) = mnist_dataloader.load_data()

# Formatting data
x_train = np.array([np.array(x).flatten() for x in x_train])
y_train = np.eye(10)[y_train]

# Shuffling data
shuffle_order = np.random.permutation(len(x_train))
x_train = x_train[shuffle_order, ...]
y_train = y_train[shuffle_order, ...]

# Splitting data
batch_size = 256
x_batches = [x_train[i:i+64] for i in range(int(np.ceil(len(x_train[:(batch_size * 100)]) / batch_size)))]
y_batches = [y_train[i:i+64] for i in range(int(np.ceil(len(y_train[:(batch_size * 100)]) / batch_size)))]

x_test = [np.array(x).flatten() for x in x_test]
y_test = np.eye(10)[y_test]

In [2]:
nn = Network(28 * 28, 16, 16, 10, regularization=0)

In [None]:
for i in range(100):
    nn.batch_train(x_train, y_train, learning_rate=1)
    print('...', end='')
    print(f'{nn.batch_loss(x_train, y_train):.10f}')

...20.7839005777
...20.4385128138
...2.3748076326
...2.3639479020
...2.3544801530
...2.3462670733
...2.3391779528
...2.3330891816
...2.3278848348
...2.3234572118
...2.3197072422
...2.3165447090
...2.3138882734
...2.3116653069
...2.3098115549


In [9]:
count = 0
for x, y in zip(x_batches, y_batches):
    for i in range(100):
        nn.batch_train(x, y, learning_rate=0.0001)
    print(f"{nn.batch_loss(x, y):.10f}")
    count += 1
    if (count % 10) == 0:
        print(f"\n TEST LOSS: {nn.batch_loss(x_test, y_test):.10f} \n")
loss = nn.batch_loss(x_train, y_train)
print(f'Loss after: {loss:.10f}')

1.1485831908
0.7944455092
0.7753582391
0.7943594204
0.7915377680
0.8150180870
0.7840123218
0.7397353978
0.6970897223
0.8634204689

 TEST LOSS: 10.5248384982 

0.7778841147
0.6919870850
0.6729958632
0.6546037616
0.6126061657
0.6266826134
0.6423795851
0.6375686145
0.6420562774
0.6397535053

 TEST LOSS: 10.4643651894 

0.6530222622
0.6472877927
0.6095282656
0.5677839998
0.5655585984
0.5454364727
0.5372670388
0.5261407817
0.5137510146
0.5293419312

 TEST LOSS: 10.6405608320 

0.5269415317
0.5320639230
0.5295389766
0.5644657946
0.5633683381
0.5980506586
0.6333045778
0.6376288987


KeyboardInterrupt: 

In [4]:
nn.feed_forward(x_train[0])

array([0.07980572, 0.18291377, 0.09687323, 0.07711792, 0.07719099,
       0.11609646, 0.10441765, 0.08062703, 0.10162784, 0.08332939])

In [5]:
nn.batch_loss(x_test, y_test)

11.27521073865499

In [6]:
print(y_train[1])

[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]


In [7]:
def render(x):
    for i in range(28 * 28):
        if x[i] > 128:
            print('.', end='')
        else:
            print('#', end='')
        if (i + 1) % 28 == 0:
            print('')

In [8]:
render(x_train[1])

############################
############################
############################
#################.##########
#################.##########
#################.##########
#################.##########
#################.##########
#########..#####..##########
########...#####..##########
########...#####..##########
#########..#####..##########
#########..#####..##########
#########..#####..##########
##########.#####..##########
##########...###..##########
##########....#...##########
##########.###.......#######
#########..#####.###########
################.###########
################..##########
################..##########
################..##########
############################
############################
############################
############################
############################
