<a href="https://colab.research.google.com/github/inspire-lab/CyberAI-labs/blob/main/category-PrivateAI/Training-Secure-MPC/MPC_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training an Encrypted Neural Network using CrypTen

In this lab, we will walk through an example of how we can train a neural network with CrypTen. This is particularly relevant for the <i>Feature Aggregation</i>, <i>Data Labeling</i> and <i>Data Augmentation</i> use cases. We will focus on the usual two-party setting and show how we can train an accurate neural network for digit classification on the MNIST data.

For concreteness, this lab will step through the <i>Feature Aggregation</i> use cases: Alice and Bob each have part of the features of the data set, and wish to train a neural network on their combined data, while keeping their data private.

## Setup
We'll begin by installing, importing and initializing the `crypten` and `torch` libraries.  

We will use the MNIST dataset to demonstrate how Alice and Bob can learn without revealing protected information. For reference, the feature size of each example in the MNIST data is `28 x 28`. Let's assume Alice has the first `28 x 20` features and Bob has last `28 x 8` features. One way to think of this split is that Alice has the (roughly) top 2/3rds of each image, while Bob has the bottom 1/3rd of each image. We'll use helper script `mnist_utils.py` that downloads the publicly available MNIST data, and splits the data as required.

For simplicity, we will restrict our problem to binary classification: we'll simply learn how to distinguish between 0 and non-zero digits. For speed of execution in the notebook, we will only create a dataset of a 100 examples.


Note: We are installing crypten without dependencies here as the default pip command causes sklearn error.

In [None]:
!pip install torch torchvision omegaconf>=2.0.6 onnx>=1.7.0 pandas>=1.2.2 pyyaml>=5.3.1 tensorboard future scipy>=1.6.0
!pip install --no-deps crypten

Collecting crypten
  Downloading crypten-0.4.1-py3-none-any.whl.metadata (7.4 kB)
Downloading crypten-0.4.1-py3-none-any.whl (259 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/259.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m256.0/259.9 kB[0m [31m9.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m259.9/259.9 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: crypten
Successfully installed crypten-0.4.1


In [None]:
import crypten
import torch

crypten.init()
torch.set_num_threads(1)

In [None]:
%run ./mnist_utils.py --option features --reduced 100 --binary

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 496kB/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.53MB/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 8.69MB/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



Next, we'll define the network architecture below, and then describe how to train it on encrypted data in the next section.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

#Define an example network
class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 12 * 12, 100)
        self.fc2 = nn.Linear(100, 2) # For binary classification, final layer needs only 2 outputs

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 16 * 12 * 12)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

crypten.common.serial.register_safe_class(ExampleNet)

## Encrypted Training

We'll first discuss how the training loop in CrypTen differs from PyTorch. Then, we'll go through a complete example to illustrate training on encrypted data from end-to-end.

### How does CrypTen training differ from PyTorch training?

There are two main ways implementing a CrypTen training loop differs from a PyTorch training loop. We'll describe these items first, and then illustrate them with small examples below.

<i>(1) Use one-hot encoding</i>: CrypTen training requires all labels to use one-hot encoding. This means that when using standard datasets such as MNIST, we need to modify the labels to use one-hot encoding.

<i>(2) Directly update parameters</i>: CrypTen does not use the PyTorch optimizers. Instead, CrypTen implements encrypted SGD by implementing its own `backward` function, followed by directly updating the parameters. As we will see below, using SGD in CrypTen is very similar to using the PyTorch optimizers.

We now show some small examples to illustrate these differences. As before, we will assume Alice has the rank 0 process and Bob has the rank 1 process.

In [None]:
# Define source argument values for Alice and Bob
ALICE = 0
BOB = 1

In [None]:
# Load Alice's data
data_alice_enc = crypten.load_from_party('./alice_train.pth', src=ALICE)

  result = load_closure(f, **kwargs)


In [None]:
# We'll now set up the data for our small example below
# For illustration purposes, we will create small dataset
# and encrypt all of it from source ALICE
x_small = torch.rand(100, 1, 28, 28)
y_small = torch.randint(1, (100,))

# Transform labels into one-hot encoding
label_eye = torch.eye(2)
y_one_hot = label_eye[y_small]

# Transform all data to CrypTensors
x_train = crypten.cryptensor(x_small, src=ALICE)
y_train = crypten.cryptensor(y_one_hot)

# Instantiate and encrypt a CrypTen model
model_plaintext = ExampleNet()
dummy_input = torch.empty(1, 1, 28, 28)
model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
model.encrypt()

  param = torch.from_numpy(numpy_helper.to_array(node))


Graph encrypted module

In [None]:
# Example: Stochastic Gradient Descent in CrypTen

model.train() # Change to training mode
loss = crypten.nn.MSELoss() # Choose loss functions

# Set parameters: learning rate, num_epochs
learning_rate = 0.001
num_epochs = 10

# Train the model: SGD on encrypted data
for i in range(num_epochs):

    # forward pass
    output = model(x_train)
    loss_value = loss(output, y_train)

    # set gradients to zero
    model.zero_grad()

    # perform backward pass
    loss_value.backward()

    # update parameters
    #####################
    #Your code goes here
    #####################

    # examine the loss after each epoch
    print("Epoch: {0:d} Loss: {1:.4f}".format(i, loss_value.get_plain_text()))

Epoch: 0 Loss: 0.5361
Epoch: 1 Loss: 0.5044
Epoch: 2 Loss: 0.4750
Epoch: 3 Loss: 0.4470
Epoch: 4 Loss: 0.4209
Epoch: 5 Loss: 0.3970
Epoch: 6 Loss: 0.3745
Epoch: 7 Loss: 0.3536
Epoch: 8 Loss: 0.3346
Epoch: 9 Loss: 0.3171


### A Complete Example

We now put these pieces together for a complete example of training a network in a multi-party setting.

We'll assume Alice has the rank 0 process, and Bob has the rank 1 process; so we'll load and encrypt Alice's data with `src=0`, and load and encrypt Bob's data with `src=1`. We'll then initialize a plaintext model and convert it to an encrypted model. We'll finally define our loss function, training parameters, and run SGD on the encrypted data. For the purposes of this lab we train on 100 samples; training should complete in ~3 minutes per epoch.

In [None]:
import crypten.mpc as mpc
import crypten.communicator as comm

# Convert labels to one-hot encoding
# Since labels are public in this use case, we will simply use them from loaded torch tensors
labels = torch.load('./train_labels.pth')
labels = labels.long()
labels_one_hot = label_eye[labels]

@mpc.run_multiprocess(world_size=2)
def run_encrypted_training():
    # Load data:
    x_alice_enc = crypten.load_from_party('./alice_train.pth', src=ALICE)
    x_bob_enc = crypten.load_from_party('./bob_train.pth', src=BOB)

    crypten.print(x_alice_enc.size())
    crypten.print(x_bob_enc.size())

    # Combine the feature sets
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)

    # Reshape to match the network architecture
    x_combined_enc = x_combined_enc.unsqueeze(1)


    # Commenting out due to intermittent failure in PyTorch codebase

    # Initialize a plaintext model and convert to CrypTen model
    pytorch_model = ExampleNet()
    model =

    #####################
    #Your code goes here
    #####################
    model.encrypt()
    # Set train mode
    model.train()

    # Define a loss function
    loss = crypten.nn.MSELoss()

    # Define training parameters
    learning_rate = 0.001
    num_epochs = 10
    batch_size = 10
    num_batches = x_combined_enc.size(0) // batch_size

    rank = comm.get().get_rank()
    for i in range(num_epochs):
        crypten.print(f"Epoch {i} in progress:")

        for batch in range(num_batches):
            # define the start and end of the training mini-batch
            start, end = batch * batch_size, (batch + 1) * batch_size

            # construct CrypTensors out of training examples / labels
            x_train = x_combined_enc[start:end]
            y_batch = labels_one_hot[start:end]
            y_train = crypten.cryptensor(y_batch, requires_grad=True)

            # perform forward pass:
            output = model(x_train)
            loss_value = loss(output, y_train)

            # set gradients to "zero"
            model.zero_grad()

            # perform backward pass:
            loss_value.backward()

            # update parameters
            model.update_parameters(learning_rate)

            # Print progress every batch:
            batch_loss = loss_value.get_plain_text()
            crypten.print(f"\tBatch {(batch + 1)} of {num_batches} Loss {batch_loss.item():.4f}")

