# Trained-MPC: A Private Inference by Training-Based Multiparty Computation

## Read Me!

This code is provided in a Jupyter notebook file.

Prerequisites: Python 3.6+  ||  PyTorch 1.0+ || NumPy

Our system device: cuda

## -------------------------- Input --------------------------

---- Input: DatasetName, [arg1,arg2], $\sigma$, $\mathrm{W}$

---- Default example: MNIST [Iden,Iden] 70 [[1.0,-1.0]]

### details:

---- DatasetName: (MNIST, Fashion-MNIST, and Cifar-10)

Note: We use their datasets by using their standard training sets and testing sets. The only used preprocessings on images are Random Crop and Random Horizontal Flip on Cifar-10 training dataset and padding MNIST and Fashion-MNIST images on all sides with zeros of length $2$ to fit in a network with input size $32 \times 32$.

---- [arg1,arg2]: (The network structure) arg1: (Iden, 1, 2, ...) ----->  arg2: (Iden, 10, 11, ...)

Note:  We initialize the network parameters by PyTorch default.

---- $\sigma$

Note: For each value of the standard deviation of the noise, we continue the learning process for 265 epochs. We decrease the learning rate from $10^{-3}$ to $2 \times 10^{-5}$ exponentially during the training.

---- $\mathrm{W}$: (a matrix with $T$ (the number of colluding servers) rows and $N$ (the number of servers) columns for creating $\mathbb{P}_{\mathbf{Z}}$)

### -------------------------- Output --------------------------

Reporting training accuracy during each learning epoch and testing accuracy at the end of each epoch for $\sigma$s.


# Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision
import torch.backends.cudnn as cudnn

import os
import numpy as np
import matplotlib.pyplot as plt

import datetime
import time
import itertools

# device
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print ('Your system: ' + str(device))

# Input parameters

In [None]:
import numpy as np

# -------------- Custom examples -------------- #
'''
MNIST [Iden,32] 70 [[1.0,-1.0]]
MNIST [Iden,Iden] 70 [[1.0,-1.0,10000.0]]
MNIST [Iden,Iden] 70 [[1]]
MNIST [Iden,32] 70 [[0.0,np.sqrt(3)/2,-np.sqrt(3)/2],[1.0,-1/2,-1/2]]
Fashion-MNIST [Iden,32] 70 [[1.0,-1.0]]
Cifar-10 [32,Iden] 70 [[1.0,-1.0]]
Cifar-10 [32,128] 70 [[1,-1,10000,1.5,-1.5]]
  .
  .
  .
'''

# input
q = input('Default(d) or Custom(c) input? ')
if q == 'c' :
    DN,NS,Sigma,W = input('Input? ').split()
else:
    DN,NS,Sigma,W = 'MNIST', '[Iden,Iden]', '70', '[[1.0,-1.0]]'

# type fixing
if not(DN == 'MNIST' or DN == 'Fashion-MNIST' or DN == 'Cifar-10' or DN == 'CelebA'):
    print ('DatasetName is false; MNIST selected as the default.')
    DN = 'MNIST'
NS = list(map(str, NS[1:-1].strip().split(',')))[:2]
Sigma = float(Sigma)
exec('W = ' + W)
W = np.array(W,dtype=float)
N = len(W[0,:])
T = len(W[:,0])
IW = np.ones((N,T+1),dtype=float)
IW[:,1:] = W.transpose()

print(DN,NS,Sigma)
print ('N = \n {}'.format(N))
print ('T = \n {}'.format(T))
print ('[1,W^T] = \n {}:'.format(IW))

# Dataset

In [None]:
if DN == 'MNIST':
    # ------------ MNIST ------------ #
    meanI = torch.tensor([0.1307])
    stdI = torch.tensor([0.3040])

    transform_train = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)



elif DN == 'Fashion-MNIST':
    # ------------ Fashion-MNIST ------------ #
    meanI = torch.tensor([0.2860])
    stdI = torch.tensor([0.3205])

    transform_train = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)



