In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, let's say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We'll see how CrypTen allows Alice and Bob coordinate and classify the data, all while achieving their privacy requirements.

To simulate this scenario, we'll begin by Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

### Initialization
Let's load some MNIST data, and train Alice's network on it.

In [1]:
import crypten
import torch

INFO:root:DistributedCommunicator with rank 0
INFO:root:World size = 1


In [2]:
from torchvision import datasets, transforms

#download data
mnist_train = datasets.MNIST("/tmp", download=True, train=True)
mnist_test = datasets.MNIST("/tmp", download=True, train=False)

#compute normalization factors
data_all = torch.cat([mnist_train.data, mnist_test.data]).float()
data_mean, data_std = data_all.mean(), data_all.std()
tensor_mean, tensor_std = data_mean.unsqueeze(0), data_std.unsqueeze(0)

In [3]:
#Let's define Alice's and Bob's data
data_alice = mnist_train.data
data_bob = mnist_test.data

label_alice = mnist_train.targets
label_bob = mnist_test.targets

#Normalize the data
data_alice_norm = transforms.functional.normalize(data_alice.float(), tensor_mean, tensor_std)
data_bob_norm = transforms.functional.normalize(data_bob.float(), tensor_mean, tensor_std)

#Flatten the data
data_alice_flat = data_alice_norm.flatten(start_dim=1)
data_bob_flat = data_bob_norm.flatten(start_dim=1)

In [4]:
#Alice creates and trains her network on her data

import torch.nn as nn
import torch.nn.functional as F

#Define Alice's network
class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)
        self.batchnorm1 = nn.BatchNorm1d(128)
        self.batchnorm2 = nn.BatchNorm1d(128)
 
    def forward(self, x):
        out = self.fc1(x)
        out = self.batchnorm1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = self.batchnorm2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out
    
model = AliceNet()

#Train Alice's network
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6)

num_examples = 60000
batch_size = 256
num_epochs = 25
log_accuracy = True

for i in range(num_epochs):
    for j in range(0, num_examples, batch_size):
        
        #get the mini-batch
        start, end = j, min(j+batch_size,num_examples)
        sample_flat = data_alice_flat[start:end,:]
        target = label_alice[start:end]
        
        #forward pass: compute prediction
        output = model(sample_flat)

        #compute and print loss
        loss = criterion(output, target)
        
        #zero gradients for learnable parameters
        optimizer.zero_grad()

        #backward pass: compute gradient with respect to model parameters
        loss.backward()

        #update model parameters
        optimizer.step()
        
    #log accuracy every epoch
    if log_accuracy:
        pred = output.argmax(1)
        correct = pred.eq(target)
        correct_count = correct.sum(0, keepdim=True).float()
        accuracy = correct_count.mul_(100.0 / output.size(0))
        print("Epoch", i, "Loss:", loss.item())
        print("\tAccuracy:", accuracy)


Epoch 0 Loss: 0.16712146997451782
	Accuracy: tensor([96.8750])
Epoch 1 Loss: 0.13901349902153015
	Accuracy: tensor([97.9167])
Epoch 2 Loss: 0.1161004975438118
	Accuracy: tensor([97.9167])
Epoch 3 Loss: 0.09419527649879456
	Accuracy: tensor([98.9583])
Epoch 4 Loss: 0.08182505518198013
	Accuracy: tensor([98.9583])
Epoch 5 Loss: 0.06977503001689911
	Accuracy: tensor([98.9583])
Epoch 6 Loss: 0.05783069133758545
	Accuracy: tensor([98.9583])
Epoch 7 Loss: 0.043654829263687134
	Accuracy: tensor([98.9583])
Epoch 8 Loss: 0.027489224448800087
	Accuracy: tensor([98.9583])
Epoch 9 Loss: 0.016932297497987747
	Accuracy: tensor([98.9583])
Epoch 10 Loss: 0.010210790671408176
	Accuracy: tensor([100.])
Epoch 11 Loss: 0.006478141527622938
	Accuracy: tensor([100.])
Epoch 12 Loss: 0.0046099042519927025
	Accuracy: tensor([100.])
Epoch 13 Loss: 0.0036421415861696005
	Accuracy: tensor([100.])
