# Training an Encrypted Neural Network

In this tutorial, we will walk through an example of how we can train a neural network with CrypTen. This is particularly relevant for the <i>Feature Aggregation</i>, <i>Data Labeling</i> and <i>Data Augmentation</i> use cases. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.

For concreteness, this tutorial will step through the <i>Feature Aggregation</i> use cases: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private. 

## Setup
As usual, we'll begin by importing and initializing the `crypten` and `torch` libraries.  

We will use the MNIST dataset to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`. Let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll again use our helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.

For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. For speed of execution in the notebook, we will only create a dataset of a 100 examples.

In [1]:
import crypten
import torch

crypten.init()
torch.set_num_threads(1)

In [2]:
%run ./mnist_utils.py --option features --reduced 100 --binary

Next, we'll define the network architecture below, and then describe how to train it on encrypted data in the next section. 

In [3]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

## Encrypted Training

After all the material we've covered in earlier tutorials, we only need to know a few additional items for encrypted training. We'll first discuss how the training loop in CrypTen differs from PyTorch. Then, we'll go through a complete example to illustrate training on encrypted data from end-to-end.

### How does CrypTen training differ from PyTorch training?

There are two main ways implementing a CrypTen training loop differs from a PyTorch training loop. We'll describe these items first, and then illustrate them with small examples below.

<i>(1) Use one-hot encoding</i>: CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.

<i>(2) Directly update parameters</i>: CrypTen does not use the PyTorch optimizers. Instead, CrypTen implements encrypted SGD by implementing its own `backward` function, followed by directly updating the parameters. As we will see below, using SGD in CrypTen is very similar to using the PyTorch optimizers.

We now show some small examples to illustrate these differences. As before, we will assume Alice has the rank 0 process and Bob has the rank 1 process.

In [4]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

In [5]:
# Load Alice's data 
data_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)

In [6]:
# We'll now set up the data for our small example below
# For illustration purposes, we will create toy data
# and encrypt all of it from source ALICE
x_small = torch.rand(100, 1, 28, 28)
y_small = torch.randint(1, (100,))

# Transform labels into one-hot encoding
label_eye = torch.eye(2)
y_one_hot = label_eye[y_small]

# Transform all data to CrypTensors
x_train = crypten.cryptensor(x_small, src=ALICE)
y_train = crypten.cryptensor(y_one_hot)

# Instantiate and encrypt a CrypTen model
model_plaintext = ExampleNet()
dummy_input = torch.empty(1, 1, 28, 28)
model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
model.encrypt()

Graph encrypted module

In [7]:
# Example: Stochastic Gradient Descent in CrypTen

model.train() # Change to training mode
loss = crypten.nn.MSELoss() # Choose loss functions

# Set parameters: learning rate, num_epochs
learning_rate = 0.001
num_epochs = 2

# Train the model: SGD on encrypted data
for i in range(num_epochs):

    # forward pass
    output = model(x_train)
    loss_value = loss(output, y_train)
    
    # set gradients to zero
    model.zero_grad()

    # perform backward pass
    loss_value.backward()

    # update parameters
    model.update_parameters(learning_rate) 
    
    # examine the loss after each epoch
    print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))

Epoch: 0 Loss: 0.4035
Epoch: 1 Loss: 0.3672


### A Complete Example

We now put these pieces together for a complete example of training a network in a multi-party setting. 

As in Tutorial 3, we'll assume Alice has the rank 0 process, and Bob has the rank 1 process; so we'll load and encrypt Alice's data with `src=0`, and load and encrypt Bob's data with `src=1`. We'll then initialize a plaintext model and convert it to an encrypted model, just as we did in Tutorial 4. We'll finally define our loss function, training parameters, and run SGD on the encrypted data. For the purposes of this tutorial we train on 100 samples; training should complete in ~3 minutes per epoch.

In [13]:
import crypten.mpc as mpc
import crypten.communicator as comm

# Convert labels to one-hot encoding
# Since labels are public in this use case, we will simply use them from loaded torch tensors
labels = torch.load('/tmp/train_labels.pth')
labels = labels.long()
labels_one_hot = label_eye[labels]

@mpc.run_multiprocess(world_size=2)
def run_encrypted_training():
    # Load data:
    x_alice_enc = crypten.load_from_party('/tmp/alice_train.pth', src=ALICE)
    x_bob_enc = crypten.load_from_party('/tmp/bob_train.pth', src=BOB)
    
    crypten.print(x_alice_enc.size())
    crypten.print(x_bob_enc.size())
    
    # Combine the feature sets: identical to Tutorial 3
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    
    # Reshape to match the network architecture
    x_combined_enc = x_combined_enc.unsqueeze(1)
    
    
    # Initialize a plaintext model and convert to CrypTen model
    pytorch_model = ExampleNet()
    model = crypten.nn.from_pytorch(pytorch_model, dummy_input)
    model.encrypt()
    """
    # Set train mode
    model.train()
  
    # Define a loss function
    loss = crypten.nn.MSELoss()

    # Define training parameters
    learning_rate = 0.001
    num_epochs = 2
    batch_size = 10
    num_batches = x_combined_enc.size(0) // batch_size
    
    rank = comm.get().get_rank()
    for i in range(num_epochs): 
        crypten.print(f"Epoch {i} in progress:")       
        
        for batch in range(num_batches):
            # define the start and end of the training mini-batch
            start, end = batch * batch_size, (batch + 1) * batch_size
                                    
            # construct CrypTensors out of training examples / labels
            x_train = x_combined_enc[start:end]
            y_batch = labels_one_hot[start:end]
            y_train = crypten.cryptensor(y_batch, requires_grad=True)
            
            # perform forward pass:
            output = model(x_train)
            loss_value = loss(output, y_train)
            
            # set gradients to "zero" 
            model.zero_grad()

            # perform backward pass: 
            loss_value.backward()

            # update parameters
            model.update_parameters(learning_rate)
            
            # Print progress every batch:
            batch_loss = loss_value.get_plain_text()
            crypten.print(f"\tBatch {(batch + 1)} of {num_batches} Loss {batch_loss.item():.4f}")
    """

run_encrypted_training()

torch.Size([100, 28, 20])
torch.Size([100, 28, 8])


ERROR:root:One of the parties failed. Check past logs


We see that the average batch loss decreases across the epochs, as we expect during training.

This completes our tutorial. Before exiting this tutorial, please clean up the files generated using the following code.

In [0]:
import os

filenames = ['/tmp/alice_train.pth', 
             '/tmp/bob_train.pth', 
             '/tmp/alice_test.pth',
             '/tmp/bob_test.pth', 
             '/tmp/train_labels.pth',
             '/tmp/test_labels.pth']

for fn in filenames:
    if os.path.exists(fn): os.remove(fn)