In [1]:
import sys
sys.path.append('../..')

import numpy as np
from galois_datasets import load_all_data_mnist, load_all_data_cifar10, load_all_data_fashion_mnist
from utils import create_batch_data, to_finite_field_domain, from_galois_to_real_domain
import modules
import galois_layers
from galois_criterions import GaloisFieldMSELoss
from sklearn.datasets import make_classification
import galois

In [2]:
BATCH_SIZE = 256
EPOCH = 1
PRINT = 1
FLATTEN = False
# 0, MNIST; 1, FashionMNIST; 2, CIFAR10; 3 RANDOM
DATASET_MODE = 0

QUANTIZATION_INPUT = 8
QUANTIZATION_WEIGHT = 16
QUANTIZATION_BATCH_SIZE = 8
LR = 7
PRIME = 684502462494449

In [3]:
field = galois.GF(PRIME)

In [4]:
# data fetching
load_path = '../../data'
if DATASET_MODE == 0:
    train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 1:
    train_data, train_label, test_data, test_label = load_all_data_fashion_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 2:
    train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 3:
    train_data, train_label = make_classification(n_samples=10, n_features=100, n_classes=10, n_clusters_per_class=1, n_informative=10)
    train_data = train_data.reshape((-1, 4, 5, 5))
    test_data, test_label = train_data, train_label
    train_label = np.zeros((10, 10))
    for idx, label in enumerate(test_label):
        train_label[idx][label] = 1
    train_data, train_label, test_data = to_finite_field_domain(train_data, QUANTIZATION_INPUT, PRIME), to_finite_field_domain(train_label, QUANTIZATION_WEIGHT, PRIME), to_finite_field_domain(test_data, QUANTIZATION_INPUT, PRIME)
    train_data, train_label, test_data = field(train_data), field(train_label), field(test_data)
else:
    train_data, train_label, test_data, test_label = None, None, None, None
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

In [5]:
model_arr = [
    galois_layers.GaloisFieldPiNetSecondOrderConvLayer(1, 6, (5, 5), QUANTIZATION_WEIGHT, PRIME, field, first_layer=True,
                                                quantization_bit_input=QUANTIZATION_INPUT),
    galois_layers.GaloisFieldPiNetSecondOrderConvLayer(6, 6, (5, 5), QUANTIZATION_WEIGHT, PRIME, field),
    modules.Flatten(),
    galois_layers.GaloisFieldPiNetSecondOrderLinearLayer(2400, 128, QUANTIZATION_WEIGHT, PRIME, field),
    galois_layers.GaloisFieldLinearLayer(128, 10, QUANTIZATION_WEIGHT, PRIME, field)
]

model = modules.Network(model_arr)
criterion = GaloisFieldMSELoss(PRIME, QUANTIZATION_WEIGHT, QUANTIZATION_BATCH_SIZE, field)

In [6]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        loss = criterion.forward(preds, train_label_batch)
        tot_loss += loss
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        print('epoch: {}, idx: {}, curr loss: {}'.format(epoch + 1, train_idx + 1, loss))
        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, avg loss: {}'.format(epoch + 1, train_idx + 1, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, curr loss: 1.468828772892266
epoch: 1, idx: 1, avg loss: 1.468828772892266
epoch: 1, idx: 2, curr loss: 1.0373511156176394
epoch: 1, idx: 2, avg loss: 1.0373511156176394
epoch: 1, idx: 3, curr loss: 0.9575303546607755
epoch: 1, idx: 3, avg loss: 0.9575303546607755
epoch: 1, idx: 4, curr loss: 0.9297337685493402
epoch: 1, idx: 4, avg loss: 0.9297337685493402
epoch: 1, idx: 5, curr loss: 0.9147911656755241
epoch: 1, idx: 5, avg loss: 0.9147911656755241
epoch: 1, idx: 6, curr loss: 0.8402015753035812
epoch: 1, idx: 6, avg loss: 0.8402015753035812
epoch: 1, idx: 7, curr loss: 0.7775595990251531
epoch: 1, idx: 7, avg loss: 0.7775595990251531
epoch: 1, idx: 8, curr loss: 0.793746826368988
epoch: 1, idx: 8, avg loss: 0.793746826368988
epoch: 1, idx: 9, curr loss: 0.7356878174978192
epoch: 1, idx: 9, avg loss: 0.7356878174978192
epoch: 1, idx: 10, curr loss: 0.7275589056707757
epoch: 1, idx: 10, avg loss: 0.7275589056707757
epoch: 1, idx: 11, curr loss: 0.7172351948183859
epo

In [8]:
tot_acc = 0
tot_sample = 0
for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
    # train accuracy
    preds = model.forward(test_data_batch)
    preds = from_galois_to_real_domain(preds, QUANTIZATION_WEIGHT, PRIME)
    pred_args = np.argmax(preds, axis=1)

    tot_acc += np.count_nonzero(pred_args == test_label_batch)
    tot_sample += test_data_batch.shape[0]
    print(tot_acc)
accuracy = tot_acc / tot_sample

244
481
712
947
1172
1409
1639
1879
2110
2352
2589
2833
3071
3309
3539
3773
4004
4238
4473
4712
4962
5211
5457
5690
5939
6180
6425
6676
6930
7177
7421
7674
7919
8172
8427
8673
8924
9172
9395
9411


In [9]:
accuracy

0.9411