Original code taken from https://medium.com/@aungkyawmyint_26195/multi-layer-perceptron-mnist-pytorch-463f795b897a

New things:
* Used a binary transformer with thresholding
* added functionality for switching between different artchitecutres

Other
* replaced some magic numbers with config variables
* Added variable to switch between architectures

In [None]:
import torch
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
from config import *


from common import plot_batch
%matplotlib inline

# Settings

In [None]:
from models.mnist_binary_conf import * # load configuration file for model

# Data Preparation

In [None]:
# choose the training and testing datasets
train_data = datasets.MNIST(root = "data", train = True, download = True, transform = transform)
test_data = datasets.MNIST(root = "data", train = False, download = True, transform = transform)

assert sample_size >= 0  # Error: Invalid TRAIN_SIZE
assert sample_size <= len(train_data) # Error: Invalid TRAIN_SIZE

# set number of subsamples 
if sample_size == 0: num_train = len(train_data)
else: num_train = sample_size

# obtain training indices that will be used for validation 
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_index, valid_index = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_index)
valid_sampler = SubsetRandomSampler(valid_index)

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, sampler = train_sampler, num_workers = num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, sampler = valid_sampler, num_workers = num_workers)
test_loader =  torch.utils.data.DataLoader(test_data,  batch_size = batch_size, num_workers = num_workers)

# Visualize a batch of training data

In [None]:
# obtain one batch of training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
plot_batch(images, labels, label_color="white")

# Define Network Architecture

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()    
        layers(self) # defined in the loaded conf file

    def forward(self, x):   
        return forward(self, x) # defined in loaded conf file

class Net_Relu(nn.Module):
    def __init__(self):
        super(Net_Relu,self).__init__()    
        layers(self) # defined in the loaded conf file

    def forward(self, x):   
        return forward_relu(self, x) # defined in loaded conf file
    
model_file_name = file_name_relu
# initialize the NN
#model = Net()
model = Net_Relu()
print(model)

# Specify Loss function and Optimizer

In [None]:
# specify loss function (categorical cross-entropy)
criterion = nn.CrossEntropyLoss()

# specify optimizer (stochastic gradient descent) and learning rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)

# Train the network

In [None]:
ENABLE_TRAIN = True
#ENABLE_TRAIN = False # So you can switch this off, if you want to rerun the entire notebook
model_file_name = MODEL_DIR + "model_relu_peter.pth"# if you want to override where the model params are saved to!
if ENABLE_TRAIN:
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf  # set initial "min" to infinity
    for epoch in range(n_epochs):
        # monitor losses
        train_loss = 0
        valid_loss = 0


        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for data,label in train_loader:
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output,label)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update running training loss
            train_loss += loss.item() * data.size(0)


        ######################    
        # validate the model #
        ######################
        model.eval()  # prep model for evaluation
        for data,label in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output,label)
            # update running validation loss 
            valid_loss = loss.item() * data.size(0)

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = train_loss / len(train_loader.sampler)
        valid_loss = valid_loss / len(valid_loader.sampler)

        print("Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}".format(
            epoch+1, 
            train_loss,
            valid_loss
            ))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print("Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
            valid_loss_min,
            valid_loss))
            torch.save(model.state_dict(), model_file_name)
            valid_loss_min = valid_loss

    print("### --------- TRAINING DONE --------- ###")
else:
    print("### --------- DID NOT TRAIN --------- ###")
    print("Maybe you forgot to enable training for this notebook?")
    print("Check the top of this cell.")

## Training a crypten model

In [None]:
ENABLE_TRAIN = True
ENABLE_TRAIN = False # So you can switch this off, if you want to rerun the entire notebook
#model_file_name = MODEL_DIR + "model_relu_peter.pth"# if you want to override where the model params are saved to!


#criterion = crypten.nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
if ENABLE_TRAIN:
    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf  # set initial "min" to infinity
    for epoch in range(n_epochs):
        # monitor losses
        train_loss = 0
        valid_loss = 0


        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for data,label in train_loader:
            # clear the gradients of all optimized variables
            #optimizer.zero_grad()
            model.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            #display(output)
            #print(output.shape, label.shape)
            loss = criterion(output, label)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            #optimizer.step()
            model.update_parameters(learning_rate)
            # update running training loss
            train_loss += loss.item() * data.size(0)


        ######################    
        # validate the model #
        ######################
        model.eval()  # prep model for evaluation
        for data,label in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output,label)
            # update running validation loss 
            valid_loss = loss.item() * data.size(0)

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = train_loss / len(train_loader.sampler)
        valid_loss = valid_loss / len(valid_loader.sampler)

        print("Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}".format(
            epoch+1, 
            train_loss,
            valid_loss
            ))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print("Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...".format(
            valid_loss_min,
            valid_loss))
            torch.save(model.state_dict(), model_file_name)
            valid_loss_min = valid_loss

    print("### --------- TRAINING DONE --------- ###")
