In [1]:
import numpy as np
from datetime import datetime

In [2]:
from pond.tensor import NativeTensor, PrivateEncodedTensor, PublicEncodedTensor
from pond.nn import Dense, Sigmoid, Reveal, Diff, Softmax, CrossEntropy, Sequential, DataLoader

In [3]:
classifier = Sequential([
    Dense(128, 6272),
    Sigmoid(),
    # Dropout(.5),
    Dense(5, 128),
    Reveal(),
    Softmax()
])

# Load sample data

In [4]:
def predict(classifier, wrapper, sample_index):
    x = np.load('x_test_features.npy')[sample_index].reshape(1, -1)
    y = np.argmax(np.load('y_test.npy')[sample_index])

    likelihoods = classifier.predict(wrapper(x))
    y_predicted = np.argmax(likelihoods.unwrap())

    return y, y_predicted

# Perform prediction using unencrypted weights

In [5]:
classifier.layers[0].weights = PublicEncodedTensor.from_elements(np.load('layer0_weights.npy'))
classifier.layers[0].bias = PublicEncodedTensor.from_elements(np.load('layer0_bias.npy'))

classifier.layers[2].weights = PublicEncodedTensor.from_elements(np.load('layer2_weights.npy'))
classifier.layers[2].bias = PublicEncodedTensor.from_elements(np.load('layer2_bias.npy'))

In [6]:
for sample in range(10):
    y_correct, y_predicted = predict(classifier, PublicEncodedTensor, sample)
    print(y_correct, y_predicted)

2 2
4 4
0 1
4 4
1 1
4 4
0 0
4 4
2 2
4 4


# Perform prediction using encrypted weights

In [7]:
classifier.layers[0].weights = PrivateEncodedTensor.from_shares(np.load('layer0_weights_0.npy'), np.load('layer0_weights_1.npy'))
classifier.layers[0].bias = PrivateEncodedTensor.from_shares(np.load('layer0_bias_0.npy'), np.load('layer0_bias_1.npy'))

classifier.layers[2].weights = PrivateEncodedTensor.from_shares(np.load('layer2_weights_0.npy'), np.load('layer2_weights_1.npy'))
classifier.layers[2].bias = PrivateEncodedTensor.from_shares(np.load('layer2_bias_0.npy'), np.load('layer2_bias_1.npy'))

In [8]:
for sample in range(10):
    y_correct, y_predicted = predict(classifier, PrivateEncodedTensor, sample)
    print(y_correct, y_predicted)

2 2
4 4
0 1
4 4
1 1
4 4
0 0
4 4
2 2
4 4
