In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, let's say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We'll see how CrypTen allows Alice and Bob coordinate and classify the data, all while achieving their privacy requirements.

To simulate this scenario, we'll begin by Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

### Initialization
Let's load some MNIST data, and train Alice's network on it.

In [1]:
import crypten
import torch

crypten.init()

In [2]:
from torchvision import datasets, transforms

#download data
mnist_train = datasets.MNIST("/tmp", download=True, train=True)
mnist_test = datasets.MNIST("/tmp", download=True, train=False)

#compute normalization factors
data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
data_mean, data_std = data_all.mean(), data_all.std()
tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

In [3]:
#Let's define Alice's and Bob's data
data_alice = mnist_train.data
data_bob = mnist_test.data

label_alice = mnist_train.targets
label_bob = mnist_test.targets

#Normalize the data
data_alice_norm = transforms.functional.normalize(data_alice.float(), tensor_mean, tensor_std)
data_bob_norm = transforms.functional.normalize(data_bob.float(), tensor_mean, tensor_std)

#Flatten the data
data_alice_flat = data_alice_norm.flatten(start_dim=1)
data_bob_flat = data_bob_norm.flatten(start_dim=1)

In [8]:
#Alice creates and trains her network on her data

import torch.nn as nn
import torch.nn.functional as F

#Define Alice's network
class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        #self.batchnorm1 = nn.BatchNorm1d(128)
        #self.batchnorm2 = nn.BatchNorm1d(128)
 
    def forward(self, x):
        out = self.fc1(x)
        #out = self.batchnorm1(out)
        out = F.relu(out)
        out = self.fc2(out)
        #out = self.batchnorm2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out
    
model = AliceNet()

#Train Alice's network
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6)

num_examples = 60000
batch_size = 256
num_epochs = 5
log_accuracy = True

for i in range(num_epochs):
    for j in range(0, num_examples, batch_size):
        
        #get the mini-batch
        start, end = j, min(j+batch_size,num_examples)
        sample_flat = data_alice_flat[start:end,:]
        target = label_alice[start:end]
        
        #forward pass: compute prediction
        output = model(sample_flat)

        #compute and print loss
        loss = criterion(output, target)
        
        #zero gradients for learnable parameters
        optimizer.zero_grad()

        #backward pass: compute gradient with respect to model parameters
        loss.backward()

        #update model parameters
        optimizer.step()
        
    #log accuracy every epoch
    if log_accuracy:
        pred = output.argmax(1)
        correct = pred.eq(target)
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))
        print("Epoch", i, "Loss:", loss.item())
        print("\tAccuracy:", accuracy)


Epoch 0 Loss: 0.32299792766571045
	Accuracy: tensor([94.7917])
Epoch 1 Loss: 0.2552248537540436
	Accuracy: tensor([97.9167])
Epoch 2 Loss: 0.21967248618602753
	Accuracy: tensor([97.9167])
Epoch 3 Loss: 0.19773970544338226
	Accuracy: tensor([97.9167])
Epoch 4 Loss: 0.1811065673828125
	Accuracy: tensor([98.9583])


### Encryption
Alice now has a trained neural network that can classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. 

In CrypTen, encrypting PyTorch network is straightforward: first, we call the function `from_pytorch` that sets up a CrypTen network from the PyTorch network. Then, we call `encrypt` on the CrypTen network to encrypt its parameters. After encryption, the CrypTen network can also decrypted (see the `decrypt` function).

In addition to the PyTorch network, the `from_pytorch` function also requires a dummy input of the shape of the model's input -- this is a similar requirement to what we saw in Tutorial 3. (In the next tutorial, we'll take a closer look at Alice's CrypTen network, to understand the details of how the parameters of each layer are encrypted.)
We'll also encrypt Bob's data -- again, this step is identical to what we've seen in Tutorial 3. We'll walk through an example below. 

<i><small>(Technical note: Since Jupyter notebooks only run a single process, we use a custom decorator `mpc.run_multiprocess` to simulate a multi-party world below.)</small><i>

In [9]:
import crypten.mpc as mpc
import crypten.communicator as comm