else:
    print("### --------- DID NOT TRAIN --------- ###")
    print("Maybe you forgot to enable training for this notebook?")
    print("Check the top of this cell.")

# Load the Model with Lowest Validation Loss

In [None]:
model.load_state_dict(torch.load(model_file_name))

### Error when loading

I got this error when loading Marios model...

```bash
RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:132)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7f13cbfd0193 in /home/peter/.local/share/virtualenvs/Ex3-pG_9TV2D/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7f13cf1589eb in /home/peter/.local/share/virtualenvs/Ex3-pG_9TV2D/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7f13cf159c04 in /home/peter/.local/share/virtualenvs/Ex3-pG_9TV2D/lib/python3.7/site-packages/torch/lib/libtorch.so)
```

# Test the trained Network

## Single party/standard variant

In [None]:
# initialize lists to monitor test loss and accuracy
# NUM_CLASSES = 10

test_loss = 0.0
class_correct = [0] * NUM_CLASSES
class_total = [0] * NUM_CLASSES 


model.eval() # prep model for evaluation
for data, target in test_loader:
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data)
    # calculate the loss
    loss = criterion(output, target)
    # update test loss 
    test_loss += loss.item()*data.size(0)
    # convert output probabilities to predicted class
    _, pred = torch.max(output, 1)
    # compare predictions to true label
    correct = np.squeeze(pred.eq(target.data.view_as(pred)))
    # calculate test accuracy for each object class
    for i in range(len(target)):
        label = target.data[i]
        class_correct[label] += correct[i].item()
        class_total[label] += 1

# calculate and print avg test loss
test_loss = test_loss/len(test_loader.sampler)
print(f"Test Loss: {test_loss:.6}\n")
# Print accuracy per class
for i in range(NUM_CLASSES):
    if class_total[i] > 0:
        print(f"Test Accuracy of {i:5}: " 
              f"{100 * class_correct[i] / class_total[i]:3.0f}% "
              f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )")
    else:
        print(f"Test Accuracy of {classes[i]}: N/A (no training examples)")
# Print overall accuracy
print(f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% "
      f"( {np.sum(class_correct)} / {np.sum(class_total)} )")

## MPC variant

Ideally, we should see the same test results here as well...it's just going to take much longer

In [None]:
import crypten

assert sys.version_info[0] == 3 and sys.version_info[1] == 7, "python 3.7 is required!"

print(f"Okay, good! You have: {sys.version_info[:3]}")
# Now we can init crypten!
crypten.init()

In [None]:
import pathlib
import crypten.communicator as comm # the communicator is similar to the MPI communicator for example
from crypten import mpc
from multiprocessing import Barrier
from tqdm import tqdm
from time import time

from ex3_lib.dir_setup import POSSIBLE_PARTICIPANTS, check_and_mkdir

import warnings; 
warnings.filterwarnings("ignore")

check_and_mkdir(pathlib.Path("./log"))

In [None]:
def convert_legacy_config():
    if "MNIST_IMG_HWIDTH" in locals():
        if "MNIST_IMG_HEIGHT" in locals():
            IMG_HEIGHT = MNIST_IMG_HEIGHT
        else:
            IMG_HEIGHT = 28
    if "MNIST_IMG_HWIDTH" in locals():
        if "MNIST_IMG_HWIDTH" in locals():
            IMG_WIDTH = MNIST_IMG_HWIDTH
        else:
            IMG_WIDTH = 28
    if "IMAGE_TYPE" not in locals():
        IMAGE_TYPE = "grayscale"
        if "NUM_CHANNELS" not in locals():
            NUM_CHANNELS = 1

In [None]:
# LEGACY
convert_legacy_config()
dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH, IMG_HEIGHT]) # is that the right way around? :D
model_enc = crypten.nn.from_pytorch(model, dummy_image)

