# Encrypted Inference-Linear Regression

Encrypted Inference is the process of performing inference with machine learning models such that model owner cannot observe the true input data nor can the data owner see the true model weights. The weights and data are encrypted by splitting them into shares and performiming computations according to a protocol. The general class of methods know as Secure Multi Party Computation (SMPC).

SOURCE: https://github.com/OpenMined/SyMPC/blob/main/examples/Encrypted-Inference-LinearRegression.ipynb

### Imports

In [28]:
#External libraries
import pandas as pd
import numpy as np
import time

#Import torch
import torch
import torch.nn as nn
import torch.utils.data as data_utils
torch.manual_seed(0)

#Import syft
import syft as sy
sy.logger.remove()

### Data Loading and Processing

In [2]:
#Download Boston housing dataset
#!wget https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data

--2021-09-28 21:02:37--  https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252, 52.35.62.125, 128.200.200.146
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49082 (48K) [application/x-httpd-php]
Saving to: ‘housing.data.1’


2021-09-28 21:02:38 (69.3 KB/s) - ‘housing.data.1’ saved [49082/49082]



In [29]:
#Import dataset and add headers
dataset=pd.read_csv("housing.data",delim_whitespace=True,
                    names=["crim","zn","indus",
                           "chas","nox","rm",
                           "age","dis","rad",
                           "tax","ptratio","black",
                           "lstat","medv"])

In [30]:
#Visualize and look at columns and rows of dataset
dataset.head()

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat,medv
0,0.00632,18.0,2.31,0,0.538,6.575,65.2,4.09,1,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0,0.469,6.421,78.9,4.9671,2,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0,0.469,7.185,61.1,4.9671,2,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0,0.458,6.998,45.8,6.0622,3,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0,0.458,7.147,54.2,6.0622,3,222.0,18.7,396.9,5.33,36.2


In [31]:
#Split data into features and target variables
features = dataset.drop("medv",axis=1)
targets = dataset["medv"]

In [32]:
#Normalize features
features = features.apply(
    lambda x: (x - x.mean()) / x.std()
)

In [33]:
features

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,black,lstat
0,-0.419367,0.284548,-1.286636,-0.272329,-0.144075,0.413263,-0.119895,0.140075,-0.981871,-0.665949,-1.457558,0.440616,-1.074499
1,-0.416927,-0.487240,-0.592794,-0.272329,-0.739530,0.194082,0.366803,0.556609,-0.867024,-0.986353,-0.302794,0.440616,-0.491953
2,-0.416929,-0.487240,-0.592794,-0.272329,-0.739530,1.281446,-0.265549,0.556609,-0.867024,-0.986353,-0.302794,0.396035,-1.207532
3,-0.416338,-0.487240,-1.305586,-0.272329,-0.834458,1.015298,-0.809088,1.076671,-0.752178,-1.105022,0.112920,0.415751,-1.360171
4,-0.412074,-0.487240,-1.305586,-0.272329,-0.834458,1.227362,-0.510674,1.076671,-0.752178,-1.105022,0.112920,0.440616,-1.025487
...,...,...,...,...,...,...,...,...,...,...,...,...,...
501,-0.412820,-0.487240,0.115624,-0.272329,0.157968,0.438881,0.018654,-0.625178,-0.981871,-0.802418,1.175303,0.386834,-0.417734
502,-0.414839,-0.487240,0.115624,-0.272329,0.157968,-0.234316,0.288648,-0.715931,-0.981871,-0.802418,1.175303,0.440616,-0.500355
503,-0.413038,-0.487240,0.115624,-0.272329,0.157968,0.983986,0.796661,-0.772919,-0.981871,-0.802418,1.175303,0.440616,-0.982076
504,-0.407361,-0.487240,0.115624,-0.272329,0.157968,0.724955,0.736268,-0.667776,-0.981871,-0.802418,1.175303,0.402826,-0.864446


In [34]:
targets

0      24.0
1      21.6
2      34.7
3      33.4
4      36.2
       ... 
501    22.4
502    20.6
503    23.9
504    22.0
505    11.9
Name: medv, Length: 506, dtype: float64

In [35]:
#Convert features and targets into torch tensors
features = torch.tensor(features.values.astype(np.float32)) 
targets = torch.tensor(targets.values.astype(np.float32))

In [36]:
#Hyperparameters
batch_size = 16
epochs = 300
train_test_split = 0.8
lr = 0.001

In [37]:
train_indices=int(len(features)*train_test_split)

train_x = features[:train_indices]
train_y = targets[:train_indices]

test_x = features[train_indices+1:]
test_y = targets[train_indices+1:]

In [38]:
def get_batches(X, y):
    batches = []
    for index in range(0,len(train_x)+1,batch_size):
        batches.append((X[index:index+batch_size],y[index:index+batch_size]))
    
    return batches

In [39]:
train_batches = get_batches(train_x, train_y)

### Plaintext training

In [41]:
#Define Linear regression model
class LinearSyNet(sy.Module):
    def __init__(self, torch_ref):
        super(LinearSyNet, self).__init__(torch_ref=torch_ref)
        self.fc1 = self.torch_ref.nn.Linear(13,1)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [42]:
