# Implementation of (simplified) CryptoNet and AlexNet for inference under homomorphic encryption

In [23]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torchvision.utils import save_image
import torch.optim as optim
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from torch.utils.tensorboard import SummaryWriter
import math

## interactive off
plt.ioff()
## setup torch enviro
torch.manual_seed(42)
torch.autograd.set_detect_anomaly(True)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Approximated Relus

In [24]:
def approx_relu_2d(x):
  """2-degree approx of relu in [-6,6] from https://arxiv.org/pdf/2009.03727.pdf"""
  a = 0.563059
  b = 0.5
  c = 0.078047
  x_2 = torch.square(x)
  return a + b*x + c*x_2
  
def approx_relu_4d(x):
  """4-degree approx of relu in [-6,6] from https://arxiv.org/pdf/2009.03727.pdf"""
  a = 0.119782
  b = 0.5
  c = 0.147298
  d = -0.002015
  x_2 = torch.square(x)
  x_4 = torch.square(x_2)
  return a + b*x + c*x_2 + d*x_4

CryptoNet from [Microsoft](https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/) ==> couldn't replicate results

In [25]:
'''
class ScaledAvgPool2d(nn.Module):
    """Define the ScaledAvgPool layer, a.k.a the Sum Pool"""
    def __init__(self,kernel_size):
      super().__init__()
      self.kernel_size = kernel_size
      self.AvgPool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=1, padding=int(math.ceil((kernel_size-1)/2)))

    def forward(self,x):
      return (self.kernel_size**2)*self.AvgPool(x)
    

class CryptoNet(nn.Module):
  """
    Original 9-layer network used during training
    CURRENTLY NOT WORKING
  """
  def __init__(self, verbose):
    super().__init__()
    self.verbose = verbose
    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.square1 = torch.square
    self.scaledAvgPool1 = ScaledAvgPool2d(kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2)
    self.scaledAvgPool2 = ScaledAvgPool2d(kernel_size=3)
    self.fc1 = nn.Linear(in_features=1250, out_features=100)
    self.square2 = torch.square
    self.fc2 = nn.Linear(in_features=100, out_features=10)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    if self.verbose:
      print("Start --> ",x.mean())
    x = self.conv1(x)
    if self.verbose:
      print("Conv1 --> ",x.mean())
    x = self.square1(x)
    if self.verbose:
      print("Sq --> ",x.mean())
    x = self.scaledAvgPool1(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    x = self.conv2(x)
    if self.verbose:
      print("Conv2 --> ",x.mean())
    x = self.scaledAvgPool2(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    ## Flatten
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    if self.verbose:
      print("fc1 --> ",x.mean())
    x = self.square2(x)
    if self.verbose:
      print("Square --> ",x.mean())
    x = self.fc2(x)
    if self.verbose:
      print("fc2 --> ",x.mean())
    x = self.sigmoid(x)
    return x

  def weights_init(self, m):
    """ Custom initilization to avoid square activation to blow up """
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
      elif isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, 1e-4,1e-3)
'''

'\nclass ScaledAvgPool2d(nn.Module):\n    """Define the ScaledAvgPool layer, a.k.a the Sum Pool"""\n    def __init__(self,kernel_size):\n      super().__init__()\n      self.kernel_size = kernel_size\n      self.AvgPool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=1, padding=int(math.ceil((kernel_size-1)/2)))\n\n    def forward(self,x):\n      return (self.kernel_size**2)*self.AvgPool(x)\n    \n\nclass CryptoNet(nn.Module):\n  """\n    Original 9-layer network used during training\n    CURRENTLY NOT WORKING\n  """\n  def __init__(self, verbose):\n    super().__init__()\n    self.verbose = verbose\n    self.pad = F.pad\n    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)\n    self.square1 = torch.square\n    self.scaledAvgPool1 = ScaledAvgPool2d(kernel_size=3)\n    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2)\n    self.scaledAvgPool2 = ScaledAvgPool2d(kernel_size=3)\n    self.fc1 = nn.Linear(in_features=1250, out_f

In [26]:
class SimpleNet(nn.Module):
  '''
    Simpliefied network used in paper for inference https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/
  '''
  def __init__(self, batch_size : int, activation : str, init_method : str, verbose : bool):
    super().__init__()
    self.verbose = verbose
    self.init_method = init_method
    self.batch_size = batch_size

    if activation == "square":
      self.activation = torch.square
    elif activation == "relu":
      self.activation = nn.ReLU()
    elif activation == "a-relu-2d":
      self.activation = approx_relu_2d
    elif activation == "a-relu-4d":
      self.activation = approx_relu_4d

    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.pool1 = nn.Conv2d(in_channels=5, out_channels=100, kernel_size=13, stride=1000)
    self.pool2 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(100,1), stride=1000)

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    x = self.conv1(x)
    x = self.activation(self.pool1(x))
    x = x.reshape([self.batch_size,1,100,1]) #batch_size tensors in 1 channel, 100x1
    x = self.activation(self.pool2(x))
    x = x.reshape(x.shape[0], -1)
    return x
 
  def weights_init(self, m):
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        if self.init_method == "he":
          nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
        elif self.init_method == "xavier":
          nn.init.xavier_uniform_(m.weight, gain=math.sqrt(2))
        elif self.init_method == "uniform":
          nn.init.uniform_(m.weight, -0.5, 0.5)
        elif self.init_method == "norm":
          nn.init.normal_(m.weight, 0.0, 1.0)

