In [1]:
import numpy as np
from datasets import load_all_data_mnist, load_all_data_cifar10, load_all_data_fashion_mnist, load_all_data_apply_vgg_cifar10
from utils import create_batch_data, to_finite_field_domain, to_real_domain
import modules
import layers
from criterions import FiniteFieldMSELoss
from sklearn.datasets import make_classification



In [2]:
BATCH_SIZE = 100
EPOCH = 100
PRINT = 1
FLATTEN = False
# 0, MNIST; 1, FashionMNIST; 2, CIFAR10; 3, VGG-CIFAR10; 4 RANDOM
DATASET_MODE = 4
PRIME = 2**26 - 5
QUANTIZATION_WEIGHT = 8
QUANTIZATION_INPUT = 8
QUANTIZATION_BATCH_SIZE = 8
LR = 4

In [6]:
# data fetching
load_path = '../../data'
if DATASET_MODE == 0:
    train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, flatten=FLATTEN)
elif DATASET_MODE == 1:
    train_data, train_label, test_data, test_label = load_all_data_fashion_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, flatten=FLATTEN)
elif DATASET_MODE == 2:
    train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, flatten=FLATTEN)
elif DATASET_MODE == 3:
    train_data, train_label, test_data, test_label = load_all_data_apply_vgg_cifar10(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, flatten=FLATTEN)
elif DATASET_MODE == 4:
    train_data, train_label = make_classification(n_samples=25, n_features=25, n_classes=10, n_clusters_per_class=1, n_informative=5)
    train_data = train_data.reshape((-1, 5, 5))[:, np.newaxis, :, :]
    test_data, test_label = train_data, train_label
    train_label = np.zeros((25, 10))
    for idx, label in enumerate(test_label):
        train_label[idx][label] = 1
    train_data, train_label, test_data = to_finite_field_domain(train_data, QUANTIZATION_INPUT, PRIME), to_finite_field_domain(train_label, QUANTIZATION_WEIGHT, PRIME), to_finite_field_domain(test_data, QUANTIZATION_INPUT, PRIME)
else:
    train_data, train_label, test_data, test_label = None, None, None, None
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

In [7]:
model_arr = [
    layers.FiniteFieldPiNetSecondOrderConvLayer(1, 3, (2, 2), QUANTIZATION_WEIGHT, PRIME, first_layer=True, quantization_bit_input=QUANTIZATION_INPUT),
    layers.FiniteFieldPiNetSecondOrderConvLayer(3, 3, (2, 2), QUANTIZATION_WEIGHT, PRIME),
    modules.Flatten(),
    layers.FiniteFieldPiNetSecondOrderLinearLayer(27, 10, QUANTIZATION_WEIGHT, PRIME),
    layers.FiniteFieldLinearLayer(10, 10, QUANTIZATION_WEIGHT, PRIME)
]

model = modules.Network(model_arr)
criterion = FiniteFieldMSELoss(PRIME, QUANTIZATION_WEIGHT, QUANTIZATION_BATCH_SIZE)

In [8]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        loss = criterion.forward(preds, train_label_batch)
        tot_loss += loss
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        if (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
                # train accuracy
                preds = model.forward(test_data_batch)
                preds = to_real_domain(preds, QUANTIZATION_WEIGHT, PRIME)
                pred_args = np.argmax(preds, axis=1)

                tot_acc += np.count_nonzero(pred_args == test_label_batch)
                tot_sample += test_data_batch.shape[0]
            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, accuracy: 0.08, loss: 1.1233520507812502
epoch: 2, idx: 1, accuracy: 0.12, loss: 1.1164111328125
epoch: 3, idx: 1, accuracy: 0.12, loss: 1.1125518798828125
epoch: 4, idx: 1, accuracy: 0.08, loss: 1.1082269287109374
epoch: 5, idx: 1, accuracy: 0.12, loss: 1.1010241699218748
epoch: 6, idx: 1, accuracy: 0.12, loss: 1.0968377685546875
epoch: 7, idx: 1, accuracy: 0.12, loss: 1.0941638183593752
epoch: 8, idx: 1, accuracy: 0.12, loss: 1.0857434082031248
epoch: 9, idx: 1, accuracy: 0.12, loss: 1.0846063232421874
epoch: 10, idx: 1, accuracy: 0.12, loss: 1.0777478027343748
epoch: 11, idx: 1, accuracy: 0.12, loss: 1.0740057373046874
epoch: 12, idx: 1, accuracy: 0.12, loss: 1.0692596435546873
epoch: 13, idx: 1, accuracy: 0.12, loss: 1.0637530517578124
epoch: 14, idx: 1, accuracy: 0.12, loss: 1.0622784423828124
epoch: 15, idx: 1, accuracy: 0.12, loss: 1.0561614990234376
epoch: 16, idx: 1, accuracy: 0.12, loss: 1.055255126953125
epoch: 17, idx: 1, accuracy: 0.12, loss: 1.0557104492