@mpc.run_multiprocess(world_size=2)
def encrypt_model_and_data():
    
    rank = comm.get().get_rank()
    
    if rank == 0:
        #Alice gets the trained model
        plaintext_model = model
    else:
        #Bob gets a dummy model with random parameters
        plaintext_model = AliceNet()
        
    #Encrypt the model    
    #Create a dummy input with the same shape as the model input
    dummy_input = torch.empty((1, 784))
    #Construct a CrypTen network with the trained model and dummy_input
    private_model = crypten.nn.from_pytorch(plaintext_model, dummy_input)
    #Encrypt the CrypTen network. Alice has the real model, so we encrypt with src=0
    private_model.encrypt(src=0)
    #The model is now encrypted: we can check the model's 'encrypted' flag!
    if rank == 0:
        print("Encryption flag in CrypTen model:", private_model.encrypted)
    
    if rank == 1:
        #Bob gets the real data
        plaintext_data = data_bob_flat
    else:
        #Alice gets an empty torch tensor of the same shape
        plaintext_data = torch.empty(data_bob_flat.size())
        
    #Encrypt the data
    #Bob has the real data, so we encrypt with src=1
    data_enc = crypten.cryptensor(plaintext_data, src=1)
    if rank == 0:
        print("Encryption of Data:", crypten.is_encrypted_tensor(data_enc))
    #print(rank, data_enc._tensor, "\n")
        
z = encrypt_model_and_data()

Encryption flag in CrypTen model: True
Encryption of Data: True


### Classifying Encrypted Data with Encrypted Model
We can finally use Alice's encrypted network to classify Bob's encrypted data. This step is identical to PyTorch, except we'll use the encrypted network and data instead of the plaintext versions that PyTorch uses. 

<small><i>(Technical note: We have simulated a multi-party world with the `@mpc.run_multiprocess` decorator in the previous cell. However, as a result, the variables loaded do not carry over from cell to cell as is customary in a notebook. Therefore, in order to illustrate next steps, we reinitialize  `model_enc` and `data_enc` both with `src=0`)</i></small>

In [10]:
# The following code is required for demonstrating the learning algorithm in our notebook. 
# As Jupyter notebooks run only a single process, the model and the data both need to be encrypted
# with src=0 in order for the remaining code to run. In a regular CrypTen implementation 
# (see the CrypTen examples folder), data_enc would be encrypted with src=1 as shown in the cell above.
dummy_input = torch.empty((1, 784))
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.encrypt(src=0)
data_enc = crypten.cryptensor(data_bob_flat, src=0) #This would use src=1 outside Jupyter notebooks

In [11]:
#We run inference on the encrypted network with the encrypted data
private_model.eval()

with torch.no_grad():
     output_enc = private_model(data_enc)

<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>
<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>
<class 'int'>
<class 'int'>
<class 'crypten.mpc.mpc.MPCTensor'>
<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>
<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>
<class 'int'>
<class 'int'>
<class 'crypten.mpc.mpc.MPCTensor'>
<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>
<class 'crypten.autograd_cryptensor.AutogradCrypTensor'>


The result of this classification is encrypted. To see this, here let's just check whether the result is an encrypted tensor; in the next tutorial, we'll look into the values of tensor and confirm the encryption. 

We can now decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted output of the model, and can both use this to obtain the labels. 

In [12]:
#The results are encrypted: 
print("Output tensor encrypted:", crypten.is_encrypted_tensor(output_enc)) 

#Decrypting the result
output = output_enc.get_plain_text()
print("Decrypted output:\n", output)

#Obtaining the labels
pred = output.argmax(dim=1)
print("Decrypted labels:\n", pred)

Output tensor encrypted: True
Decrypted output:
 tensor([[-0.7364, -4.2769,  3.2639,  ..., 10.7419,  0.7018,  1.0107],
        [ 0.5365,  1.4743,  9.1954,  ..., -3.7552,  1.2168, -7.6075],
        [-4.3711,  6.7733,  1.0815,  ...,  1.2077, -0.3217, -2.0999],
        ...,
        [-5.5424, -6.8985, -3.4222,  ...,  1.9314,  4.1241,  4.9593],
        [ 1.6874, -1.1822, -3.0257,  ..., -4.0624,  3.7262, -4.6972],
        [ 2.2276, -5.1252,  1.8377,  ..., -5.8095, -1.3235, -5.8347]])
Decrypted labels:
 tensor([7, 2, 1,  ..., 4, 5, 6])


In [13]:
#Finally, we'll compute the accuracy of the output:
output = output_enc.get_plain_text()

with torch.no_grad():
    pred = output.argmax(1)
    correct = pred.eq(label_bob)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    print("Accuracy:", accuracy)

Accuracy: tensor([96.7300])


This completes our tutorial. While we have used a simple network here to illustrate the concepts, CrypTen provides primitives to allow for encryption of substantially more complex networks. In our examples section, we demonstrate how CrypTen can be used to encrypt LeNet and ResNet, among others.