#Define model, loss function and optimizer
model = LinearSyNet(torch)
criterion = torch.nn.MSELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [43]:
#Training Loop
for epoch in range(epochs):
  running_loss = 0.0
  for index in range(0,len(train_batches)):
    # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
    optimizer.zero_grad()

    # get output from the model, given the inputs
    outputs = model(train_batches[index][0]).reshape([-1])

    # get loss for the predicted output
    loss = criterion(outputs,train_batches[index][1])
    running_loss += loss
    # get gradients w.r.t to parameters
    loss.backward()

    # update parameters
    optimizer.step()
    
  test_accuracy = criterion(model(test_x).reshape([-1]),test_y)
  if((epoch%50)==0):
     print(f"Epoch {epoch}/{epochs}  Running Loss : {running_loss.item()/batch_size} and test loss: {test_accuracy.item()}")

Epoch 0/300  Running Loss : 1006.2069091796875 and test loss: 355.1054992675781
Epoch 50/300  Running Loss : 53.09187698364258 and test loss: 110.3436050415039
Epoch 100/300  Running Loss : 43.0759391784668 and test loss: 48.66814422607422
Epoch 150/300  Running Loss : 41.09862518310547 and test loss: 30.31635856628418
Epoch 200/300  Running Loss : 40.13597869873047 and test loss: 23.175142288208008
Epoch 250/300  Running Loss : 39.59026336669922 and test loss: 20.673261642456055


### Plaintext Inference

In [44]:
#Perform inference in plaintext
start_time=time.time()
plaintext_predictions = model(test_x).reshape([-1])
end_time=time.time()

In [45]:
#Calculate inference time and MSELoss
print("MSE Loss: ",criterion(plaintext_predictions,test_y).item())
print("Inference time: ",str(end_time-start_time),"s")

MSE Loss:  20.229232788085938
Inference time:  0.001733541488647461 s


### Encrypted Inference

In [46]:
#SyMPC imports required for encrypted inference
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import MPCTensor
from sympc.protocol import FSS,Falcon
# from sympc.protocol import Falcon,FSS

In [47]:
def get_clients(n_parties):
  #Generate required number of syft clients and return them.

  parties=[]
  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())

  return parties

In [48]:
def split_send(data,session):
    """Splits data into number of chunks equal to number of parties and distributes it to respective 
       parties.
    """
    data_pointers = []
    
    split_size = int(len(data)/len(session.parties))+1
    for index in range(0,len(session.parties)):
        ptr=data[index*split_size:index*split_size+split_size].share(session=session)
        data_pointers.append(ptr)
        
    return data_pointers

In [23]:
def inference(n_clients,protocol=None):
    
  # Get VM clients 
  parties=get_clients(n_clients)

  # Setup the session for the computation
  if(protocol):
     session = Session(parties = parties,protocol = protocol)
  else:
     session = Session(parties = parties)
        
  SessionManager.setup_mpc(session)

  #Split data and send data to clients
  pointers = split_send(test_x,session)

  #Encrypt model 
  mpc_model = model.share(session)

  #Encrypt test data
  #test_data=MPCTensor(secret=test_x, session = session)

  #Perform inference and measure time taken
  start_time = time.time()
    
  results = []
    
  for ptr in pointers:
     encrypted_results = mpc_model(ptr)
     plaintext_results = encrypted_results.reconstruct()
     results.append(plaintext_results)
        
  end_time = time.time()

  print(f"Time for inference: {end_time-start_time}s")
    
  predictions = torch.cat(results).reshape([-1])

  #Calculate Loss
  print("MSE Loss: ",criterion(predictions,test_y).item())
    
  return predictions

In [49]:
predictions=inference(3, Falcon("semi-honest"))

Time for inference: 0.2723851203918457s
MSE Loss:  20.22910499572754


In [53]:
for index in range(0,101):
    print(f"Index {index}")
    print(f"Encrypted Prediction Output {predictions[index].item()}")
    print(f"Plaintext Prediction Output {plaintext_predictions[index].item()}")
    print(f"Expected Prediction: {test_y[index]}")
    print("\n")

Index 0
Encrypted Prediction Output 2.598541259765625
Plaintext Prediction Output 2.59842848777771
Expected Prediction: 5.0


Index 1
Encrypted Prediction Output 4.947998046875
Plaintext Prediction Output 4.9478864669799805
Expected Prediction: 11.899999618530273


Index 2
Encrypted Prediction Output 19.206314086914062
Plaintext Prediction Output 19.206314086914062
Expected Prediction: 27.899999618530273


Index 3
Encrypted Prediction Output 12.046829223632812
Plaintext Prediction Output 12.04680061340332
Expected Prediction: 17.200000762939453


Index 4
Encrypted Prediction Output 19.068893432617188
Plaintext Prediction Output 19.06889533996582
Expected Prediction: 27.5


Index 5
Encrypted Prediction Output 11.845855712890625
Plaintext Prediction Output 11.845837593078613
Expected Prediction: 15.0


Index 6
Encrypted Prediction Output 16.331619262695312
Plaintext Prediction Output 16.331565856933594
Expected Prediction: 17.200000762939453


Index 7
Encrypted Prediction Output -1.26194

In [54]:
predictions=inference(3,Falcon("malicious"))

Time for inference: 1.6142754554748535s
MSE Loss:  20.22911262512207


In [27]:
print('Inference with 3 parties')
predictions=inference(3)
print()
print('Inference with 5 parties')
predictions=inference(5)


Inference with 3 parties
Time for inference: 0.9718637466430664s
MSE Loss:  20.227685928344727

Inference with 5 parties
Time for inference: 2.1683151721954346s
MSE Loss:  20.227663040161133
