# Implementation of (simplified) CryptoNet for inference under homomorphic encryption

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from activation import relu_approx, sigmoid_approx
from logger import Logger
from dataHandler import DataHandler
from utils import *

CryptoNet from [Microsoft](https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/) ==> couldn't replicate results

In [2]:
'''
class ScaledAvgPool2d(nn.Module):
    """Define the ScaledAvgPool layer, a.k.a the Sum Pool"""
    def __init__(self,kernel_size):
      super().__init__()
      self.kernel_size = kernel_size
      self.AvgPool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=1, padding=int(math.ceil((kernel_size-1)/2)))

    def forward(self,x):
      return (self.kernel_size**2)*self.AvgPool(x)
    

class CryptoNet(nn.Module):
  """
    Original 9-layer network used during training
    CURRENTLY NOT WORKING
  """
  def __init__(self, verbose):
    super().__init__()
    self.verbose = verbose
    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.square1 = torch.square
    self.scaledAvgPool1 = ScaledAvgPool2d(kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2)
    self.scaledAvgPool2 = ScaledAvgPool2d(kernel_size=3)
    self.fc1 = nn.Linear(in_features=1250, out_features=100)
    self.square2 = torch.square
    self.fc2 = nn.Linear(in_features=100, out_features=10)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    if self.verbose:
      print("Start --> ",x.mean())
    x = self.conv1(x)
    if self.verbose:
      print("Conv1 --> ",x.mean())
    x = self.square1(x)
    if self.verbose:
      print("Sq --> ",x.mean())
    x = self.scaledAvgPool1(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    x = self.conv2(x)
    if self.verbose:
      print("Conv2 --> ",x.mean())
    x = self.scaledAvgPool2(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    ## Flatten
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    if self.verbose:
      print("fc1 --> ",x.mean())
    x = self.square2(x)
    if self.verbose:
      print("Square --> ",x.mean())
    x = self.fc2(x)
    if self.verbose:
      print("fc2 --> ",x.mean())
    x = self.sigmoid(x)
    return x

  def weights_init(self, m):
    """ Custom initilization to avoid square activation to blow up """
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
      elif isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, 1e-4,1e-3)
'''

'\nclass ScaledAvgPool2d(nn.Module):\n    """Define the ScaledAvgPool layer, a.k.a the Sum Pool"""\n    def __init__(self,kernel_size):\n      super().__init__()\n      self.kernel_size = kernel_size\n      self.AvgPool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=1, padding=int(math.ceil((kernel_size-1)/2)))\n\n    def forward(self,x):\n      return (self.kernel_size**2)*self.AvgPool(x)\n    \n\nclass CryptoNet(nn.Module):\n  """\n    Original 9-layer network used during training\n    CURRENTLY NOT WORKING\n  """\n  def __init__(self, verbose):\n    super().__init__()\n    self.verbose = verbose\n    self.pad = F.pad\n    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)\n    self.square1 = torch.square\n    self.scaledAvgPool1 = ScaledAvgPool2d(kernel_size=3)\n    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2)\n    self.scaledAvgPool2 = ScaledAvgPool2d(kernel_size=3)\n    self.fc1 = nn.Linear(in_features=1250, out_f