Epoch 14 Loss: 0.002934912219643593
	Accuracy: tensor([100.])
Epoch 15 Loss: 0.002525407588109374
	Accuracy: tensor([1

### Encryption
Alice now has a trained neural network that can classify data. Let's see how we can use CrypTen to encrypt this network, so it can be used to classify data without revealing its parameters. 

In CrypTen, encrypting PyTorch network is straightforward: first, we call the function ```from_pytorch``` that sets up a CrypTen network from the PyTorch network. Then, we call ```encrypt``` on the CrypTen network to encrypt its parameters. After encryption, the CrypTen network can also decrypted (see the ```decrypt``` function). <b>TODO: encrypt needs a src, and Bob needs a dummy model.</b>

In addition to the PyTorch network, the ```from_pytorch``` function also requires a dummy input of the shape of the model's input (this is a similar requirement to what we saw in Tutorial 3). 

We'll walk through an example below:

In [5]:
#Alice encrypts her network

#Create a dummy input with the same shape as the model input
dummy_input = torch.empty((1, 784))

#Construct a CrypTen network with the trained model and dummy_input
alice_private_model = crypten.nn.from_pytorch(model, dummy_input)

#Encrypt the CrypTen network
alice_private_model.encrypt()

#Alice's model is now encrypted: we can check the model's 'encrypted' flag!
print("Encryption flag in Alice's CrypTen model:", alice_private_model.encrypted)

Encryption flag in Alice's CrypTen model: True


In the next tutorial, we'll take a closer look at Alice's CrypTen network, to understand the details of how the parameters of each layer are encrypted.

Let's also encrypt Bob's data (this step is identical to what we've seen in Tutorial 3).

In [6]:
#Bob encrypts his data, note src=1 
data_bob_enc = crypten.cryptensor(data_bob_flat, src=0) #TODO: src = 0 for now

### Classifying Encrypted Data with Encrypted Model
We can finally use Alice's encrypted network to classify Bob's encrypted data. This step is identical to PyTorch, except we'll use the encrypted network and data instead of the plaintext versions that PyTorch uses. Thus, we can just do: 

In [7]:
#Alice runs inference on her encrypted model with Bob's encrypted data
alice_private_model.eval()

with torch.no_grad():
     output_enc = alice_private_model(data_bob_enc)


The result of this classification is encrypted. To see this, here let's just check whether the result is an encrypted tensor; in the next tutorial, we'll look into the values of tensor and confirm the encryption. 

Finally, we'll decrypt the result. As we discussed before, Alice and Bob both have access to the decrypted result. 

In [8]:
#The results are encrypted: 
print("Output tensor encrypted:", crypten.is_encrypted_tensor(output_enc)) 
print()

#Decrypting the result
print("Decrypted output:\n", output_enc.get_plain_text())

Output tensor encrypted: True

Decrypted output:
 tensor([[-1.7982, -1.4551,  0.0343,  ..., 11.4704, -2.6817, -2.0960],
        [-0.3612,  1.1524, 14.0475,  ..., -3.0585, -1.7713, -6.4185],
        [-1.4942, 11.0817,  0.4863,  ..., -0.2121,  0.3387, -2.5970],
        ...,
        [-3.4620, -2.6589, -2.7017,  ...,  0.1182,  0.3336, -2.0951],
        [ 0.5075, -2.0209, -6.8010,  ..., -3.7700,  3.5046, -5.3319],
        [-0.1711, -2.8731, -2.3386,  ..., -2.8373, -2.0506, -3.5235]])


In [9]:
#Finally, we'll compute the accuracy of the output:
output = output_enc.get_plain_text()

with torch.no_grad():
    pred = output.argmax(1)
    correct = pred.eq(label_bob)
    correct_count = correct.sum(0, keepdim=True).float()
    accuracy = correct_count.mul_(100.0 / output.size(0))
    print("Accuracy:", accuracy)

Accuracy: tensor([98.1100])


This completes our tutorial. While we have used a simple network here to illustrate the concepts, CrypTen provides primitives to allow for encryption of substantially more complex networks. In our examples section, we demonstrate how CrypTen can be used to encrypt LeNet and ResNet, among others.