# Corella: A Private Multi Server Learning Approach based on Correlated Queries

## Read Me!

This code is provided in a Jupyter notebook file.

Prerequisites: Python 3.6+  ||  PyTorch 1.0+ || NumPy

Our system device: cuda

## -------------------------- Input --------------------------

---- Input: DatasetName, [arg1,arg2], $\sigma_{\bar{Z}}^*$, $\mathrm{W}$

---- Default example: MNIST [Iden,Iden] 70 [[1.0,-1.0]]

### details:

---- DatasetName: (MNIST, Fashion-MNIST, and Cifar-10)

Note: We use their standard training set and test set. We set the training batch size equal to 128. The only used preprocessing on the images is Random Crop and Random Horizontal Flip on Cifar10 dataset.

---- [arg1,arg2]: (The network structure) arg1: (Iden, 1, 2, ...) ----->  arg2: (Iden, 10, 11, ...)

Note:  We initialize the network parameters by PyTorch default.

---- $\sigma_{\bar{Z}}^*$

Note: To evaluate the accuracy of the proposed algorithm for a ﬁxed noise standard deviation, $\sigma_{\bar{Z}}^*$, we start training the model from $\sigma_{\bar{Z}}$ = 0 and gradually increase the noise standard deviation with linearly increasing step-size up to $\sigma_{\bar{Z}}$ = $\sigma_{\bar{Z}}^*$, where in each step we run one epoch of learning, and the ﬁnally we report the accuracy at $\sigma_{\bar{Z}}$ = $\sigma_{\bar{Z}}^*$. The sequence of step-sizes are linearly increases, as 0.002,0.004,0.006,... . We also decrease the learning rate from $10^{−3}$ to $2×10^{-5}$ during the training gradually.

Note that in this paper, we concern about the privacy of the client data, and not the training samples.

---- $\mathrm{W}$: (a matrix with $T$ (the number of colluding servers) rows and $N$ (the number of servers) columns for creating $\mathbb{P}_{\mathbf{Z}}$)

### -------------------------- Output --------------------------

Reporting training accuracy during each learning epoch and testing accuracy at the end of each epoch for $\sigma_{\bar{Z}}$s.

Note: If arg1=Iden, you can ignore $\bar{g}_0()$ and $\textsf{Normalized}()$ function, where the dataset is $\textsf{Normalized}$.


# Library

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.datasets
import numpy as np

# device
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print ('Your system: ' + str(device))

# Input

In [0]:
# -------------- Custom examples -------------- #
'''
MNIST [Iden,32] 70 [[1.0,-1.0]]
MNIST [Iden,Iden] 70 [[1.0,-1.0,10000.0]]
MNIST [Iden,Iden] 70 [[1]]
MNIST [Iden,32] 70 [[0.0,np.sqrt(3)/2,-np.sqrt(3)/2],[1.0,-1/2,-1/2]]
Fashion-MNIST [Iden,32] 70 [[1.0,-1.0]]
Cifar-10 [32,Iden] 70 [[1.0,-1.0]]
Cifar-10 [32,128] 70 [[1,-1,10000,1.5,-1.5]]
  .
  .
  .
'''

# input
q = input('Default(d) or Custom(c) input? ')
if q == 'c' :
    DN,NS,Sigma,W = input('Input? ').split()
else:
    DN,NS,Sigma,W = 'MNIST', '[Iden,Iden]', '70', '[[1.0,-1.0]]'

# type fixing
if not(DN == 'MNIST' or DN == 'Fashion-MNIST' or DN == 'Cifar-10'):
    print ('DatasetName is false; MNIST selected as the default.')
    DN = 'MNIST'
NS = list(map(str, NS[1:-1].strip().split(',')))[:2]
Sigma = float(Sigma)
exec('W = ' + W)
W = np.array(W,dtype=float)
N = len(W[0,:])
T = len(W[:,0])
IW = np.ones((N,T+1),dtype=float) 
IW[:,1:] = W.transpose()

print(DN,NS,Sigma)
print ('N = \n {}'.format(N))
print ('T = \n {}'.format(T))
print ('[1,W^T] = \n {}:'.format(IW))

# Dataset

In [0]:
if DN == 'MNIST' :
    # ------------ MNIST ------------ #
    meanI = torch.tensor([0.1307])
    stdI = torch.tensor([0.3040])

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)



