# CryptoNet implementation and training on MNIST dataset

In [4]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torchvision.utils import save_image
import torch.optim as optim
from torch.utils.data import random_split


Model

In [5]:
## DUMMY MODEL
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()        
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)    
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output  # return x for visualization

In [27]:
class CryptoNet(nn.Module):
  '''
    TO DO: check how in the paper the avg pool does not downscale the input size...weird padding?
  '''
  def __init__(self):
    super().__init__()
    self.pad = F.pad
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=5, kernel_size=5, stride=2)
    self.square = torch.square
    self.scaledAvgPool = nn.LPPool2d(norm_type=1, kernel_size=3, stride=1) # avgPool scaled by a factor = |window|, i.e we do not divide the mean (SumPooling)
    self.conv2 = nn.Conv2d(in_channels=5, out_channels=50, kernel_size=5, stride=2, padding=1) # no padding maybe in paper
    self.fc1 = nn.Linear(in_features=50*3*3, out_features=100) # in paper in_features was 1250
    self.fc2 = nn.Linear(in_features=100, out_features=10)
    self.sigmoid = torch.sigmoid

  def forward(self, x):
    x = self.pad(x, (1,0,1,0))

    x = self.conv1(x)
   
    x = self.square(x)
   
    x = self.scaledAvgPool(x)
   
    x = self.conv2(x)
   
    x = self.scaledAvgPool(x)
   
    ## Flatten
    x = x.reshape(x.shape[0], -1)
   
    x = self.fc1(x)
   
    x = self.square(x)
   
    x = self.fc2(x)
   
    x = self.sigmoid(x)

    return x

Load Datasets

In [10]:

class DataHandler():
  def __init__(self, dataset : str):
    if dataset == "MNIST":
      transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
      train_ds = MNIST("data/", train=True, download=True, transform=transform)
      test_ds = MNIST("data/", train=False, download=True)

      self.train_dl = DataLoader(train_ds, batch_size = 64, shuffle=True)
      self.test_dl = DataLoader(test_ds, batch_size = 64, shuffle=True)



Training

In [28]:
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CryptoNet()
#model = CNN()

model = model.to(device=device)
dataHandler = DataHandler("MNIST")

learning_rate = 1e-2
momentum = 0.9
num_epochs = 50
total_step = len(dataHandler.train_dl)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
  model.train()
  for i, (data, labels) in enumerate(dataHandler.train_dl):
    data = data.to(device=device)
    labels = labels.to(device=device)
    #labels = labels.to(torch.float32)

    ## Forward
    optimizer.zero_grad()
    predictions = model(data)
    loss = criterion(predictions, labels)
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
      print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
torch.save(model, "cryptoNet.pt")

Epoch [1/50], Step [100/938], Loss: 2.4696
Epoch [1/50], Step [200/938], Loss: 2.3758
Epoch [1/50], Step [300/938], Loss: 2.3289
Epoch [1/50], Step [400/938], Loss: 2.4227
Epoch [1/50], Step [500/938], Loss: 2.4071
Epoch [1/50], Step [600/938], Loss: 2.3446
Epoch [1/50], Step [700/938], Loss: 2.3758
Epoch [1/50], Step [800/938], Loss: 2.4227
Epoch [1/50], Step [900/938], Loss: 2.3289
Epoch [2/50], Step [100/938], Loss: 2.2977
Epoch [2/50], Step [200/938], Loss: 2.3758
Epoch [2/50], Step [300/938], Loss: 2.4696


KeyboardInterrupt: ignored

Testing