In [31]:
import numpy as np

from real_networks import RealNetworkLinearTest
from real_criterion import RealMSELoss
from real_dataset import load_all_data_mnist, create_batch_data

In [39]:
BATCH_SIZE = 128
EPOCH = 1
LR = 0.001
PRINT = 20

In [40]:
# data fetching
load_path = '../data'
train_data, train_label, test_data, test_label = load_all_data_mnist(load_path)
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

In [37]:
model = RealNetworkLinearTest()
criterion = RealMSELoss()

In [41]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        tot_loss += criterion.forward(preds, train_label_batch)
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (train_acc_data_batch, train_acc_label_batch) in enumerate(zip(train_data, train_label)):
                # train accuracy
                preds = model.forward(train_acc_data_batch)

                pred_args = np.argmax(preds, axis=1)
                train_label_args = np.argmax(train_acc_label_batch, axis=1)
                tot_acc += np.count_nonzero(pred_args == train_label_args)
                tot_sample += train_acc_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, accuracy: 0.5385, loss: 1.7287706178749038
epoch: 1, idx: 20, accuracy: 0.5472166666666667, loss: 1.6990478310355335
epoch: 1, idx: 40, accuracy: 0.5590666666666667, loss: 1.7906268129146425
epoch: 1, idx: 60, accuracy: 0.5735666666666667, loss: 1.6804172805981203
epoch: 1, idx: 80, accuracy: 0.5844, loss: 1.676219226143078
epoch: 1, idx: 100, accuracy: 0.5899166666666666, loss: 1.5568866849462015
epoch: 1, idx: 120, accuracy: 0.6014, loss: 1.6172376447063517
epoch: 1, idx: 140, accuracy: 0.6165833333333334, loss: 1.5284456637547423
epoch: 1, idx: 160, accuracy: 0.6225333333333334, loss: 1.4277116158094127
epoch: 1, idx: 180, accuracy: 0.6256666666666667, loss: 1.3913125217686266
epoch: 1, idx: 200, accuracy: 0.6336833333333334, loss: 1.3334045302932884
epoch: 1, idx: 220, accuracy: 0.6415833333333333, loss: 1.3073363599655192
epoch: 1, idx: 240, accuracy: 0.6511, loss: 1.2488409364107547
epoch: 1, idx: 260, accuracy: 0.6613166666666667, loss: 1.2598964417304832
epoch