In [1]:
#Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
import syft as sy
hook = sy.TorchHook(torch)
client = sy.VirtualWorker(hook, id='client')
bob = sy.VirtualWorker(hook, id='bob')
alice = sy.VirtualWorker(hook, id='alice')
crypto_provider = sy.VirtualWorker(hook, id='crypto_provider')

W0802 14:42:09.632205  1724 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\Users\Vilas_2\Anaconda3\envs\pysyft\lib\site-packages\tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0802 14:42:09.727210  1724 deprecation_wrapper.py:119] From C:\Users\Vilas_2\Anaconda3\envs\pysyft\lib\site-packages\tf_encrypted\session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [4]:
#Set the learning task
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 50
        self.epochs = 10
        self.lr = 0.001
        self.log_interval = 100
        
args = Arguments()

In [12]:
#Data loading and sending to workers
train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('~/.pytorch/MNIST_data/', train=True, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))
                              ])),
                batch_size=args.batch_size, shuffle=True)

Then, client has some data and would like to have predictions on it using the server's model. This client encrypts its data by sharing it additively across two workers alice and bob

SMPC uses crypto protocols which require to work on integers. We leverage here the pysyft tensor abstraction to convert PyTorch Float tensors into Fixed Precission Tensors using .fix_precision(). 
For example 0.123 with precission 2 does a rounding at the 2nd decimal digit so the number stored is the integer 12.

In [13]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/.pytorch/MNIST_data/', train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,), (0.3081,))
                  ])),
    batch_size=args.test_batch_size, shuffle=True)

private_test_loader = []
for data, target in test_loader:
    private_test_loader.append((
    data.fix_precision().share(alice, bob, crypto_provider=crypto_provider),
    target.fix_precision().share(alice, bob, crypto_provider = crypto_provider)
    ))

In [14]:
#Feed forward NN
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [15]:
#Launch the training
def train(args, model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        output = F.log_softmax(output, dim=1)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval==0:
            print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,
                    100. * batch_idx / len(train_loader), loss.item()))

In [16]:
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    train(args, model, train_loader, optimizer, epoch)



In [19]:
def test(args, model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            output = F.log_softmax(output, dim=1)
            test_loss += F.nll_loss(output, target, reduction='sum').item() #sum up batch loss
            pred = output.argmax(1, keepdim=True) #get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy:{}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

In [20]:
test(args, model, test_loader)


Test set: Average loss: 0.0755, Accuracy:9820/10000 (98%)



Model is now trained and ready to be provided as service

Secure Evaluation

In [21]:
model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)

Net(
  (fc1): Linear(in_features=784, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [27]:
def test(args, model, test_loader):
    model.eval()
    n_correct_priv = 0
    n_total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            n_correct_priv += pred.eq(target.view_as(pred)).sum()
            n_total += args.test_batch_size
            
#This 'test' fc performs the encrypted evaluation. The model weights, the data inputs, the prediction 
#and the target used for scoring are all encrypted

#However as you can observe, the syntax is very similar to normal Pytorch testing

#The only thing we decrypt from the server side is the final score at the end of our 200 items batches
#to verify predictions were on average good

            n_correct = n_correct_priv.copy().get().float_precision().long().item()
            
            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
                 n_correct, n_total,
                 100. * n_correct/ n_total))

In [28]:
test(args, model, private_test_loader)

Test set: Accuracy: 50/50 (100%)
Test set: Accuracy: 96/100 (96%)
Test set: Accuracy: 146/150 (97%)
Test set: Accuracy: 195/200 (98%)
Test set: Accuracy: 245/250 (98%)
Test set: Accuracy: 294/300 (98%)
Test set: Accuracy: 343/350 (98%)
Test set: Accuracy: 391/400 (98%)
Test set: Accuracy: 441/450 (98%)
Test set: Accuracy: 491/500 (98%)
Test set: Accuracy: 540/550 (98%)
Test set: Accuracy: 589/600 (98%)
Test set: Accuracy: 639/650 (98%)
Test set: Accuracy: 687/700 (98%)
Test set: Accuracy: 736/750 (98%)
Test set: Accuracy: 786/800 (98%)
Test set: Accuracy: 836/850 (98%)
Test set: Accuracy: 886/900 (98%)
Test set: Accuracy: 936/950 (99%)
Test set: Accuracy: 983/1000 (98%)
Test set: Accuracy: 1033/1050 (98%)
Test set: Accuracy: 1082/1100 (98%)
Test set: Accuracy: 1132/1150 (98%)
Test set: Accuracy: 1181/1200 (98%)
Test set: Accuracy: 1230/1250 (98%)
Test set: Accuracy: 1279/1300 (98%)
Test set: Accuracy: 1328/1350 (98%)
Test set: Accuracy: 1375/1400 (98%)
Test set: Accuracy: 1424/1450 (98