In [1]:
import numpy as np

from networks import RealNetworkLinearTest, RealNetworkLinearTest_v2, RealLeNet, RealLeNetReLU
from criterions import RealMSELoss
from datasets import real_load_all_data_mnist, create_batch_data

In [2]:
BATCH_SIZE = 128
EPOCH = 1
LR = 0.1
PRINT = 10
MODE = 4

In [3]:
# data fetching
load_path = '../../data'
train_data, train_label, test_data, test_label = real_load_all_data_mnist(load_path, flatten=False)
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

In [4]:
if MODE == 1:
    model = RealNetworkLinearTest()
elif MODE == 2:
    model = RealNetworkLinearTest_v2()
elif MODE == 3:
    model = RealLeNet()
elif MODE == 4:
    model = RealLeNetReLU()
else:
    model = None
criterion = RealMSELoss()

In [5]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        tot_loss += criterion.forward(preds, train_label_batch)
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, loss: {}'.format(epoch + 1, train_idx + 1, tot_loss))
            tot_loss = 0

        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (train_acc_data_batch, train_acc_label_batch) in enumerate(zip(train_data, train_label)):
                # train accuracy
                preds = model.forward(train_acc_data_batch)

                pred_args = np.argmax(preds, axis=1)
                train_label_args = np.argmax(train_acc_label_batch, axis=1)
                tot_acc += np.count_nonzero(pred_args == train_label_args)
                tot_sample += train_acc_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, loss: 1.001598355036903
epoch: 1, idx: 1, accuracy: 0.14895, loss: 0
epoch: 1, idx: 10, loss: 0.7942631328887184
epoch: 1, idx: 10, accuracy: 0.6212, loss: 0.0
epoch: 1, idx: 20, loss: 0.7271700729453302
epoch: 1, idx: 20, accuracy: 0.7560333333333333, loss: 0.0
epoch: 1, idx: 30, loss: 0.5640655864069177
epoch: 1, idx: 30, accuracy: 0.7700166666666667, loss: 0.0
epoch: 1, idx: 40, loss: 0.452692307562613
epoch: 1, idx: 40, accuracy: 0.7904833333333333, loss: 0.0
epoch: 1, idx: 50, loss: 0.3744141319042222
epoch: 1, idx: 50, accuracy: 0.8731666666666666, loss: 0.0
epoch: 1, idx: 60, loss: 0.31410981510619906
epoch: 1, idx: 60, accuracy: 0.8892666666666666, loss: 0.0
epoch: 1, idx: 70, loss: 0.25362671952433674
epoch: 1, idx: 70, accuracy: 0.90435, loss: 0.0
epoch: 1, idx: 80, loss: 0.2323650339021249
epoch: 1, idx: 80, accuracy: 0.9241, loss: 0.0
epoch: 1, idx: 90, loss: 0.21168667754711273
epoch: 1, idx: 90, accuracy: 0.8597666666666667, loss: 0.0
epoch: 1, idx: 100,

KeyboardInterrupt: 