In [6]:
import numpy as np

from networks import RealNetworkLinearTest, RealNetworkLinearTest_v2, RealLeNet, RealLeNetReLU, RealPiNetNetworkLinear, RealPiNetNetworkLeNet, RealPiNetNetworkLeNetCIFAR10
from criterions import RealMSELoss
from datasets import load_all_data_mnist, load_all_data_cifar10
from utils import create_batch_data

In [7]:
BATCH_SIZE = 128
EPOCH = 1
LR = 1 / 256
PRINT = 10
MODE = 7

In [8]:
if MODE == 1:
    model = RealNetworkLinearTest()
    flatten = True
elif MODE == 2:
    model = RealNetworkLinearTest_v2()
    flatten = True
elif MODE == 3:
    model = RealLeNet()
    flatten = False
elif MODE == 4:
    model = RealLeNetReLU()
    flatten = False
elif MODE == 5:
    model = RealPiNetNetworkLinear()
    flatten = True
elif MODE == 6:
    model = RealPiNetNetworkLeNet()
    flatten = False
elif MODE == 7:
    model = RealPiNetNetworkLeNetCIFAR10()
    flatten = False
else:
    model = None
    flatten = True
criterion = RealMSELoss()

In [9]:
# data fetching
load_path = '../../data'
train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, flatten=flatten)
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        tot_loss += criterion.forward(preds, train_label_batch)
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
                # train accuracy
                preds = model.forward(test_data_batch)
                pred_args = np.argmax(preds, axis=1)

                tot_acc += np.count_nonzero(pred_args == test_label_batch)
                tot_sample += test_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, accuracy: 0.1192, loss: 1.059382075073404
epoch: 1, idx: 10, accuracy: 0.1986, loss: 0.8637434682604328
epoch: 1, idx: 20, accuracy: 0.2401, loss: 0.907005111418435
epoch: 1, idx: 30, accuracy: 0.2631, loss: 0.8895534251778388
epoch: 1, idx: 40, accuracy: 0.2805, loss: 0.8807508411951801
epoch: 1, idx: 50, accuracy: 0.2961, loss: 0.8638560134202937
epoch: 1, idx: 60, accuracy: 0.3071, loss: 0.8551981889055396
epoch: 1, idx: 70, accuracy: 0.3144, loss: 0.850013507891866
epoch: 1, idx: 80, accuracy: 0.3215, loss: 0.8473751262647868
epoch: 1, idx: 90, accuracy: 0.3285, loss: 0.8422778063139387
epoch: 1, idx: 100, accuracy: 0.3342, loss: 0.8339376096032529
epoch: 1, idx: 110, accuracy: 0.3418, loss: 0.8312417403437227
epoch: 1, idx: 120, accuracy: 0.3436, loss: 0.8323257108262195
epoch: 1, idx: 130, accuracy: 0.3433, loss: 0.8304260903535203
epoch: 1, idx: 140, accuracy: 0.352, loss: 0.8205108950331299
epoch: 1, idx: 150, accuracy: 0.358, loss: 0.8138961803336106
epoch: 1

KeyboardInterrupt: 