# CryptoNet implementation and training on MNIST dataset

In [4]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torchvision.utils import save_image
import torch.optim as optim
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from torch.utils.tensorboard import SummaryWriter
import math


Dummy model for testing training pipeline

In [None]:
"""
## DUMMY MODEL
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()        
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)    
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output  # return x for visualization
  """

CryptoNet from [Microsoft](https://www.microsoft.com/en-us/research/publication/cryptonets-applying-neural-networks-to-encrypted-data-with-high-throughput-and-accuracy/)

In [5]:
class ScaledAvgPool2d(nn.Module):
    """Define the ScaledAvgPool layer, a.k.a the Sum Pool"""
    def __init__(self,kernel_size):
      super().__init__()
      self.kernel_size = kernel_size
      self.AvgPool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=1, padding=1)

    def forward(self,x):
      return (self.kernel_size**2)*self.AvgPool(x)
    

class CryptoNet(nn.Module):
  '''
    Original 9-layer network used during training
    CURRENTLY NOT WORKING
  '''
  def __init__(self, verbose):
    super().__init__()
    self.verbose = verbose
    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.square1 = torch.square
    self.scaledAvgPool1 = ScaledAvgPool2d(kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2)
    self.scaledAvgPool2 = ScaledAvgPool2d(kernel_size=3)
    self.fc1 = nn.Linear(in_features=1250, out_features=100) # in paper in_features was 1250
    self.square2 = torch.square
    self.fc2 = nn.Linear(in_features=100, out_features=10)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    if self.verbose:
      print("Start --> ",x.mean())
    x = self.conv1(x)
    if self.verbose:
      print("Conv1 --> ",x.mean())
    x = self.square1(x)
    if self.verbose:
      print("Sq --> ",x.mean())
    x = self.scaledAvgPool1(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    x = self.conv2(x)
    if self.verbose:
      print("Conv2 --> ",x.mean())
    x = self.scaledAvgPool2(x)
    if self.verbose:
      print("Pool --> ",x.mean())
    ## Flatten
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    if self.verbose:
      print("fc1 --> ",x.mean())
    x = self.square2(x)
    if self.verbose:
      print("Square --> ",x.mean())
    x = self.fc2(x)
    if self.verbose:
      print("fc2 --> ",x.mean())
    x = self.sigmoid(x)
    return x

  def weights_init(self, m):
    """ Custom initilization to avoid square activation to blow up """
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
      elif isinstance(m, nn.Linear):
        nn.init.uniform_(m.weight, 1e-4,1e-3)


In [30]:
class SimpleNet(nn.Module):
  '''
    Simpliefied network used in paper for inference
  '''
  def __init__(self, batch_size, verbose):
    super().__init__()
    self.verbose = verbose
    self.batch_size = batch_size
    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.square1 = torch.square
    self.pool1 = nn.Conv2d(in_channels=5, out_channels=100, kernel_size=13, stride=1000)
    self.square2 = torch.square
    self.pool2 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(100,1), stride=1000)

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))
    x = self.conv1(x)
    x = self.square1(self.pool1(x))
    x = x.reshape([self.batch_size,1,100,1]) #batch_size tensors in 1 channel, 100x1
    x = self.square2(self.pool2(x))
    x = x.reshape(x.shape[0], -1)
    return x

  def weights_init(self, m):
    """ HE weigth init --> do not use, worse performance"""
    for m in self.children():
      if isinstance(m,nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu')

  def set_batch_size(self, batch_size):
    self.batch_size = batch_size

Load Datasets

In [37]:
class DataHandler():
  def __init__(self, dataset : str, batch_size : int):
    if dataset == "MNIST":
      self.batch_size = batch_size
      transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
      train_ds = MNIST("data/", train=True, download=True, transform=transform)
      test_ds = MNIST("data/", train=False, download=True, transform=transforms.ToTensor())

      self.train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle=True, drop_last=True)
      self.test_dl = DataLoader(test_ds, batch_size = batch_size, shuffle=True, drop_last=True)

Plot gradient flow

In [9]:
def plot_grad_flow(named_parameters):
    ## From https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063
    ## Beware it's a little bit tricky to interpret results
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''

    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
            print(f"Layer {n}, grad avg {p.grad.mean()}, data {p.data.mean()}")
    plt.bar(np.arange(len(max_grads)), max(max_grads), alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), np.mean(ave_grads), alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
    