run_encrypted_training()

  labels = torch.load('./train_labels.pth')


torch.Size([100, 28, 20])
torch.Size([100, 28, 8])
Epoch 0 in progress:
	Batch 1 of 10 Loss 0.4424
	Batch 2 of 10 Loss 0.3554
	Batch 3 of 10 Loss 0.3885
	Batch 4 of 10 Loss 0.3497
	Batch 5 of 10 Loss 0.2915
	Batch 6 of 10 Loss 0.2872
	Batch 7 of 10 Loss 0.2823
	Batch 8 of 10 Loss 0.2705
	Batch 9 of 10 Loss 0.2822
	Batch 10 of 10 Loss 0.2055
Epoch 1 in progress:
	Batch 1 of 10 Loss 0.2413
	Batch 2 of 10 Loss 0.1625
	Batch 3 of 10 Loss 0.2095
	Batch 4 of 10 Loss 0.2300
	Batch 5 of 10 Loss 0.1457
	Batch 6 of 10 Loss 0.1758
	Batch 7 of 10 Loss 0.2108
	Batch 8 of 10 Loss 0.1678
	Batch 9 of 10 Loss 0.1996
	Batch 10 of 10 Loss 0.1271
Epoch 2 in progress:
	Batch 1 of 10 Loss 0.1560
	Batch 2 of 10 Loss 0.0860
	Batch 3 of 10 Loss 0.1378
	Batch 4 of 10 Loss 0.1842
	Batch 5 of 10 Loss 0.0828
	Batch 6 of 10 Loss 0.1283
	Batch 7 of 10 Loss 0.1825
	Batch 8 of 10 Loss 0.1209
	Batch 9 of 10 Loss 0.1651
	Batch 10 of 10 Loss 0.0941
Epoch 3 in progress:
	Batch 1 of 10 Loss 0.1151
	Batch 2 of 10 Loss 0.052

[None, None]

We see that the average batch loss decreases across the epochs, as we expect during training.

## References

1. https://crypten.ai/
2. https://github.com/facebookresearch/CrypTen