elif DN == 'Cifar-10':
    # ------------ Cifar-10 ------------ #
    meanI = torch.tensor([0.4914, 0.4822, 0.4465])
    stdI = torch.tensor([0.2023, 0.1994, 0.2010])

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


# Functions

In [None]:
# train function
def train_function(trainloader, model, optimizer, criterion):
    loss_sum = 0
    correct_sum = 0
    sample_num = 0

    model.train()
    for batch_idx, (images, labels) in enumerate(trainloader):
        # data
        images, labels = images.to(device), labels.to(device)
        # forward
        outputs = model(images)
        loss = criterion(outputs, labels)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # track
        sample_num += labels.size(0)
        loss_sum += (loss.item()) * labels.size(0)
        predicted_label = torch.max(outputs.data, 1)[1]
        correct_sum += (predicted_label == labels).sum().item()

    # result
    avg_loss = float(loss_sum) / float(sample_num)
    avg_acc = float(correct_sum) / float(sample_num) * 100.0

    return avg_loss, avg_acc


# test function
def test_function(testloader, model, criterion):
    loss_sum = 0
    correct_sum = 0
    sample_num = 0
    confusion_data = [[None, None]] # [target,predicted]

    model.eval()
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(testloader):
            # data
            images, labels = images.to(device), labels.to(device)
            # forward
            outputs = model(images)
            # track
            sample_num += labels.size(0)
            loss_sum += (criterion(outputs, labels).item()) * labels.size(0)
            predicted_label = torch.max(outputs.data, 1)[1]
            correct_sum += (predicted_label == labels).sum().item()

    # result
    avg_loss = float(loss_sum) / float(sample_num)
    avg_acc = float(correct_sum) / float(sample_num) * 100.0

    return avg_loss, avg_acc


# run function
def run(track, trainloader, testloader, model, optimizer, criterion, epoch_num, scheduler):
    # trackers
    track = {'loss_train':[], 'acc_train':[], 'loss_test':[], 'acc_test':[], 'confusion_matrix':[]} if track == None else track

    for epoch in range(0,epoch_num):
        # train & test
        loss_train, acc_train = train_function(trainloader, model, optimizer, criterion)
        loss_test, acc_test = test_function(testloader, model, criterion)

        # learning rate
        print(f"lr = {scheduler.get_last_lr()[0]}")
        scheduler.step()

        # track
        track['loss_train'].append(loss_train)
        track['acc_train'].append(acc_train)
        track['loss_test'].append(loss_test)
        track['acc_test'].append(acc_test)
        print(f'Epoch [{epoch+1}/{epoch_num}] | Loss -> Test {loss_test:.4f} Train {loss_train:.4f} | Acc -> Test {acc_test:.3f} Train {acc_train:.3f} |')

    return track

# Network

In [None]:
isRGB = int(DN=='Cifar-10') # MNIST and Fashion-MNIST: 0   # Cifar-10: 1

if NS[1]=='Iden':
    NS1 = 10
else:
    NS1 = int(NS[1])

# ---------------- \bar{g}_Pre ---------------- #
if NS[0] == 'Iden':
    class g_Pre(nn.Module):
        def __init__(self):
            super(g_Pre, self).__init__()

        def forward(self, X):
            out = X
            return out

else:
    class g_Pre(nn.Module):
        def __init__(self):
            super(g_Pre, self).__init__()
            self.client_L1_conv2d = nn.Conv2d(1+2*isRGB, int(NS[0]), kernel_size=5, stride=3, padding=0)
            self.b1 = nn.BatchNorm2d(int(NS[0]))

        def forward(self, X):
            out = self.client_L1_conv2d(X)
            out = self.b1(out)
            out = F.relu(out)
            return out

