# Classification with Encrypted Neural Networks

In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, let's say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We'll see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.

To simulate this scenario, we'll begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

## Initialization
Let's load some MNIST data, and train Alice's network on it. We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. 

In [1]:
import crypten
import torch

crypten.init()

#ignore warnings
import warnings; 
warnings.filterwarnings("ignore")

In [2]:
# Run script that downloads the publicly available MNIST data, and splits the data as required.
%run ./mnist_utils.py --option train_v_test

In [3]:
# Alice creates and trains her network on her data
import torch.nn as nn
import torch.nn.functional as F

# Define Alice's network
class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
 
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out
    
model = AliceNet()

# Load Alice's data
data_alice = crypten.load('/tmp/alice_train.pth', src=0)
label_alice = crypten.load('/tmp/alice_train_labels.pth', src=0)
label_alice = label_alice.long()
data_alice_flat = data_alice.flatten(start_dim=1)

# Train Alice's network
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6)

num_examples = 60000
batch_size = 256
num_epochs = 2
log_accuracy = True

for i in range(num_epochs):
    for j in range(0, num_examples, batch_size):
        
        # get the mini-batch
        start, end = j, min(j+batch_size,num_examples)
        sample_flat = data_alice_flat[start:end,:]
        target = label_alice[start:end]
        
        # forward pass: compute prediction
        output = model(sample_flat)

        # compute and print loss
        loss = criterion(output, target)
        
        # zero gradients for learnable parameters
        optimizer.zero_grad()

        # backward pass: compute gradient with respect to model parameters
        loss.backward()

        # update model parameters
        optimizer.step()
        
    # log accuracy every epoch
    if log_accuracy:
        pred = output.argmax(1)
        correct = pred.eq(target)
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))
        print("Epoch {0} Loss: {1:.4f}".format(i, loss.item()))
        print("\tAccuracy: {0:.4f}".format(accuracy.item()))
        
torch.save(model, '/tmp/alice_model.ptr')

Epoch 0 Loss: 0.3269
	Accuracy: 92.7083
Epoch 1 Loss: 0.2534
	Accuracy: 97.9167


## Encryption
Alice now has a trained neural network that can classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. 

In CrypTen, encrypting PyTorch network is straightforward: first, we call the function `from_pytorch` that sets up a CrypTen network from the PyTorch network. Then, we call `encrypt` on the CrypTen network to encrypt its parameters. After encryption, the CrypTen network can also decrypted (see the `decrypt` function).

In addition to the PyTorch network, the `from_pytorch` function also requires a dummy input of the shape of the model's input. The dummy input simply needs to be a `torch` tensor of the same shape; the values inside the tensor do not matter. (This is a requirement of `torch.distributed`, our communication backend.) After calling the `encrypt` function, we'll be able to check that the model is encrypted. (We won't actually look into how the model is encrypted in this tutorial; we'll leave that to Tutorial 5). 

We'll also encrypt Bob's data -- this step is identical to what we've seen in Tutorial 3. Let's walk through an example below.

<i><small>(Technical note: Since Jupyter notebooks only run a single process, we use a custom decorator `mpc.run_multiprocess` to simulate a multi-party world below.)</small><i>

In [4]:
import crypten.mpc as mpc
import crypten.communicator as comm

@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    
    rank = comm.get().get_rank()
    
    if rank == 0:
        #Alice gets the trained model
        plaintext_model = model
    else:
        #Bob gets a dummy model with random parameters
        plaintext_model = AliceNet()

    #Encrypt the model    
    #Create a dummy input with the same shape as the model input
    dummy_input = torch.empty((1, 784))
    #Construct a CrypTen network with the trained model and dummy_input
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    #Encrypt the CrypTen network. Alice has the real model, so we encrypt with src=0
    private_model.encrypt(src=0)
    #The model is now encrypted: we can check the model's 'encrypted' flag!
    if rank == 0:
        print("Encryption flag in CrypTen model:", private_model.encrypted)
    
    plaintext_data = crypten.load('/tmp/bob_test.pth', src=1)
        
    #Encrypt the data
    #Bob has the real data, so we encrypt with src=1
    data_enc = crypten.cryptensor(plaintext_data, src=1)
    if rank == 0:
        print("Encryption of Data:", crypten.is_encrypted_tensor(data_enc))
    #print(rank, data_enc._tensor, "\n")
        
