# Implementation of (simplified) CryptoNet for inference under homomorphic encryption

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from activation import *
from logger import Logger
from dataHandler import DataHandler
from utils import *

In [2]:
class SimpleNet(nn.Module):
  '''
    Simpliefied network used in paper for inference https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/

    Input size: 1x28x28 pixel
    Pad: image is padded with 0s on left and top side --> 1x29x29
    Conv1 --> 5x13x13
    Pool1 --> 100x1x1 --> resized to 1x100x1
    Pool2 --> 10x1x1 --> resized to a vector of len 10
  '''
  
  def __init__(self, batch_size : int, activation : str, sigmoid : str, init_method : str, verbose : bool):
    super().__init__()
    self.verbose = verbose
    self.init_method = init_method
    self.batch_size = batch_size
    self.activation = activation
    self.sigmoid = sigmoid

    if activation == "square":
      self.activation = torch.square
    elif activation == "relu":
      self.activation = nn.ReLU()
    elif activation == "relu_approx":
      self.activation = ReLUApprox()

    if sigmoid == "sigmoid":
      self.sigmoid = nn.Sigmoid()
    elif sigmoid == "approx":
      self.sigmoid = SigmoidApprox()
    elif sigmoid == "none":
      self.sigmoid = identity

    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.pool1 = nn.Conv2d(in_channels=5, out_channels=100, kernel_size=13, stride=1000)
    self.pool2 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(100,1), stride=1000) # in chans is 1 not 100

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    
    x = self.conv1(x)
    
    x = self.activation(self.pool1(x))
    #print(x.shape)
    x = x.reshape([self.batch_size,1,100,1]) #batch_size tensors in 1 channel, 100x1
    x = self.activation(self.pool2(x))
    #print(x.shape)
    """
        |  legacy code:
        |  sigmoid as last activation improved performance, but it was removed
        v  to stick with the original architecture. sigmoid can (and is in final model) be just f(x) = x
    x = self.sigmoid(x)
    """
    x = self.sigmoid(x)
    x = x.reshape(x.shape[0], -1)
    #print(x.shape)
    return x
 
  def weights_init(self, m):
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        if self.init_method == "he":
          nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
        elif self.init_method == "xavier":
          nn.init.xavier_uniform_(m.weight, gain=math.sqrt(2))
        #elif self.init_method == "uniform":
        #  nn.init.uniform_(m.weight, -0.5, 0.5)
        #elif self.init_method == "norm":
        #  nn.init.normal_(m.weight, 0.0, 1.0)

Load Datasets

In [3]:
dataHandler = DataHandler(dataset="MNIST", batch_size=256)

Train and test pipeline

In [4]:
##############################
#                            #
# TRAINING AND EVAL PIPELINE #
#                            #
##############################

## TEST
## init models 
#methods = ["he", "xavier", "random"] ##he init blows up values with square
#methods = ["xavier","he"]
#activations = ["relu_approx","relu"]
#models = {}
#sigmoid = True
#for method in methods:
#  for activation in activations:
#    models[method+"_"+activation] = SimpleNet(batch_size=dataHandler.batch_size,
#                                    activation=activation,
#                                    init_method=method,
#                                    verbose=False,
#                                    sigmoid=sigmoid).to(device=device)
## TEST

models = {}
#models["xavier_relu"] = SimpleNet(batch_size=dataHandler.batch_size,
#                                    activation="relu",
#                                    init_method="xavier",
#                                    verbose=False,
#                                    sigmoid=True).to(device=device)
#models["he_relu"] = SimpleNet(batch_size=dataHandler.batch_size,
#                                    activation="relu",
#                                    init_method="he",
#                                    verbose=False,
#                                    sigmoid=True).to(device=device)

## Most promising model. With approximated sigmoid we can increase accuracy
## up to 96%, but it's not faithful to the original model, plus it is more complex
models["xavier_relu_approx"] = SimpleNet(batch_size=dataHandler.batch_size,
                                    activation="relu_approx",
                                    init_method="xavier",
                                    sigmoid="none",
                                    verbose=False).to(device=device)

#models["he_relu_approx"] = SimpleNet(batch_size=dataHandler.batch_size,
#                                    activation="relu_approx",
#                                    init_method="he",
#                                    verbose=False,
#                                    sigmoid=False).to(device=device)

scores = {}

for key, model in models.items():
  logger = Logger("./logs/",f"SimpleNet_{key}")
  model.apply(model.weights_init)
  train(logger, model, dataHandler, num_epochs=400, lr=0.001)
  loss, accuracy = eval(logger, model, dataHandler)
  scores[key] = {"loss":loss, "accuracy":accuracy}
  torch.save(model, f"./models/SimpleNet_{key}.pt")


for key, metrics in scores.items():
  print("=====================================================================")
  print(f"[+] Model with {key}: Avg test Loss ==> {metrics['loss']}, Accuracy ==> {metrics['accuracy']}")

[?] SimpleNet_xavier_relu_approx Epoch 1/400 Loss 0.1487
[?] SimpleNet_xavier_relu_approx Epoch 2/400 Loss 0.0805
[?] SimpleNet_xavier_relu_approx Epoch 3/400 Loss 0.0736
[?] SimpleNet_xavier_relu_approx Epoch 4/400 Loss 0.0672
[?] SimpleNet_xavier_relu_approx Epoch 5/400 Loss 0.0613
[?] SimpleNet_xavier_relu_approx Epoch 6/400 Loss 0.0565
[?] SimpleNet_xavier_relu_approx Epoch 7/400 Loss 0.0527
[?] SimpleNet_xavier_relu_approx Epoch 8/400 Loss 0.0498
[?] SimpleNet_xavier_relu_approx Epoch 9/400 Loss 0.0475
[?] SimpleNet_xavier_relu_approx Epoch 10/400 Loss 0.0458
[?] SimpleNet_xavier_relu_approx Epoch 11/400 Loss 0.0444
[?] SimpleNet_xavier_relu_approx Epoch 12/400 Loss 0.0433
[?] SimpleNet_xavier_relu_approx Epoch 13/400 Loss 0.0424
[?] SimpleNet_xavier_relu_approx Epoch 14/400 Loss 0.0416
[?] SimpleNet_xavier_relu_approx Epoch 15/400 Loss 0.0409
[?] SimpleNet_xavier_relu_approx Epoch 16/400 Loss 0.0403
[?] SimpleNet_xavier_relu_approx Epoch 17/400 Loss 0.0398
[?] SimpleNet_xavier_re