# Encrypted Inference-Linear Regression with Syft and SMPC

Encrypted Inference is the process of performing inference with machine learning models such
that model owner cannot observe the true input data nor can the data owner see the true model weights.
The weights and data are encrypted by splitting them into shares and performiming computations
according to a protocol. The general class of methods know as Secure Multi Party Computation (SMPC).

This case study contains a scenario of successful encrypted inference with encrypted ML model (linear regression)
on an encrypted data.
This experiment is implemented with the Torch, Syft and SyMPC frameworks

### Imports

In [None]:
#External libraries
import pandas as pd
import numpy as np
import time

#Import torch
import torch
torch.manual_seed(0)

#Import syft
import syft as sy
sy.logger.remove()

### Data Loading and Processing
Initial data loading, normalization and splitting into train and test
parts

In [None]:
#Import dataset and add headers
dataset=pd.read_csv("../../../data/housing_data/housing.data",delim_whitespace=True,
                    names=["crim","zn","indus",
                           "chas","nox","rm",
                           "age","dis","rad",
                           "tax","ptratio","black",
                           "lstat","medv"])

In [None]:
#Visualize and look at columns and rows of dataset
dataset.head()

In [None]:
#Split data into features and target variables
features = dataset.drop("medv",axis=1)
targets = dataset["medv"]

#Normalize features
features = features.apply(
    lambda x: (x - x.mean()) / x.std()
)
#display dataset
display(features)
display(targets)


#Convert features and targets into torch tensors
features = torch.tensor(features.values.astype(np.float32))
targets = torch.tensor(targets.values.astype(np.float32))

In [None]:
#Hyperparameters
batch_size = 16
epochs = 300
train_test_split = 0.8
lr = 0.001

#split data to train and test
train_indices=int(len(features)*train_test_split)

train_x = features[:train_indices]
train_y = targets[:train_indices]

test_x = features[train_indices+1:]
test_y = targets[train_indices+1:]

In [None]:
def get_batches(X, y):
    batches = []
    for index in range(0,len(train_x)+1,batch_size):
        batches.append((X[index:index+batch_size],y[index:index+batch_size]))
    
    return batches
train_batches = get_batches(train_x, train_y)

### Plaintext training
This part contains linear model definition, training cycle and model evaluation in regular, non-encrypted environment.

In [None]:
#Define Linear regression model
class LinearSyNet(sy.Module): # to make model accessable within SMPC session the special wrapper must be used
    def __init__(self, torch_ref):
        super(LinearSyNet, self).__init__(torch_ref=torch_ref)
        self.fc1 = self.torch_ref.nn.Linear(13,1)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [None]:
#Define model, loss function and optimizer
model = LinearSyNet(torch)
criterion = torch.nn.MSELoss(reduction='mean') 
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [None]:
#Training Loop
for epoch in range(epochs):
  running_loss = 0.0
  for index in range(0,len(train_batches)):
    # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward, dont want to cummulate gradients
    optimizer.zero_grad()

    # get output from the model, given the inputs
    outputs = model(train_batches[index][0]).reshape([-1])

    # get loss for the predicted output
    loss = criterion(outputs,train_batches[index][1])
    running_loss += loss
    # get gradients w.r.t to parameters
    loss.backward()

    # update parameters
    optimizer.step()
    
  test_accuracy = criterion(model(test_x).reshape([-1]),test_y)
  if((epoch%50)==0):
     print(f"Epoch {epoch}/{epochs}  Running Loss : {running_loss.item()/batch_size} and test loss: {test_accuracy.item()}")

### Plaintext Inference
This part contains performing regular, plaintext inference for result comparison

In [None]:
#Perform inference in plaintext
start_time=time.time()
plaintext_predictions = model(test_x).reshape([-1])
end_time=time.time()
#Calculate inference time and MSELoss
print("MSE Loss: ",criterion(plaintext_predictions,test_y).item())
print("Inference time: ",str(end_time-start_time),"s")

### Encrypted Inference
This part contains the steps for performing encrypted inference

In [None]:
#SyMPC imports required for encrypted inference
import sympc
from sympc.session import Session
from sympc.session import SessionManager
from sympc.protocol import FSS,Falcon

In [None]:
# Define function that generates required number of syft clients and return them.
def get_clients(n_parties):
  parties=[]
  for index in range(n_parties): 
      parties.append(sy.VirtualMachine(name = "worker"+str(index)).get_root_client())
  return parties

In [None]:
# Define function to split and share the data between simulate parties
def split_send(data,session):
    """Splits data into number of chunks equal to number of parties and distributes it to respective parties."""
    data_pointers = []
    split_size = int(len(data)/len(session.parties))+1
    for index in range(0,len(session.parties)):
        ptr=data[index*split_size:index*split_size+split_size].share(session=session)
        data_pointers.append(ptr)
    return data_pointers

In [None]:
# Define function to perform encrypted inference
def inference(n_clients,protocol=None):
  # Get VM clients 
  parties=get_clients(n_clients)
  # Setup the session for the computation
  if(protocol):
     session = Session(parties = parties,protocol = protocol)
  else:
     session = Session(parties = parties)
  SessionManager.setup_mpc(session)
  #Split data and send data to clients
  pointers = split_send(test_x,session)
  #Encrypt model 
  mpc_model = model.share(session)
  #Perform inference and measure time taken
  start_time = time.time()
  results = []
  for ptr in pointers:
     encrypted_results = mpc_model(ptr)
     plaintext_results = encrypted_results.reconstruct()
     results.append(plaintext_results)
  end_time = time.time()
  print(f"Time for inference: {end_time-start_time}s")
  predictions = torch.cat(results).reshape([-1])
  #Calculate Loss
  print("MSE Loss: ",criterion(predictions,test_y).item())
    
  return predictions

In [None]:
predictions=inference(3, Falcon("semi-honest"))

In [None]:
# Compare first 100 encrypted inference results
for index in range(0,101):
    print(f"Index {index}")
    print(f"Encrypted Prediction Output {predictions[index].item()}")
    print(f"Plaintext Prediction Output {plaintext_predictions[index].item()}")
    print(f"Expected Prediction: {test_y[index]}")
    print("\n")

In [None]:
predictions=inference(3,Falcon("malicious"))

In [None]:
print('Inference with 3 parties')
predictions=inference(3)
print()
print('Inference with 5 parties')
predictions=inference(5)

### Summary
This case study contains the scenario of performing the inference on encrypted data using the
encrypted model. The inference process is implemented using the SyMPC, Syft and Torch.
The inference was performed several time using different protocols, adversary models and number of parties.
The plain and encrypted inferences results are very close to each other, difference in fifth place after the point
is considered as acceptable.
Summarizing all statements and performed action the experiment is considered as successful, the
privacy of data was preserved without significant precision loss.