In [1]:
import numpy as np

from networks import FiniteFieldPiNetNetworkLinear, FiniteFieldPiNetNetworkLeNet, FiniteFieldPiNetNetworkLeNetCIFAR10, FiniteFieldPiNetNetworkDebug, FiniteFieldPiNetNetworkDebug2
from criterions import FiniteFieldMSELoss
from datasets import load_all_data_mnist, load_all_data_cifar10
from utils import create_batch_data, to_real_domain

In [2]:
BATCH_SIZE = 256
EPOCH = 1
LR = 7
PRINT = 20
MODE = 'debug_2'
PRIME = 2**26 - 5
QUANTIZATION_WEIGHT = 8
QUANTIZATION_INPUT = 8
QUANTIZATION_BATCH_SIZE = 8

In [3]:
if MODE == 1:
    model = FiniteFieldPiNetNetworkLinear(QUANTIZATION_WEIGHT, PRIME, QUANTIZATION_INPUT)
    flatten = True
elif MODE == 2:
    model = FiniteFieldPiNetNetworkLeNet(QUANTIZATION_WEIGHT, PRIME, QUANTIZATION_INPUT)
    flatten = False
elif MODE == 3:
    model = FiniteFieldPiNetNetworkLeNetCIFAR10(QUANTIZATION_WEIGHT, PRIME, QUANTIZATION_INPUT)
    flatten = False
elif MODE == 'debug':
    model = FiniteFieldPiNetNetworkDebug(QUANTIZATION_WEIGHT, PRIME, QUANTIZATION_INPUT)
    flatten = False
elif MODE == 'debug_2':
    model = FiniteFieldPiNetNetworkDebug2(QUANTIZATION_WEIGHT, PRIME, QUANTIZATION_INPUT)
    flatten = False
else:
    model = None
    flatten = True
criterion = FiniteFieldMSELoss(PRIME, QUANTIZATION_WEIGHT, QUANTIZATION_BATCH_SIZE)

In [4]:
# data fetching
load_path = '../../data'
train_data, train_label, test_data, test_label = load_all_data_mnist(load_path, QUANTIZATION_INPUT, QUANTIZATION_WEIGHT, PRIME, flatten=flatten)
train_data, train_label, test_data, test_label = create_batch_data(train_data, train_label, test_data, test_label, BATCH_SIZE)

In [5]:
for epoch in range(EPOCH):
    tot_loss = 0
    for train_idx, (train_data_batch, train_label_batch) in enumerate(zip(train_data, train_label)):
        # train
        preds = model.forward(train_data_batch)

        loss = criterion.forward(preds, train_label_batch)
        tot_loss += loss
        propagated_error = criterion.error_derivative()

        model.backprop(propagated_error)
        model.optimize(LR)
        print('idx: {}, loss: {}'.format(train_idx + 1, loss))
        if (train_idx + 1) % PRINT == 0:
            print('epoch: {}, idx: {}, loss: {}'.format(epoch + 1, train_idx + 1, tot_loss / PRINT))
            tot_loss = 0

        if (train_idx + 1) % PRINT == 0:
            tot_acc = 0
            tot_sample = 0
            for train_acc_idx, (test_data_batch, test_label_batch) in enumerate(zip(test_data, test_label)):
                # train accuracy
                preds = model.forward(test_data_batch)
                real_preds = to_real_domain(preds, QUANTIZATION_WEIGHT, PRIME)
                pred_args = np.argmax(real_preds, axis=1)

                tot_acc += np.count_nonzero(pred_args == test_label_batch)
                tot_sample += test_data_batch.shape[0]

            accuracy = tot_acc / tot_sample
            if train_idx != 0:
                tot_loss = tot_loss / PRINT
            print('epoch: {}, idx: {}, accuracy: {}, loss: {}'.format(epoch + 1, train_idx + 1, accuracy, tot_loss))
            tot_loss = 0

idx: 1, loss: 4.529284656047821
idx: 2, loss: 344523.0065196157
idx: 3, loss: 834972.4083711504
idx: 4, loss: 858051.1641591787
idx: 5, loss: 891226.0604442358
idx: 6, loss: 868316.565542102
idx: 7, loss: 867157.2670397162
idx: 8, loss: 874109.4636483192
idx: 9, loss: 856025.4221897124
idx: 10, loss: 870294.5295047163
idx: 11, loss: 867264.2277969718
idx: 12, loss: 865303.8806425928
idx: 13, loss: 853777.6647468805
idx: 14, loss: 885227.8337913752
idx: 15, loss: 906567.2516385316
idx: 16, loss: 873516.5326761008
idx: 17, loss: 881709.0769244432
idx: 18, loss: 857226.9802350402
idx: 19, loss: 865868.3371910453
idx: 20, loss: 897636.7145737408
epoch: 1, idx: 20, loss: 800938.9458460063
epoch: 1, idx: 20, accuracy: 0.101, loss: 0.0
idx: 21, loss: 871528.1233577133
idx: 22, loss: 886244.2541662456
idx: 23, loss: 894427.5418386458
idx: 24, loss: 865978.8637493849
idx: 25, loss: 888552.5914028882
idx: 26, loss: 869981.2278617622
idx: 27, loss: 842376.3674026728
idx: 28, loss: 857729.33624607