In [None]:
# initialize lists to monitor test loss and accuracy
# NUM_CLASSES = 10
num_participants = 2
participants = POSSIBLE_PARTICIPANTS[:num_participants]
torch.set_num_threads(1) #

assert len(participants) == num_participants # checking for shenanigans

class_correct = [0] * NUM_CLASSES
class_total = [0] * NUM_CLASSES 
runtime = 0

convert_legacy_config() # LEGACY
dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH, IMG_HEIGHT]) # is that the right way around? :D
model_mpc = crypten.nn.from_pytorch(model, dummy_image)


before_test = Barrier(num_participants)
after_test = Barrier(num_participants)
#criterion = crypten.nn.CrossEntropyLoss()
criterion = crypten.nn.MSELoss()
#criterion = nn.CrossEntropyLoss()
@mpc.run_multiprocess(world_size=num_participants)
def test_model_mpc():
    pid = comm.get().get_rank()
    ws = comm.get().world_size
    name = participants[pid]
    if pid == 0:
        print(f"Hello from the main process (rank#{pid} of {ws})!")
        print(f"My name is {name}.")
        print(f"My colleagues today are: ")
        print(participants)
        
    model_mpc.encrypt(src=0)
        
    before_test.wait()  
    if pid == 0:
        print("Gonna evaluate now...")
    
    test_loss = 0.0
    model_mpc.eval() # prep model for evaluation
    if pid==0:
        start = time()
    for data, target in tqdm(test_loader, position=0): #, desc=f"{name}"):
        data_enc = crypten.cryptensor(data, src=pid)
        target_enc = crypten.cryptensor(target, src=pid)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model_mpc(data_enc)
        # decrypt output
        #output = output.get_plain_text()
        # convert output probabilities to predicted class
        pred = output.argmax(dim=1, one_hot=False)
        # calculate the loss
        if pid == 0:
            if pred.shape != target.shape:
                print((pred.shape, target_enc.shape))
        loss = criterion(pred, target_enc).get_plain_text()
        # update test loss 
        test_loss += loss.item()*data.size(0)
        # compare predictions to true label
        #pred = pred.get_plain_text()
        # decrypt predictions
        pred = pred.get_plain_text()
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1
    if pid==0:
        stop = time()
        runtime = stop - start
        
    if pid == 0:
        print("Done evaluating...")
    
    after_test.wait()
    
    if pid == 0:
        print("Ouputing information...")

    # calculate and print avg test loss
    test_loss = test_loss/len(test_loader.sampler)
    if pid == 0:
        print(f"Test Loss: {test_loss:.6}\n")
        # Print accuracy per class
        for i in range(NUM_CLASSES):
            if class_total[i] > 0:
                print(f"Test Accuracy of {i:5}: "
                      f"{100 * class_correct[i] / class_total[i]:3.0f}% "
                      f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )")
            else:
                print(f"Test Accuracy of {classes[i]}: N/A (no training examples)")
        # Print overall accuracy
        print(f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% "
              f"( {np.sum(class_correct)} / {np.sum(class_total)} )")
    
    # Gather log
    LOG_STR = f"Rank: {pid}\nWorld_Size: {ws}\n\n"
    LOG_STR += f"Test Loss: {test_loss:.6}\n"
    LOG_STR += "\n"
    for i in range(NUM_CLASSES):
        if class_total[i] > 0:
            LOG_STR += f"Test Accuracy of {i:5}: " \
                  f"{100 * class_correct[i] / class_total[i]:3.0f}% " \
                  f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )"
            LOG_STR += "\n"
        else:
            LOG_STR += f"Test Accuracy of {classes[i]}: N/A (no training examples)"
            LOG_STR += "\n"
    LOG_STR += f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " + \
          f"( {np.sum(class_correct)} / {np.sum(class_total)} )"
    with open(f"log/test_log_rank{pid}", "w") as f:
        f.write(LOG_STR)
        
test_model_mpc()

  6%|▌         | 30/500 [00:44<12:19,  1.57s/it]

# Visualize Sample Test Results

In [None]:
# obtain one batch of test images
dataiter = iter(test_loader)
images, labels = dataiter.next()
# get sample outputs
output = model(images)
# convert output probabilities to predicted class
_, preds = torch.max(output, 1)

plot_batch(images, labels, preds)