z = encrypt_model_and_data()

Encryption flag in CrypTen model: True
Encryption of Data: True


## Classifying Encrypted Data with Encrypted Model
We can finally use Alice's encrypted network to classify Bob's encrypted data. This step is identical to PyTorch, except we'll use the encrypted network and data instead of the plaintext versions that PyTorch uses. 

<small><i>(Technical note: We have simulated a multi-party world with the `@mpc.run_multiprocess` decorator in the previous cell. However, as a result, the variables loaded do not carry over from cell to cell as is customary in a notebook. Therefore, in order to illustrate next steps, we reinitialize  `model_enc` and `data_enc` both with `src=0`)</i></small>

In [5]:
# The following code is required for demonstrating the learning algorithm in our notebook. 
# As Jupyter notebooks run only a single process, the model and the data both need to be encrypted
# with src=0 in order for the remaining code to run. In a regular CrypTen implementation 
# (see the CrypTen examples folder), data_enc would be encrypted with src=1 as shown in the cell above.
dummy_input = torch.empty((1, 784))
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.encrypt(src=0)

#These next two lines would use src=1 outside Jupyter notebooks
plaintext_data = crypten.load('/tmp/bob_test.pth', src=0)
data_enc = crypten.cryptensor(plaintext_data, src=0)

In [6]:
#We run inference on the encrypted network with the encrypted data
private_model.eval()

with torch.no_grad():
     output_enc = private_model(data_enc)

The result of this classification is encrypted. To see this, here let's just check whether the result is an encrypted tensor; in the next tutorial, we'll look into the values of tensor and confirm the encryption. 

We can now decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted output of the model, and can both use this to obtain the labels. 

In [7]:
#The results are encrypted: 
print("Output tensor encrypted:", crypten.is_encrypted_tensor(output_enc)) 

#Decrypting the result
output = output_enc.get_plain_text()
print("Decrypted output:\n", output)

#Obtaining the labels
pred = output.argmax(dim=1)
print("Decrypted labels:\n", pred)

Output tensor encrypted: True
Decrypted output:
 tensor([[ 0.3764, -4.4840,  3.9561,  ..., 10.0836, -1.3179,  1.5138],
        [ 2.1316, -0.5320,  7.6382,  ..., -5.9090,  1.1773, -7.3301],
        [-3.6839,  6.7675,  1.1226,  ...,  0.3997,  0.3834, -2.1214],
        ...,
        [-4.4886, -5.7316, -2.8467,  ...,  0.3242,  3.6152,  4.9825],
        [-0.4876,  0.5972, -2.1670,  ..., -5.4159,  3.6688, -3.0831],
        [ 1.4960, -6.8193,  4.1900,  ..., -6.5355, -1.7503, -4.0193]])
Decrypted labels:
 tensor([7, 2, 1,  ..., 4, 5, 6])


In [8]:
#Finally, we'll compute the accuracy of the output:
output = output_enc.get_plain_text()
label_bob = crypten.load('/tmp/bob_test_labels.pth')
label_bob = label_bob.long()

with torch.no_grad():
    pred = output.argmax(1)
    correct = pred.eq(label_bob)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    print("Accuracy {0:.4f}".format(accuracy.item()))

Accuracy 94.4000


This completes our tutorial. While we have used a simple network here to illustrate the concepts, CrypTen provides primitives to allow for encryption of substantially more complex networks. In our examples section, we demonstrate how CrypTen can be used to encrypt LeNet and ResNet, among others.