In [6]:
import numpy as np
from datasets import load_all_data_mnist, load_all_data_cifar10, load_all_data_fashion_mnist, load_all_data_apply_vgg_cifar10
from utils import create_batch_data
import modules
import layers
from criterions import RealMSELoss

In [7]:
BATCH_SIZE = 256
EPOCH = 5
LR = 0.01
PRINT = 10
FLATTEN = False
# 0, MNIST; 1, FashionMNIST; 2, CIFAR10; 3, VGG-CIFAR10
DATASET_MODE = 2

In [8]:
# data fetching
load_path = '../../data'
if DATASET_MODE == 0:
    train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, flatten=FLATTEN)
elif DATASET_MODE == 1:
    train_data, train_label, test_data, test_label = load_all_data_fashion_mnist(load_path, flatten=FLATTEN)
elif DATASET_MODE == 2:
    train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, flatten=FLATTEN)
elif DATASET_MODE == 3:
    train_data, train_label, test_data, test_label = load_all_data_apply_vgg_cifar10(load_path, flatten=FLATTEN)
else:
    train_data, train_label, test_data, test_label = None, None, None, None
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
model_arr = [
    layers.RealPiNetSecondOrderConvLayer(3, 6, (9, 9)),
    modules.Flatten(),
    layers.RealPiNetSecondOrderLinearLayer(3456, 128),
    layers.RealLinearLayer(128, 10)
]

model = modules.Network(model_arr)
criterion = RealMSELoss()

In [None]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        tot_loss += criterion.forward(preds, train_label_batch)
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
                # train accuracy
                preds = model.forward(test_data_batch)
                pred_args = np.argmax(preds, axis=1)

                tot_acc += np.count_nonzero(pred_args == test_label_batch)
                tot_sample += test_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, accuracy: 0.1505, loss: 1.553851144962415
epoch: 1, idx: 10, accuracy: 0.2103, loss: 0.9121908972448518
epoch: 1, idx: 20, accuracy: 0.2421, loss: 0.929817139014753
epoch: 1, idx: 30, accuracy: 0.2654, loss: 0.8947012438482627
epoch: 1, idx: 40, accuracy: 0.2778, loss: 0.8826544528633986
epoch: 1, idx: 50, accuracy: 0.2969, loss: 0.8792823297836613
epoch: 1, idx: 60, accuracy: 0.3066, loss: 0.8620027676648764
epoch: 1, idx: 70, accuracy: 0.3204, loss: 0.8625010331976724
epoch: 1, idx: 80, accuracy: 0.3311, loss: 0.848832472630531
epoch: 1, idx: 90, accuracy: 0.3349, loss: 0.835046372928866
epoch: 1, idx: 100, accuracy: 0.3461, loss: 0.8300451117349861
epoch: 1, idx: 110, accuracy: 0.3515, loss: 0.8335340221276688
epoch: 1, idx: 120, accuracy: 0.3604, loss: 0.8231468556989101
epoch: 1, idx: 130, accuracy: 0.3637, loss: 0.8241137123750935
epoch: 1, idx: 140, accuracy: 0.3648, loss: 0.8236029069510481
epoch: 1, idx: 150, accuracy: 0.3651, loss: 0.8154964941357179
epoch: 