In [1]:
import numpy as np
from nets.real_net.datasets import load_all_data_mnist, load_all_data_cifar10, load_all_data_fashion_mnist, load_all_data_apply_vgg_cifar10
from utils import create_batch_data
import modules
import nets.real_net.layers as layers
from nets.real_net.activations import GAPTruncation
from nets.real_net.criterions import RealMSELoss

In [2]:
BATCH_SIZE = 16
EPOCH = 500
LR = 0.01
PRINT = 10
FLATTEN = False
# 0, MNIST; 1, FashionMNIST; 2, CIFAR10; 3, VGG-CIFAR10
DATASET_MODE = 2

In [3]:
# data fetching
load_path = './data'
if DATASET_MODE == 0:
    train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, flatten=FLATTEN)
elif DATASET_MODE == 1:
    train_data, train_label, test_data, test_label = load_all_data_fashion_mnist(load_path, flatten=FLATTEN)
elif DATASET_MODE == 2:
    train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, flatten=FLATTEN)
elif DATASET_MODE == 3:
    train_data, train_label, test_data, test_label = load_all_data_apply_vgg_cifar10(load_path, flatten=FLATTEN)
else:
    train_data, train_label, test_data, test_label = None, None, None, None
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model_arr = [
    layers.RealPiNetSecondOrderConvLayer(3, 64, (4, 4), stride=(2, 2), padding=(1, 1, 1, 1)),
    layers.RealPiNetSecondOrderConvLayer(64, 256, (4, 4), stride=(2, 2), padding=(1, 1, 1, 1)),
    GAPTruncation(),
    layers.RealLinearLayer(256, 10)
]

model = modules.Network(model_arr)
criterion = RealMSELoss()

In [5]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)
        curr_training_loss = criterion.forward(preds, train_label_batch)
        print(curr_training_loss)
        tot_loss += curr_training_loss
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
                # train accuracy
                preds = model.forward(test_data_batch)
                pred_args = np.argmax(preds, axis=1)

                tot_acc += np.count_nonzero(pred_args == test_label_batch)
                tot_sample += test_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

1.1563062482208086
1.110690994421935
1.125026166615066
0.9685392448506268
0.9404804686279006
0.9576365342648065
0.9829336698755847
0.9068246146741519
0.9623560464844473
0.9723662447371626
epoch: 1, idx: 10, accuracy: 0.1462, loss: 1.0083160232772492
0.9814801559288954
0.907891134509239
0.9491250848541535
0.8983824839718597
0.9686038688037201
0.8620076758002848
0.9696682155347127
0.9200370480686
0.9006921134164847
0.8913391463717057
epoch: 1, idx: 20, accuracy: 0.1927, loss: 0.9249226927259656
1.0493774739090715
0.8974140721271434
0.9346210130168189
0.8335115121363622
0.8994961805204779
0.916852834358256
0.9481442878787354
0.8286001411042655
0.8293420365797838
0.9209455464933753
epoch: 1, idx: 30, accuracy: 0.1783, loss: 0.905830509812429
0.9185102736530629
0.8747079570210322
0.9403185848477618
1.011163395770012
0.947466683403855
0.9007624265764139
0.8747253064573429
0.9623607670653002
0.8532893351131778
0.8834892762069965
epoch: 1, idx: 40, accuracy: 0.1896, loss: 0.9166794006114956
0.