Training of 9-layer CryptoNet

In [17]:
"""
## setup torch enviro
torch.manual_seed(9325345339582034)
torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## init model
model = CryptoNet(verbose=False)
model.apply(model.weights_init)
model = model.to(device=device)

dataHandler = DataHandler("MNIST")

## training params setup
learning_rate = 3e-4
momentum = 0.9
num_epochs = 5000
total_step = len(dataHandler.train_dl)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(num_epochs):
  for i, (data, labels) in enumerate(dataHandler.train_dl):
    data = data.to(device=device)
    labels = labels.to(device=device)
    #labels = labels.to(torch.float32)

    optimizer.zero_grad()
    predictions = model(data)
    loss = criterion(predictions, labels)
    loss.backward()
    if model.verbose:
      print(f"[?] Step {i+1} Epoch {epoch+1}")
      plot_grad_flow(model.named_parameters())
    optimizer.step()

    if (i+1) % 50 == 0:
      print ('[!] Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

torch.save(model, "cryptoNet.pt")
"""

[!] Epoch [1/5000], Step [50/234], Loss: 2.3518
[!] Epoch [1/5000], Step [100/234], Loss: 2.3479


KeyboardInterrupt: ignored

Simple Model

In [38]:
## setup torch enviro
torch.manual_seed(9325345339582034)
torch.autograd.set_detect_anomaly(True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataHandler = DataHandler(dataset="MNIST", batch_size=256)

## init model
model = SimpleNet(batch_size=dataHandler.batch_size, verbose=False,)
#model.apply(model.weights_init)
model = model.to(device=device)

## training params setup
learning_rate = 3e-4
num_epochs = 50
total_step = len(dataHandler.train_dl)
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(num_epochs):
  for i, (data, labels) in enumerate(dataHandler.train_dl):
    data = data.to(device=device)
    labels = labels.to(device=device)
    #labels = labels.to(torch.float32)

    optimizer.zero_grad()
    predictions = model(data)
    loss = criterion(predictions, labels)
    loss.backward()
    if model.verbose:
      print(f"[?] Step {i+1} Epoch {epoch+1}")
      plot_grad_flow(model.named_parameters())
    optimizer.step()

    if (i+1) % 50 == 0:
      print ('[!] Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

torch.save(model, "simpleNet.pt")

[!] Epoch [1/50], Step [50/234], Loss: 0.7270
[!] Epoch [1/50], Step [100/234], Loss: 0.3414
[!] Epoch [1/50], Step [150/234], Loss: 0.2210
[!] Epoch [1/50], Step [200/234], Loss: 0.1596
[!] Epoch [2/50], Step [50/234], Loss: 0.1058
[!] Epoch [2/50], Step [100/234], Loss: 0.1671
[!] Epoch [2/50], Step [150/234], Loss: 0.0880
[!] Epoch [2/50], Step [200/234], Loss: 0.1568
[!] Epoch [3/50], Step [50/234], Loss: 0.1020
[!] Epoch [3/50], Step [100/234], Loss: 0.0552
[!] Epoch [3/50], Step [150/234], Loss: 0.1263
[!] Epoch [3/50], Step [200/234], Loss: 0.1348
[!] Epoch [4/50], Step [50/234], Loss: 0.0698
[!] Epoch [4/50], Step [100/234], Loss: 0.0528
[!] Epoch [4/50], Step [150/234], Loss: 0.0784
[!] Epoch [4/50], Step [200/234], Loss: 0.0814
[!] Epoch [5/50], Step [50/234], Loss: 0.0540
[!] Epoch [5/50], Step [100/234], Loss: 0.0786
[!] Epoch [5/50], Step [150/234], Loss: 0.0745
[!] Epoch [5/50], Step [200/234], Loss: 0.0932
[!] Epoch [6/50], Step [50/234], Loss: 0.0704
[!] Epoch [6/50], S

Testing

In [40]:
num_correct = 0
num_samples = 0

model.eval()
for _, (data,labels) in enumerate(dataHandler.test_dl):
    data = data.to(device="cpu")
    labels = labels.to(device="cpu")
    ## Forward Pass
    predictions = model(data)
    _, predictions = predictions.max(1)
    num_correct += (predictions == labels).sum()
    num_samples += predictions.size(0)
print(f"Accuracy {float(num_correct) / float(num_samples) * 100:.2f}")

Accuracy 95.03