Modified AlexNet with ReLU approximation

In [27]:
class AlexNet(nn.Module):
  def __init__(self, verbose: bool):
    super().__init__()
    self.verbose = verbose
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4, padding=0)
    self.pool = nn.MaxPool2d(kernel_size=3, stride=2) 
    self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding= 2)
    self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding= 1)
    self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
    self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
    self.fc1  = nn.Linear(in_features= 9216, out_features= 4096)
    self.fc2  = nn.Linear(in_features= 4096, out_features= 4096)
    self.fc3 = nn.Linear(in_features=4096 , out_features=10)
    self.ReLU = nn.ReLU()

    nn.init.kaiming_uniform_(self.conv1.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv2.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv3.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv4.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.conv5.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.fc1.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.kaiming_uniform_(self.fc2.weight, a=0, mode='fan_in', nonlinearity='relu')
    nn.init.xavier_uniform_(self.fc3.weight, gain=math.sqrt(2))

  def forward(self,x):
    x = self.ReLU(self.conv1(x))
    x = self.pool(x)
    x = self.ReLU(self.conv2(x))
    x = self.pool(x)
    x = self.ReLU(self.conv3(x))
    x = self.ReLU(self.conv4(x))
    x = self.ReLU(self.conv5(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.ReLU(self.fc1(x))
    x = self.ReLU(self.fc2(x))
    x = self.fc3(x)
    x = torch.sigmoid(x) ## this used to be softmax
    return x

Load Datasets

In [28]:
class DataHandler():
  def __init__(self, dataset : str, batch_size : int):
    if dataset == "MNIST":
      self.batch_size = batch_size
      
     
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
      
    train_ds = MNIST("data/", train=True, download=True, transform=transform)
    test_ds = MNIST("data/", train=False, download=True, transform=transform)
    self.train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle=True, drop_last=True,num_workers=2, pin_memory=True)
    self.test_dl = DataLoader(test_ds, batch_size = batch_size, shuffle=True, drop_last=True,num_workers=2, pin_memory=True)

dataHandler = DataHandler(dataset="MNIST", batch_size=256)

Plot gradient flow for debug

In [29]:
def plot_grad_flow(named_parameters):
    ## From https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063
    ## Beware it's a little bit tricky to interpret results
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''

    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
            print(f"Layer {n}, grad avg {p.grad.mean()}, data {p.data.mean()}")
    plt.bar(np.arange(len(max_grads)), max(max_grads), alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), np.mean(ave_grads), alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
    


Train and test pipeline

In [22]:
## training params setup
learning_rate = 3e-4
total_step = len(dataHandler.train_dl)
criterion = nn.CrossEntropyLoss()

## PLOT HELPER
def plot_history(key, train, history):
  """ 
    Plot loss and accuracy history during model run
    Input:
          key : str => name of the model
          train : bool => training 1 or test 0
          history : dict{str : list of floats}
  """
  if train:
    when = "train"
  else:
    when = "test"
  fig, ax = plt.subplots( 1, 2, figsize = (12,4) )
  ax[0].plot(history['loss'], label = when+"----"+key)
  ax[0].set_title( "Loss" )
  ax[0].set_xlabel( "Epochs" )
  ax[0].set_ylabel( "Loss" )
  ax[0].grid( True )
  ax[0].legend()

  ax[1].plot(history['accuracy'], label = when+"----"+key)
  ax[1].set_title( "Accuracy" )
  ax[1].set_xlabel( "Epochs" )
  ax[1].set_ylabel( "Accuracy" )
  ax[1].grid( True )
  ax[1].legend()

  plt.savefig(f"{key}_{when}.png")

## TRAIN
def train(key, model, dataHandler, num_epochs, TPU=False):
  num_epochs = num_epochs
  model.train()
  #optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  trainHistory = {}
  trainHistory['loss'] = []
  trainHistory['accuracy'] = []

  for epoch in range(num_epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    num_correct = 0
    num_samples = 0
    for i, (data, labels) in enumerate(dataHandler.train_dl):
      data = data.to(device=device)
      labels = labels.to(device=device)
      #labels = labels.to(torch.float32)

      optimizer.zero_grad()
      predictions = model(data)
      loss = criterion(predictions, labels)
      loss.backward()
      
      if model.verbose:
        print(f"[?] Step {i+1} Epoch {epoch+1}")
        plot_grad_flow(model.named_parameters())
      
      if not TPU:
        optimizer.step()
      else:
        xm.optimizer_step(optimizer, barrier=True) ## if TPU 
      
      _, predicted_labels = predictions.max(1)
      num_correct += (predicted_labels == labels).sum()
      num_samples += predicted_labels.size(0)
      
      epoch_accuracy += num_correct/num_samples
      epoch_loss += loss.item()

      if (i+1) % 100 == 0:
        print("=====================================================================================================================")
        print ('[!] Train Epoch [{}/{}], Step [{}/{}] ==> Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
      
    trainHistory['loss'].append(epoch_loss/len(dataHandler.train_dl))
    trainHistory['accuracy'].append(epoch_accuracy/len(dataHandler.train_dl))
    
  plot_history(key, True, trainHistory)


## EVAL 
def eval(key, model, dataHandler):
  num_correct = 0
  num_samples = 0

  model.eval()
  testHistory = {}
  testHistory['loss'] = []
  testHistory['accuracy'] = []
  test_loss = 0
  accuracy = 0
  with torch.no_grad():
    for _, (data,labels) in enumerate(dataHandler.test_dl):
        data = data.to(device="cpu")
        labels = labels.to(device="cpu")
        ## Forward Pass
        predictions = model(data)
        loss = criterion(predictions, labels).item()
        test_loss += loss
        _, predicted_labels = predictions.max(1)
        num_correct += (predicted_labels == labels).sum()
        num_samples += predicted_labels.size(0)
        testHistory['loss'].append(loss)
        testHistory['accuracy'].append(float(num_correct) / float(num_samples))
  
    accuracy = float(num_correct) / float(num_samples)
    test_loss = test_loss/len(dataHandler.test_dl)
    print("=============================")
    print(f"Average test Loss ==> {test_loss}")
    print(f"Test accuracy ==> {float(num_correct) / float(num_samples) * 100:.2f}")

    plot_history(key, False, testHistory)

  return test_loss, accuracy


Training and evaluation of SimpleNet

In [None]:
##############################
#                            #
# TRAINING AND EVAL PIPELINE #
#                            #
##############################

## init models
methods = ["random", "he", "xavier", "uniform", "norm"]
activations = ["relu","square", "a-relu-2d", "a-relu-4d"]
models = {}

for method in methods:
  for activation in activations:
    models[method+"_"+activation] = SimpleNet(batch_size=dataHandler.batch_size, activation=activation, init_method=method,verbose=False).to(device=device)
scores = {}

## Testing of different stuff ==> result was best xavier+square
for key, model in models.items():
  model.apply(model.weights_init)
  train(key, model, dataHandler, num_epochs=10)
  loss, accuracy = eval(key,model, dataHandler)
  scores[key] = {"loss":loss, "accuracy":accuracy}
  torch.save(model, f"SimpleNet_{key}.pt")

## Best Model on 60 epochs
#key = "xavier_square"
#model = models[key]
#model.apply(model.weights_init)
#train(key, model, dataHandler, num_epochs=150, TPU=False)
#loss, accuracy = eval(key,model, dataHandler)
#scores[key] = {"loss":loss, "accuracy":accuracy}
#torch.save(model, f"SimpleNet_{key}.pt")

for key, metrics in scores.items():
  print("=====================================================================")
  print(f"[+] Model with {key}: Avg test Loss ==> {metrics['loss']}, Accuracy ==> {metrics['accuracy']}")

[!] Train Epoch [1/10], Step [100/234] ==> Loss: 0.4587
[!] Train Epoch [1/10], Step [200/234] ==> Loss: 0.3362
[!] Train Epoch [2/10], Step [100/234] ==> Loss: 0.2130
[!] Train Epoch [2/10], Step [200/234] ==> Loss: 0.2530
[!] Train Epoch [3/10], Step [100/234] ==> Loss: 0.1318
[!] Train Epoch [3/10], Step [200/234] ==> Loss: 0.1515
[!] Train Epoch [4/10], Step [100/234] ==> Loss: 0.1708
[!] Train Epoch [4/10], Step [200/234] ==> Loss: 0.1400
[!] Train Epoch [5/10], Step [100/234] ==> Loss: 0.0843
[!] Train Epoch [5/10], Step [200/234] ==> Loss: 0.0835
[!] Train Epoch [6/10], Step [100/234] ==> Loss: 0.1128
[!] Train Epoch [6/10], Step [200/234] ==> Loss: 0.1780
[!] Train Epoch [7/10], Step [100/234] ==> Loss: 0.1083
[!] Train Epoch [7/10], Step [200/234] ==> Loss: 0.1103
[!] Train Epoch [8/10], Step [100/234] ==> Loss: 0.0748
[!] Train Epoch [8/10], Step [200/234] ==> Loss: 0.0897
[!] Train Epoch [9/10], Step [100/234] ==> Loss: 0.0803
[!] Train Epoch [9/10], Step [200/234] ==> Loss:

# Results of SimpleNet evaluation
Best model seems to be the one with square activation and xavier initialization.