In [3]:
class SimpleNet(nn.Module):
  '''
    Simpliefied network used in paper for inference https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/
  '''
  def __init__(self, batch_size : int, activation : str, init_method : str, verbose : bool, sigmoid : bool):
    super().__init__()
    self.verbose = verbose
    self.init_method = init_method
    self.batch_size = batch_size

    if activation == "square":
      self.activation = torch.square
    elif activation == "relu":
      self.activation = nn.ReLU()
    elif activation == "relu_approx":
      self.activation = relu_approx

    if sigmoid:
      self.sigmoid = nn.Sigmoid()
    else:
      self.sigmoid = sigmoid_approx

    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.pool1 = nn.Conv2d(in_channels=5, out_channels=100, kernel_size=13, stride=1000)
    self.pool2 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(100,1), stride=1000)

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    x = self.conv1(x)
    x = self.activation(self.pool1(x))
    #print(x[0])
    x = x.reshape([self.batch_size,1,100,1]) #batch_size tensors in 1 channel, 100x1
    x = self.activation(self.pool2(x))
    #print(x[0])
    x = self.sigmoid(x) ##needed for the probabilities
    x = x.reshape(x.shape[0], -1)
    return x
 
  def weights_init(self, m):
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        if self.init_method == "he":
          nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
        elif self.init_method == "xavier":
          nn.init.xavier_uniform_(m.weight, gain=math.sqrt(2))
        elif self.init_method == "uniform":
          nn.init.uniform_(m.weight, -0.5, 0.5)
        elif self.init_method == "norm":
          nn.init.normal_(m.weight, 0.0, 1.0)

Load Datasets

In [4]:
dataHandler = DataHandler(dataset="MNIST", batch_size=256)

Train and test pipeline

In [6]:
##############################
#                            #
# TRAINING AND EVAL PIPELINE #
#                            #
##############################

## init models
#methods = ["he", "xavier", "random"] ##he init blows up values with square
methods = ["xavier","random"]
activations = ["relu_approx","square"]#"relu"]
models = {}
sigmoid = False
for method in methods:
  for activation in activations:
    models[method+"_"+activation] = SimpleNet(batch_size=dataHandler.batch_size,
                                    activation=activation,
                                    init_method=method,
                                    verbose=False,
                                    sigmoid=sigmoid).to(device=device)
scores = {}

## Testing of different stuff ==> result was best xavier+square
for key, model in models.items():
  logger = Logger("./logs/",f"SimpleNet_{key}")
  model.apply(model.weights_init)
  train(logger, model, dataHandler, num_epochs=150)
  loss, accuracy = eval(logger, model, dataHandler)
  scores[key] = {"loss":loss, "accuracy":accuracy}
  if sigmoid:
    torch.save(model, f"SimpleNet_{key}_sigmoid.pt")
  else:
    torch.save(model, f"SimpleNet_{key}_approx_sigmoid.pt")

## Best Model on 10 epochs
#key = "xavier_relu"
#model = models[key]
#model.apply(model.weights_init)
#train(key, model, dataHandler, num_epochs=150, TPU=False)
#loss, accuracy = eval(key,model, dataHandler)
#scores[key] = {"loss":loss, "accuracy":accuracy}
#torch.save(model, f"SimpleNet_{key}.pt")

for key, metrics in scores.items():
  print("=====================================================================")
  print(f"[+] Model with {key}: Avg test Loss ==> {metrics['loss']}, Accuracy ==> {metrics['accuracy']}")

[?] SimpleNet_xavier_relu_approx Epoch 1/100 Loss 0.1953
[?] SimpleNet_xavier_relu_approx Epoch 2/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 3/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 4/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 5/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 6/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 7/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 8/100 Loss 0.1952
[?] SimpleNet_xavier_relu_approx Epoch 9/100 Loss 0.1951
[?] SimpleNet_xavier_relu_approx Epoch 10/100 Loss 0.1938
[?] SimpleNet_xavier_relu_approx Epoch 11/100 Loss 0.1921
[?] SimpleNet_xavier_relu_approx Epoch 12/100 Loss 0.1910
[?] SimpleNet_xavier_relu_approx Epoch 13/100 Loss 0.1880
[?] SimpleNet_xavier_relu_approx Epoch 14/100 Loss 0.1837
[?] SimpleNet_xavier_relu_approx Epoch 15/100 Loss 0.1836
[?] SimpleNet_xavier_relu_approx Epoch 16/100 Loss 0.1777
[?] SimpleNet_xavier_relu_approx Epoch 17/100 Loss 0.1765
[?] SimpleNet_xavier_re

KeyboardInterrupt: 