# Training an Encrypted Neural Network

In this tutorial, we will walk through an example of how we can train a neural network with CrypTen. Like in Tutorial 3, this is particularly relevant for the <i>Feature Aggregation</i> and <i>Data Augmentation</i> scenarios. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.

The tutorial will step through the <i>Feature Aggregation</i> scenario: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private. 

## Initialization
As usual, we'll begin by importing and initializing the `crypten` and `torch` libraries.  

As in Tutorial 3, we'll use the MNIST data to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`, and there are 60000 examples in training data and 10000 examples in the test data. 

As in Tutorial 3, let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll again use our helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.

In [1]:
import crypten
import torch
from tqdm.notebook import tqdm

crypten.init()

In [2]:
%run ../examples/mnist_utils.py --option features

This tutorial will essentially follow the same steps as in Tutorial 3 to load and encrypt Alice's and Bob's data, and then combine their encrypted data (i.e., Steps (a) to (c)). Step (d) is different, as we are now training a neural network instead of a linear SVM. 

We'll define the network architecture below, and then describe how to train it on encrypted data in the next section. For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. 

In [3]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2)
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

## Encrypted Training

After all the material we've covered in earlier tutorials, we only need to know a few additional items to remember in order to able to train on encrypted data. We'll describe these items first, and then illustrate them with small examples below. After that, we'll demonstrate how encrypted training works end-to-end.
<ul>
<li>We need to transform the input data to `AutogradCrypTensors` from `CrypTensors` before calling the forward pass.  (`AutogradCrypTensors` allow the CrypTensors to store gradients and thus enable backpropagation.) As we show in the examples below, this is easily done by simply calling the `AutogradCrypTensor` constructor with the previously encrypted `CrypTensor`.</li> 
<li>CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.</li>
<li>CrypTen does not use the PyTorch optimizers. It instead directly implements stochastic gradient descent on encrypted data. As we show in the examples below, using SGD in CrypTen is very similar to using the PyTorch optimizers.</li> 

In [4]:
# Example: Transforming input data into AutogradCrypTensors
from crypten.autograd_cryptensor import AutogradCrypTensor

# Load Alice's data 
data_alice = crypten.load('/tmp/alice_train.pth', src=0)
# Create a CrypTensor
data_alice_enc = crypten.cryptensor(data_alice, src=0)
# Create an AutogradCrypTensor from the CrypTensor
data_alice_enc_auto = AutogradCrypTensor(data_alice_enc)

In [5]:
# We'll now set up the data for our small example below
# For illustration purposes, we will create random toy data
x_small = torch.rand(100, 1, 28, 28)
y_small = torch.randint(1, (100,))

# Transform labels into one-hot encoding
label_eye = torch.eye(2)
y_one_hot = label_eye[y_small]

# Transform all data to AutogradCrypTensors
x_train = AutogradCrypTensor(crypten.cryptensor(x_small, src=0))
y_train = AutogradCrypTensor(crypten.cryptensor(y_one_hot))

In [6]:
# Example: Stochastic Gradient Descent in CrypTen
model_plaintext = ExampleNet()

# Encrypt the model: This step is identical to Tutorial 4
dummy_input = torch.empty((1, 1, 28, 28))
model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
model.train()
model.encrypt()

# Choose loss functions
loss = crypten.nn.MSELoss()

# Set parameters: learning rate, num_epochs
learning_rate = 0.001
num_epochs = 2

# Train the model: SGD on encrypted data
for i in range(num_epochs):

    # forward pass
    output = model(x_train)
    loss_value = loss(output, y_train)
    
    # set gradients to zero
    model.zero_grad()

    # perform backward pass
    loss_value.backward()

    # update parameters
    model.update_parameters(learning_rate) 
    
    # examine the loss after each epoch
    print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))

Epoch: 0 Loss: 0.6192
Epoch: 1 Loss: 0.6033


We can now put these pieces together for a complete example that trains a network from scratch in a multi-party setting. As in Tutorial 3, we'll assume Alice has the rank 0 process, and Bob has the rank 1 process; so we'll load and encrypt Alice's data with `src=0`, and load and encrypt Bob's data with `src=1`. We'll then initialize a plaintext model and convert it to an encrypted model, just as we did in Tutorial 4. We'll finally define our loss function, training parameters, and run SGD on the encrypted data. For the purposes of this tutorial we train on 100 samples; training should complete in ~3 minutes per epoch.

