# Encrypted Computation on MNIST using SyMPC
#### author: Oleksandr Lytvyn
CONTEXT
One party has a CNN model trained on MNIST dataset, other party wants
to make predictions on the trained model. First party do not want to share
model hyper-parameters, weights, etc.

EXPERIMENT SCENARIO
1. Prepare Data:
    1. download dataset
    2. train_data - first party
    3. test_data - second party
2. Spread data between parties
3. CNN definition and Training
4. Encrypted Prediction

In [47]:
import syft as sy
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import numpy as np


#### Load MNIST dataset and prepare Dataloader

In [26]:
train_data = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)
test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = ToTensor()
)
loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),

    'test'  : torch.utils.data.DataLoader(test_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),
}
loaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x18a4b8be940>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x18a4df6c280>}

#### Define model

In [74]:
import torch.nn as nn
import torch.nn.functional as F
class CNN(sy.Module):
    def __init__(self, torch_ref):
        super(CNN, self).__init__(torch_ref=torch_ref)
        self.conv1 = nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2)
        self.relu1=nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
        self.relu2=nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(2)
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x    # return x for visualization

In [75]:
cnn = CNN(torch)
print(cnn)

<__main__.CNN object at 0x0000018A541A90A0>


#### Define Train Function

In [76]:
from torch import optim
from torch.autograd import Variable
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = 0.01)

def train(num_epochs, cnn, loaders):

    cnn.train()

    # Train the model
    total_step = len(loaders['train'])

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):

            # gives batch data, normalize x when iterate train_loader
            b_x = Variable(images)   # batch x
            b_y = Variable(labels)   # batch y
            output = cnn(b_x)[0]
            loss = loss_func(output, b_y)

            # clear gradients for this training step
            optimizer.zero_grad()

            # backpropagation, compute gradients
            loss.backward()
            # apply gradients
            optimizer.step()

            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                       .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
                pass
        pass


#### Perform Train

In [77]:
num_epochs = 5
train(num_epochs, cnn, loaders)

Epoch [1/5], Step [100/600], Loss: 0.1256
Epoch [1/5], Step [200/600], Loss: 0.0542
Epoch [1/5], Step [300/600], Loss: 0.1522
Epoch [1/5], Step [400/600], Loss: 0.0289
Epoch [1/5], Step [500/600], Loss: 0.1084
Epoch [1/5], Step [600/600], Loss: 0.1309
Epoch [2/5], Step [100/600], Loss: 0.0120
Epoch [2/5], Step [200/600], Loss: 0.0373
Epoch [2/5], Step [300/600], Loss: 0.0070
Epoch [2/5], Step [400/600], Loss: 0.0669
Epoch [2/5], Step [500/600], Loss: 0.1413
Epoch [2/5], Step [600/600], Loss: 0.0787
Epoch [3/5], Step [100/600], Loss: 0.0250
Epoch [3/5], Step [200/600], Loss: 0.0369
Epoch [3/5], Step [300/600], Loss: 0.0147
Epoch [3/5], Step [400/600], Loss: 0.0085
Epoch [3/5], Step [500/600], Loss: 0.0591
Epoch [3/5], Step [600/600], Loss: 0.0340
Epoch [4/5], Step [100/600], Loss: 0.0086
Epoch [4/5], Step [200/600], Loss: 0.0323
Epoch [4/5], Step [300/600], Loss: 0.0084
Epoch [4/5], Step [400/600], Loss: 0.0142
Epoch [4/5], Step [500/600], Loss: 0.0395
Epoch [4/5], Step [600/600], Loss:

#### Perform Model test

In [79]:
def test():
    # Test the model
    cnn.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in loaders['test']:
            test_output, last_layer = cnn(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
            pass
    print('Test Accuracy of the model on the 10000 test images: %.2f' % accuracy)

test()

Test Accuracy of the model on the 10000 test images: 1.00


10000

ValueError: only one element tensors can be converted to Python scalars

### Encrypted inference

In [42]:
import syft as sy
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.protocol import FSS
from sympc.protocol import Falcon

In [80]:
def get_clients(n_parties):
  #Generate required number of syft clients and return them.

  parties=[]
  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())

  return parties

In [81]:
def split_send(data,session):
    """Splits data into number of chunks equal to number of parties and distributes it to respective 
       parties.
    """
    data_pointers = []
    
    split_size = int(len(data)/len(session.parties))+1
    for index in range(0,len(session.parties)):
        ptr=data[index*split_size:index*split_size+split_size].share(session=session)
        data_pointers.append(ptr)
        
    return data_pointers

In [82]:
import time
def inference(n_clients, model,protocol=None):
    
  # Get VM clients 
  parties=get_clients(n_clients)

  # Setup the session for the computation
  if(protocol):
     session = Session(parties = parties,protocol = protocol)
  else:
     session = Session(parties = parties)
        
  SessionManager.setup_mpc(session)

  #Split data and send data to clients
  imgs, lbls = next(iter(loaders['test']))
  actual_number = lbls[:10].numpy()
  pointers = split_send(imgs,session)

  #Encrypt model 
  mpc_model = model.share(session)

  #Encrypt test data
  #test_data=MPCTensor(secret=test_x, session = session)

  #Perform inference and measure time taken
  start_time = time.time()
    
  results = []
    
  for ptr in pointers:
     encrypted_results = mpc_model(ptr)
     plaintext_results = encrypted_results.reconstruct()
     results.append(plaintext_results)
        
  end_time = time.time()

  print(f"Time for inference: {end_time-start_time}s")
    

  pred_y = torch.max(results, 1)[1].data.numpy().squeeze()
  print(f'Prediction number: {pred_y}')
  print(f'Actual number: {actual_number}')
    
  return predictions

In [83]:
predictions=inference(3, cnn, Falcon("semi-honest"))

KeyError: 'ReLU'