# ---------------- f_j ---------------- #
if NS[0] == 'Iden':
    class f_j(nn.Module):
        def __init__(self):
            super(f_j, self).__init__()
            self.server_L1_conv2d = nn.Conv2d(1+2*isRGB, 64, kernel_size=5, stride=3, padding=0)
            self.b1 = nn.BatchNorm2d(64)
            self.server_L2_conv2d = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0)
            self.b2 = nn.BatchNorm2d(128)
            self.server_L3_fc = nn.Linear(128 * 8 * 8, 1024)
            self.b3 = nn.BatchNorm1d(1024)
            self.server_L4_fc = nn.Linear(1024, NS1)

        def forward(self, X):
            out = self.server_L1_conv2d(X)
            out = self.b1(out)
            out = F.relu(out)
            out = self.server_L2_conv2d(out)
            out = self.b2(out)
            out = F.relu(out)
            out = out.view(out.size(0), -1)
            out = self.server_L3_fc(out)
            out = self.b3(out)
            out = F.relu(out)
            out = self.server_L4_fc(out)
            return out

else:
    class f_j(nn.Module):
        def __init__(self):
            super(f_j, self).__init__()
            self.server_L1_conv2d = nn.Conv2d(int(NS[0]), 128, kernel_size=3, stride=1, padding=0)
            self.b1 = nn.BatchNorm2d(128)
            self.server_L2_fc = nn.Linear(128 * 8 * 8, 1024)
            self.b2 = nn.BatchNorm1d(1024)
            self.server_L3_fc = nn.Linear(1024, NS1)

        def forward(self, X):
            out = self.server_L1_conv2d(X)
            out = self.b1(out)
            out = F.relu(out)
            out = out.reshape(out.size(0), -1)
            out = self.server_L2_fc(out)
            out = self.b2(out)
            out = F.relu(out)
            out = self.server_L3_fc(out)
            return out

# ---------------- \bar{g}_Post ---------------- #
if NS[1] == 'Iden':
    class g_Post(nn.Module):
        def __init__(self):
            super(g_Post, self).__init__()

        def forward(self, X):
            out = X
            return out

else:
    class g_Post(nn.Module):
        def __init__(self):
            super(g_Post, self).__init__()
            self.bn = nn.BatchNorm1d(NS1)
            self.client_L2_fc = nn.Linear(NS1, 10)

        def forward(self, X):
            out = self.bn(X)
            out = F.relu(out)
            out = self.client_L2_fc(out)
            return out

# ---------------- network ---------------- #
class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        # client
        self.client_g_Pre = g_Pre()
        # servers
        for j in range(N):
            exec('self.server{} = f_j()'.format(j+1))
        # client
        self.client_g_Post = g_Post()

    def forward(self, X):
        # \bar{g}_0
        U = self.client_g_Pre(X)
        # normalize
        U = U - (torch.mean(U,[1,2,3],keepdim=True) * torch.ones(U.size()).to(device)).detach() # zero mean
        U = U / (torch.std(U,[1,2,3],keepdim=True) * torch.ones(U.size()).to(device)).detach() # set variance to 1
        # add noise (queries)
        Q = []
        Z = (Z_scale * torch.randn( list(U.size())+[T] )).to(device) # primary noise
        for j in range(N):
            q = U * IW[j,0]
            for t in range(T):
                q += Z[:,:,:,:,t] * IW[j,t+1]
            normP = torch.sqrt(torch.tensor(  IW[j,0]**2  + sum((Z_scale*IW[j,1:])**2)  )).to(device) # for normalizing the queries
            Q += [ q / normP ]
        # answers
        A = []
        for j in range(N):
            exec('A += [self.server{}(Q[{}])]'.format(j+1,j))
        # sumA
        sumA = A[0]
        for j in range(1,N):
            sumA += A[j]
        # out
        out = self.client_g_Post(sumA)
        return out


In [None]:
# ---------------- model ---------------- #
model = network().to(device)
if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
print(model)

# -------------- criterion -------------- #
criterion = nn.CrossEntropyLoss().to(device)

# -------------- optimizer -------------- #
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9853)

# ---------------- track ---------------- #
track = None # is global variable

# Run

In [None]:
epoch_num = 265
Z_scale = Sigma
track = run(track, trainloader, testloader, model, optimizer, criterion, epoch_num, scheduler)
