# Encrypted Computation on MNIST using SyMPC
#### author: Oleksandr Lytvyn

In [2]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

Load MNIST dataset

In [3]:
torch.manual_seed(73)

train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


1.0%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw


102.8%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz



7.9%

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


112.7%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Processing...



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


#### Define model

In [4]:
class ConvNet(torch.nn.Module):
    def __init__(self, hidden=64, output=10):
        super(ConvNet, self).__init__()        
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3)
        self.fc1 = torch.nn.Linear(256, hidden)
        self.fc2 = torch.nn.Linear(hidden, output)

    def forward(self, x):
        x = self.conv1(x)
        # the model uses the square activation function
        x = x * x
        # flattening while keeping the batch axis
        x = x.view(-1, 256)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x

#### Define Train Function

In [7]:
def train(model, train_loader, criterion, optimizer, n_epochs=10):
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
    
    model.eval()
    return model

#### Perform Train

In [8]:
model = ConvNet()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = train(model, train_loader, criterion, optimizer, 10)

Epoch: 1 	Training Loss: 0.392145
Epoch: 2 	Training Loss: 0.131439
Epoch: 3 	Training Loss: 0.090824
Epoch: 4 	Training Loss: 0.070182
Epoch: 5 	Training Loss: 0.059312
Epoch: 6 	Training Loss: 0.049882
Epoch: 7 	Training Loss: 0.045490
Epoch: 8 	Training Loss: 0.038414
Epoch: 9 	Training Loss: 0.035350
Epoch: 10 	Training Loss: 0.032657


#### Perform Model test

In [9]:
def test(model, test_loader, criterion):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    # model in evaluation mode
    model.eval()

    for data, target in test_loader:
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # calculate and print avg test loss
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )
    
test(model, test_loader, criterion)

Test Loss: 0.089131

Test Accuracy of 0: 99% (971/980)
Test Accuracy of 1: 99% (1129/1135)
Test Accuracy of 2: 97% (1010/1032)
Test Accuracy of 3: 98% (997/1010)
Test Accuracy of 4: 98% (963/982)
Test Accuracy of 5: 97% (872/892)
Test Accuracy of 6: 98% (939/958)
Test Accuracy of 7: 97% (1000/1028)
Test Accuracy of 8: 96% (940/974)
Test Accuracy of 9: 96% (969/1009)

Test Accuracy (Overall): 97% (9790/10000)


### Encrypted inference

In [2]:
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.protocol import FSS
from sympc.protocol import Falcon

In [11]:
def get_clients(n_parties):
  #Generate required number of syft clients and return them.

  parties=[]
  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())

  return parties

In [14]:
def split_send(data,session):
    """Splits data into number of chunks equal to number of parties and distributes it to respective 
       parties.
    """
    data_pointers = []
    
    split_size = int(len(data)/len(session.parties))+1
    for index in range(0,len(session.parties)):
        ptr=data[index*split_size:index*split_size+split_size].share(session=session)
        data_pointers.append(ptr)
        
    return data_pointers

In [18]:
def inference(n_clients,protocol=None):
    
  # Get VM clients 
  parties=get_clients(n_clients)

  # Setup the session for the computation
  if(protocol):
     session = Session(parties = parties,protocol = protocol)
  else:
     session = Session(parties = parties)
        
  SessionManager.setup_mpc(session)

  #Split data and send data to clients
  pointers = split_send(test_data,session)

  #Encrypt model 
  mpc_model = model.share(session)

  #Encrypt test data
  #test_data=MPCTensor(secret=test_x, session = session)

  #Perform inference and measure time taken
  start_time = time.time()
    
  results = []
    
  for ptr in pointers:
     encrypted_results = mpc_model(ptr)
     plaintext_results = encrypted_results.reconstruct()
     results.append(plaintext_results)
        
  end_time = time.time()

  print(f"Time for inference: {end_time-start_time}s")
    
  predictions = torch.cat(results).reshape([-1])

  #Calculate Loss
  print("MSE Loss: ",criterion(predictions,test_y).item())
    
  return predictions

In [19]:
predictions=inference(3, Falcon("semi-honest"))

ValueError: only one element tensors can be converted to Python scalars

In [24]:
test_loader

TypeError: 'DataLoader' object is not subscriptable