In [1]:
import sys
sys.path.append('../..')

import numpy as np
from galois_datasets import load_all_data_mnist, load_all_data_cifar10, load_all_data_fashion_mnist
from galois_utils import create_batch_data, to_finite_field_domain
import modules
import galois_layers
from galois_criterions import GaloisFieldMSELoss
from sklearn.datasets import make_classification
import galois_activations
import galois

In [2]:
BATCH_SIZE = 16
EPOCH = 1
PRINT = 1
FLATTEN = False
# 0, MNIST; 1, FashionMNIST; 2, CIFAR10; 3 RANDOM
DATASET_MODE = 2

QUANTIZATION_INPUT = 8
QUANTIZATION_WEIGHT = 32
QUANTIZATION_BATCH_SIZE = 4
LR = 7
PRIME = 136759815150493740654140208079

In [3]:
field = galois.GF(PRIME)

In [4]:
# data fetching
load_path = '../../data'
if DATASET_MODE == 0:
    train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 1:
    train_data, train_label, test_data, test_label = load_all_data_fashion_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 2:
    train_data, train_label, test_data, test_label = load_all_data_cifar10(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, field, flatten=FLATTEN)
elif DATASET_MODE == 3:
    train_data, train_label = make_classification(n_samples=10, n_features=100, n_classes=10, n_clusters_per_class=1, n_informative=10)
    train_data = train_data.reshape((-1, 4, 5, 5))
    test_data, test_label = train_data, train_label
    train_label = np.zeros((10, 10))
    for idx, label in enumerate(test_label):
        train_label[idx][label] = 1
    train_data, train_label, test_data = to_finite_field_domain(train_data, QUANTIZATION_INPUT, PRIME), to_finite_field_domain(train_label, QUANTIZATION_WEIGHT, PRIME), to_finite_field_domain(test_data, QUANTIZATION_INPUT, PRIME)
    train_data, train_label, test_data = field(train_data), field(train_label), field(test_data)
else:
    train_data, train_label, test_data, test_label = None, None, None, None
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
model_arr = [
    galois_layers.GaloisFieldPiNetSecondOrderConvLayer(3, 64, (4, 4), QUANTIZATION_WEIGHT, PRIME, field, first_layer=True,
                                                quantization_bit_input=QUANTIZATION_INPUT, stride=(2, 2), padding=(1, 1, 1, 1)),
    galois_layers.GaloisFieldPiNetSecondOrderConvLayer(64, 256, (4, 4), QUANTIZATION_WEIGHT, PRIME, field, stride=(2, 2), padding=(1, 1, 1, 1)),
    galois_activations.GAPTruncation(PRIME, field),
    galois_layers.GaloisFieldLinearLayer(256, 10, QUANTIZATION_WEIGHT, PRIME, field)
]

model = modules.Network(model_arr)
criterion = GaloisFieldMSELoss(PRIME, QUANTIZATION_WEIGHT, QUANTIZATION_BATCH_SIZE, field)

In [7]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        loss = criterion.forward(preds, train_label_batch)
        tot_loss += loss
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)

        print('epoch: {}, idx: {}, curr loss: {}'.format(epoch + 1, train_idx + 1, loss))
        if train_idx == 0 or (train_idx + 1) % PRINT == 0:
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, avg loss: {}'.format(epoch + 1, train_idx + 1, tot_loss))
            tot_loss = 0

epoch: 1, idx: 1, curr loss: 1.1185869487179363
epoch: 1, idx: 1, avg loss: 1.1185869487179363
epoch: 1, idx: 2, curr loss: 1.0432649704920263
epoch: 1, idx: 2, avg loss: 1.0432649704920263
epoch: 1, idx: 3, curr loss: 1.0132264232068984
epoch: 1, idx: 3, avg loss: 1.0132264232068984
epoch: 1, idx: 4, curr loss: 0.9248980861451757
epoch: 1, idx: 4, avg loss: 0.9248980861451757
epoch: 1, idx: 5, curr loss: 0.9505719734633858
epoch: 1, idx: 5, avg loss: 0.9505719734633858
epoch: 1, idx: 6, curr loss: 0.9203075411623645
epoch: 1, idx: 6, avg loss: 0.9203075411623645
epoch: 1, idx: 7, curr loss: 0.9502521579892661
epoch: 1, idx: 7, avg loss: 0.9502521579892661
epoch: 1, idx: 8, curr loss: 0.8725383448622618
epoch: 1, idx: 8, avg loss: 0.8725383448622618
epoch: 1, idx: 9, curr loss: 0.9456538836274562
epoch: 1, idx: 9, avg loss: 0.9456538836274562
epoch: 1, idx: 10, curr loss: 0.986148118492385
epoch: 1, idx: 10, avg loss: 0.986148118492385
epoch: 1, idx: 11, curr loss: 0.958095495013495
ep

KeyboardInterrupt: 