elif DN == 'Fashion-MNIST':
    # ------------ Fashion-MNIST ------------ #
    meanI = torch.tensor([0.2860])
    stdI = torch.tensor([0.3205])

    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)



elif DN == 'Cifar-10':
    # ------------ Cifar-10 ------------ #
    meanI = torch.tensor([0.4914, 0.4822, 0.4465])
    stdI = torch.tensor([0.2023, 0.1994, 0.2010])

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = meanI, std = stdI),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


# Network

In [0]:
isRGB = int(DN=='Cifar-10') # MNIST and Fashion-MNIST: 0   # Cifar-10: 1

if NS[1]=='Iden':
    NS1 = 10
else:
    NS1 = int(NS[1])

# ---------------- \bar{g}_0 ---------------- #
if NS[0] == 'Iden':
    class g_0(nn.Module):
        def __init__(self):
            super(g_0, self).__init__()

        def forward(self, X):
            out = X
            return out

else:
    class g_0(nn.Module):
        def __init__(self):
            super(g_0, self).__init__()
            self.client_L1_pad = nn.ConstantPad2d(( 1-isRGB, 0, 1-isRGB, 0),0)
            self.client_L2_conv2d = nn.Conv2d( 1+2*isRGB, int(NS[0]), kernel_size=5, stride=3, padding=0)

        def forward(self, X):
            out = self.client_L1_pad(X)
            out = self.client_L2_conv2d(out)
            out = F.relu(out)
            return out

# ---------------- f_j ---------------- #
if NS[0] == 'Iden':
    class f_j(nn.Module):
        def __init__(self):
            super(f_j, self).__init__()
            self.server_L1_pad = nn.ConstantPad2d(( 1-isRGB, 0, 1-isRGB, 0),0)
            self.server_L2_conv2d = nn.Conv2d( 1+2*isRGB, 64, kernel_size=5, stride=3, padding=0)
            self.server_L3_conv2d = nn.Conv2d( 64, 128, kernel_size=3, stride=1, padding=0)
            self.server_L4_fc = nn.Linear(128 * (7+isRGB) * (7+isRGB), 1024)
            self.server_L5_fc = nn.Linear(1024, NS1)

        def forward(self, X):
            out = self.server_L1_pad(X)
            out = self.server_L2_conv2d(out)
            out = F.relu(out)
            out = self.server_L3_conv2d(out)
            out = F.relu(out)
            out = out.reshape(out.size(0), -1)
            out = self.server_L4_fc(out)
            out = F.relu(out)
            out = self.server_L5_fc(out)
            return out

else:
    class f_j(nn.Module):
        def __init__(self):
            super(f_j, self).__init__()
            self.server_L1_conv2d = nn.Conv2d( int(NS[0]), 128, kernel_size=3, stride=1, padding=0)
            self.server_L2_fc = nn.Linear(128 * (7+isRGB) * (7+isRGB), 1024)
            self.server_L3_fc = nn.Linear(1024, NS1)

        def forward(self, X):
            out = self.server_L1_conv2d(X)
            out = F.relu(out)
            out = out.reshape(out.size(0), -1)
            out = self.server_L2_fc(out)
            out = F.relu(out)
            out = self.server_L3_fc(out)
            return out

# ---------------- \bar{g}_1 ---------------- #
if NS[1] == 'Iden':
    class g_1(nn.Module):
        def __init__(self):
            super(g_1, self).__init__()

        def forward(self, X):
            out = X
            return out

else:
    class g_1(nn.Module):
        def __init__(self):
            super(g_1, self).__init__()
            self.client_L2_fc = nn.Linear(NS1, 10)

        def forward(self, X):
            out = F.relu(X)
            out = self.client_L2_fc(out)
            return out