<small><i>(Technical note: Since Jupyter notebooks only run a single process, we use a custom decorator mpc.run_multiprocess to simulate a multi-party world below.)</i></small>

In [7]:
import crypten.mpc as mpc
import crypten.communicator as comm
import sys

@mpc.run_multiprocess(world_size=2)
def run_encrypted_training():
    # Load data:
    # Alice's data gets loaded with src=0
    x_alice = crypten.load('/tmp/alice_train.pth', src=0)
    # Bob's data gets loaded with src=1
    x_bob = crypten.load('/tmp/bob_train.pth', src=1)

    # Encrypt the data: 
    # Alice's tensor gets encrypted with src=0
    x_alice_enc = crypten.cryptensor(x_alice, src=0)
    # Bob's tensor gets encrypted with src=1
    x_bob_enc = crypten.cryptensor(x_bob, src=1)
    
    # using crypten.cat to combine the feature sets: identical to Tutorial 3
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    x_combined_enc = x_combined_enc.unsqueeze(1)
    
    # Restrict training to 100 examples for speed
    x_small = x_combined_enc[:100]
    
    # Load labels and restrict to the 100 examples
    y_all = crypten.load('/tmp/train_labels.pth')
    y_small = (y_all[:100]).long()
    # Modify the labels so that:
    # all non-zero digits have class label 1.
    # all zero digits have class label 0
    y_small[y_small == 0] = 0
    y_small[y_small != 0] = 1
    
    # Initialize a plaintext model and encrypt: identical to Tutorial 4
    model_plaintext = ExampleNet()
    dummy_input = torch.empty((1, 1, 28, 28))
    model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
    model.train()
    model.encrypt()
    
    # Define a loss function
    loss = crypten.nn.MSELoss()

    # Define training parameters
    num_epochs = 3
    learning_rate = 0.001
    num_examples = x_small.size(0)
    log_progress = 5
    batch_size = log_progress 
    
    rank = comm.get().get_rank()
    for i in range(num_epochs):  
        last_progress_logged = 0
        # Print output only from rank 0 process for readability
        if rank == 0:
            print(f"Epoch {i} in progress:")
        
        pbar = tqdm(range(0, num_examples, batch_size), leave=False)
        for j in pbar:
            
            # define the start and end of the training mini-batch
            start, end = j, min(j + batch_size, num_examples)
            
            # construct AutogradCrypTensors out of training examples
            x_train = AutogradCrypTensor(x_small[start:end])
            y_one_hot = label_eye[y_small[start:end]]
            y_train = AutogradCrypTensor(crypten.cryptensor(y_one_hot))
            
            # perform forward pass:
            output = model(x_train)
            loss_value = loss(output, y_train)
            
            # set gradients to "zero" 
            model.zero_grad()

            # perform backward pass: 
            loss_value.backward()

            # update parameters
            model.update_parameters(learning_rate)  
            
            # log progress every x examples:
            if j+batch_size - last_progress_logged >= log_progress:
                last_progress_logged += log_progress
                pbar.set_description(f"Loss {loss_value.get_plain_text().item():.4f}")
                          
        # compute accuracy every epoch
        pred = output.get_plain_text().argmax(1)
        correct = pred.eq(y_small[start:end])
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))
        print("Epoch {0:d} completed. Loss: {1:.4f} Accuracy: {2:.4f} \n".format(i, 
                                                                                loss_value.get_plain_text().item(), 
                                                                                accuracy.item()))

z = run_encrypted_training()

Epoch 0 in progress:


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

Epoch 0 completed. Loss: 0.1800 Accuracy: 80.0000 
Epoch 0 completed. Loss: 0.1800 Accuracy: 80.0000 


Epoch 1 in progress:


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

Epoch 1 completed. Loss: 0.1436 Accuracy: 80.0000 
Epoch 1 completed. Loss: 0.1436 Accuracy: 80.0000 


Epoch 2 in progress:


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

Epoch 2 completed. Loss: 0.1215 Accuracy: 80.0000 
Epoch 2 completed. Loss: 0.1215 Accuracy: 80.0000 