# ---------------- network ---------------- #
class network(nn.Module):
    def __init__(self):
        super(network, self).__init__()
        # client
        self.client_g_0 = g_0()
        # servers
        for j in range(N):
            exec('self.server{} = f_j()'.format(j+1))
        # client
        self.client_g_1 = g_1()

    def forward(self, X,Z_scale):
        # \bar{g}_0
        U = self.client_g_0(X)
        # normalize
        U = U - torch.mean(U,[1,2,3],keepdim=True) * torch.ones(U.size()).to(device) # zero mean
        U = U / torch.std(U,[1,2,3],keepdim=True) * torch.ones(U.size()).to(device) # set variance to 1
        # add noise (queries)
        Q = []
        Z = (Z_scale * torch.randn( list(U.size())+[T] )).to(device) # primary noise
        for j in range(N):
            q = U * IW[j,0]
            for t in range(T):
                q += Z[:,:,:,:,t] * IW[j,t+1]
            normP = torch.sqrt(torch.tensor(  IW[j,0]**2  + sum((Z_scale*IW[j,1:])**2)  )).to(device) # for normalizing the queries
            Q += [ q / normP ]
        # answers
        A = []
        for j in range(N):
            exec('A += [self.server{}(Q[{}])]'.format(j+1,j))
        # sumA
        sumA = A[0]
        for j in range(1,N):
            sumA += A[j]
        # out
        out = self.client_g_1(sumA)
        return out

# ---------------- model ---------------- #
model = network().to(device)
if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True
print(model)

# Training and Evaluating

In [0]:
# variables storing 
Train_store = {"Acc":[], "Loss":[]}
Test_store = {"Acc":[], "Loss":[]}
ZL_store = {"ZR":[], "LR":[]}


In [0]:
# loss Function
criterion = nn.CrossEntropyLoss()


### Training code

In [0]:
def train_model():
    loss_sum = 0
    correct_sum = 0
    sample_num = 0

    model.train()
    for (Step_images, Step_labels) in trainloader:
        # data
        images, labels = Step_images.to(device), Step_labels.to(device)
        # forward
        outputs = model(images,ZR)
        loss = criterion(outputs, labels)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # track
        sample_num += labels.size(0)
        loss_sum += (criterion(outputs, labels).item()) * labels.size(0)
        correct_sum += (torch.max(outputs.data, 1)[1] == labels).sum().item()

    # print
    print('NewEpoch: [{}/{}] '.format(epoch_num+1,TotalEpochs), end=' ')
    print('NoiseScale: {:.3f} LearningRate: {:.2e}'.format(ZR,LR), end=' ')
    print('Loss:',end=' ')
    print('{:.4f}'.format(float(loss_sum)/float(sample_num)), end=' ')
    print('Acc:',end=' ')
    print('{:.2f}'.format(float(correct_sum)/float(sample_num)*100), end=' ')
    print('')

    # store
    Train_store["Loss"].append(float(loss_sum)/float(sample_num))
    Train_store["Acc"].append(float(correct_sum)/float(sample_num)*100)


### Test code

In [0]:
def test_model():
    loss_sum = 0
    correct_sum = 0
    sample_num = 0
    
    model.eval()
    with torch.no_grad():
        for (Step_images, Step_labels) in testloader:
            # data
            images, labels = Step_images.to(device), Step_labels.to(device)
            # forward
            outputs = model(images,ZR)
            # track
            sample_num += labels.size(0)
            loss_sum += (criterion(outputs, labels).item()) * labels.size(0)
            correct_sum += (torch.max(outputs.data, 1)[1] == labels).sum().item()

    # print
    print('Test on {} images'.format(sample_num), end=' ')
    print('NoiseScale: {:.3f}'.format(ZR), end=' ')
    print('Loss:',end=' ')
    print('{:.4f}'.format(float(loss_sum)/float(sample_num)), end=' ')
    print('Acc:',end=' ') 
    print('{:.2f}'.format(float(correct_sum)/float(sample_num)*100), end=' ')  
    print('')
    print('')

    # store
    Test_store["Loss"].append(float(loss_sum)/float(sample_num))
    Test_store["Acc"].append(float(correct_sum)/float(sample_num)*100)
    ZL_store["ZR"].append(ZR)
    ZL_store["LR"].append(LR)


### Run

In [0]:
# total epochs
TotalEpochs = int(np.ceil((-1+np.sqrt(1+8/0.002*Sigma))/2)) # solving [0.002 * (n * (n + 1) / 2) = Sigma]
# learning rates
LRs = np.logspace(np.log10(1e-3), np.log10(2e-5), num=TotalEpochs)
# noise scale
step_size = 0
ZR = 0

for epoch_num in range(TotalEpochs):
    # noise scale
    step_size += 0.002
    ZR += step_size
    # learning rate
    LR = LRs[epoch_num]
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    # train
    train_model()
    # test
    test_model()