# Under the Hood of Encrypted Neural Networks

This tutorial is optional, and can be skipped without loss of continuity.

In this tutorial, we'll take a look at how CrypTen performs inference with an encrypted neural network on encrypted data. We'll see how the data remains encrypted through all the operations, and yet is able to obtain accurate results after the computation. 

In [1]:
import crypten
import torch

crypten.init() 
torch.set_num_threads(1)

# Ignore warnings
import warnings; 
warnings.filterwarnings("ignore")

# Keep track of all created temporary files so that we can clean up at the end
temp_files = []

## A Simple Linear Layer
We'll start by examining how a single Linear layer works in CrypTen. We'll instantiate a torch Linear layer, convert to CrypTen layer, encrypt it, and step through some toy data with it. As in earlier tutorials, we'll assume Alice has the rank 0 process and Bob has the rank 1 process. We'll also assume Alice has the layer and Bob has the data.

In [2]:
# Define ALICE and BOB src values
ALICE = 0
BOB = 1

In [3]:
import torch.nn as nn

# Instantiate single Linear layer
layer_linear = nn.Linear(4, 2)

# The weights and the bias are initialized to small random values
print("Plaintext Weights:", layer_linear._parameters['weight'])
print("Plaintext Bias:", layer_linear._parameters['bias'])

# Save the plaintext layer
layer_linear_file = "/tmp/tutorial5_layer_alice1.pth"
crypten.save(layer_linear, layer_linear_file)
temp_files.append(layer_linear_file) 

# Generate some toy data
features = 4
examples = 3
toy_data = torch.rand(examples, features)

# Save the plaintext toy data
toy_data_file = "/tmp/tutorial5_data_bob1.pth"
crypten.save(toy_data, toy_data_file)
temp_files.append(toy_data_file)

Plaintext Weights: Parameter containing:
tensor([[ 0.1231, -0.3265,  0.2590, -0.0752],
        [-0.2957,  0.1935,  0.2226, -0.0387]], requires_grad=True)
Plaintext Bias: Parameter containing:
tensor([ 0.1789, -0.0787], requires_grad=True)


In [4]:
import crypten.mpc as mpc
import crypten.communicator as comm

@mpc.run_multiprocess(world_size=2)
def forward_single_encrypted_layer():
    rank = comm.get().get_rank()
    
    # Load and encrypt the layer
    layer = crypten.load(layer_linear_file, dummy_model=nn.Linear(4, 2), src=ALICE)
    layer_enc = crypten.nn.from_pytorch(layer, dummy_input=torch.empty((1,4)))
    layer_enc.encrypt(src=ALICE)
    
    # Note that layer parameters are encrypted:
    if rank == 0:  # Print once for readability
        print("Weights:\n", layer_enc.weight.share)
        print("Bias:\n", layer_enc.bias.share)
        print()
    
    # Load and encrypt data
    data_enc = crypten.load(toy_data_file, src=BOB)
    
    # Apply the encrypted layer (linear transformation):
    result_enc = layer_enc.forward(data_enc)
    
    # Decrypt the result:
    result = result_enc.get_plain_text()
    
    # Examine the result
    if rank == 0: # Print once for readability
        print("Decrypted result:\n", result)
        
forward_single_encrypted_layer()

Weights:
 tensor([[ 2898316751075970645, -4579388503044133723,  5038242173536375749,
         -1490107348288071119],
        [-1557770914353334546, -2564120523467162291, -6444290604699277632,
         -1291727171320069465]])
Bias:
 tensor([2530322031011295021, 3780145476852328879])

Decrypted result:
 tensor([[ 0.4359, -0.0401],
        [ 0.3171, -0.1922],
        [ 0.2721, -0.2406]])


[None, None]

We can see that the application of the encrypted linear layer on the encrypted data produces an encrypted result, which we can then decrypt to get the values in plaintext.

Let's look at a second linear transformation, to give a flavor of how accuracy is preserved even when the data and the layer are encrypted. We'll look at a uniform scaling transformation, in which all tensor elements are multiplied by the same scalar factor. Again, we'll assume Alice has the layer and the rank 0 process, and Bob has the data and the rank 1 process.

In [5]:
# Initialize a linear layer with random weights
layer_scale = nn.Linear(3, 3)

# Construct a uniform scaling matrix: we'll scale by factor 5
factor = 5
layer_scale._parameters['weight'] = torch.eye(3)*factor
layer_scale._parameters['bias'] = torch.zeros_like(layer_scale._parameters['bias'])

# Save the plaintext layer
layer_scale_file = "/tmp/tutorial5_layer_alice2.pth"
crypten.save(layer_scale, layer_scale_file)
temp_files.append(layer_scale_file)

# Construct some toy data
features = 3
examples = 2
toy_data = torch.ones(examples, features)

# Save the plaintext toy data
toy_data_file = "/tmp/tutorial5_data_bob2.pth"
crypten.save(toy_data, toy_data_file)
temp_files.append(toy_data_file)

In [6]:
@mpc.run_multiprocess(world_size=2)
def forward_scaling_layer():
    rank = comm.get().get_rank()
    
    # Load and encrypt the layer
    layer = crypten.load(layer_scale_file, dummy_model=nn.Linear(3, 3), src=ALICE)
    layer_enc = crypten.nn.from_pytorch(layer, dummy_input=torch.empty((1,3)))
    layer_enc.encrypt(src=ALICE)
    
    # Load and encrypt data
    data_enc = crypten.load(toy_data_file, src=BOB)   
    
    # Note that layer parameters are (still) encrypted:
    if rank == 0:  # Print once for readability
        print("Weights:\n", layer_enc.weight.share)
        print("Bias:\n", layer_enc.bias.share)
        print()

    # Apply the encrypted scaling transformation
    result_enc = layer_enc.forward(data_enc)

    # Decrypt the result:
    result = result_enc.get_plain_text()
    
    # Since both parties have the same decrypted values, print only rank 0 for readability
    if rank == 0:
        print("Plaintext result:\n", (result))
        
z = forward_scaling_layer()

Weights:
 tensor([[ 3824570129683350441, -8540765741594476771, -3033476021071207225],
        [-3600177681891051605, -5879382445495720845,   981762740632649929],
        [-8080559122370361369, -1390179875616073746,  5513490018807349512]])
Bias:
 tensor([-1285976053388686819, -5042828109819508063, -6784293398454048581])

Plaintext result:
 tensor([[5., 5., 5.],
        [5., 5., 5.]])


The resulting plaintext tensor is correctly scaled, even though we applied the encrypted transformation on the encrypted input! 

## Multi-layer Neural Networks
Let's now look at how the encrypted input moves through an encrypted multi-layer neural network. 

For ease of explanation, we'll first step through a network with only two linear layers and ReLU activations. Again, we'll assume Alice has a network and Bob has some data, and they wish to run encrypted inference. 

To simulate this, we'll once again generate some toy data and train Alice's network on it. Then we'll encrypt Alice's network, Bob's data, and step through every layer in the network with the encrypted data. Through this, we'll see how the computations get applied although the network and the data are encrypted.

### Setup
As in Tutorial 3, we will first generate 1000 ground truth samples using 50 features and a randomly generated hyperplane to separate positive and negative examples. We will then modify the labels so that they are all non-negative. Finally, we will split the data so that the first 900 samples belong to Alice and the last 100 samples belong to Bob.

In [7]:
# Setup
features = 50
examples = 1000

# Set random seed for reproducibility
torch.manual_seed(1)

# Generate toy data and separating hyperplane
data = torch.randn(examples, features)
w_true = torch.randn(1, features)
b_true = torch.randn(1)
labels = w_true.matmul(data.t()).add(b_true).sign()

# Change labels to non-negative values
labels_nn = torch.where(labels==-1, torch.zeros(labels.size()), labels)
labels_nn = labels_nn.squeeze().long()

# Split data into Alice's and Bob's portions:
data_alice, labels_alice = data[:900], labels_nn[:900]
data_bob, labels_bob = data[900:], labels_nn[900:]

In [8]:
# Define Alice's network
import torch.nn as nn
import torch.nn.functional as F

class AliceNet(nn.Module):
    def __init__(self):
        super(AliceNet, self).__init__()
        self.fc1 = nn.Linear(50, 20)
        self.fc2 = nn.Linear(20, 2)
        
    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        return out

In [9]:
# Train and save Alice's network
model = AliceNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for i in range(500):  
    #forward pass: compute prediction
    output = model(data_alice)
    
    #compute and print loss
    loss = criterion(output, labels_alice)
    if i % 100 == 99:
        print("Epoch", i, "Loss:", loss.item())
    
    #zero gradients for learnable parameters
    optimizer.zero_grad()
    
    #backward pass: compute gradient with respect to model parameters
    loss.backward()
    
    #update model parameters
    optimizer.step()

sample_trained_model_file = '/tmp/tutorial5_alice_model.pth'
torch.save(model, sample_trained_model_file)
temp_files.append(sample_trained_model_file)

Epoch 99 Loss: 0.2470429241657257
Epoch 199 Loss: 0.08965438604354858
Epoch 299 Loss: 0.05166155472397804
Epoch 399 Loss: 0.03510778397321701
Epoch 499 Loss: 0.026072446256875992


### Stepping through a Multi-layer Network

Let's now look at what happens when we load the network Alice's has trained and encrypt it. First, we'll look at how the network structure changes when we convert it from a PyTorch network to CrypTen network.

In [10]:
# Load the trained network to Alice
model_plaintext = crypten.load(sample_trained_model_file, dummy_model=AliceNet(), src=ALICE)

# Convert the trained network to CrypTen network 
private_model = crypten.nn.from_pytorch(model_plaintext, dummy_input=torch.empty((1, 50)))
# Encrypt the network
private_model.encrypt(src=ALICE)

# Examine the structure of the encrypted CrypTen network
for name, curr_module in private_model._modules.items():
    print("Name:", name, "\tModule:", curr_module)

Name: 5 	Module: <crypten.nn.module.Linear object at 0x7fb688db1750>
Name: 6 	Module: <crypten.nn.module.ReLU object at 0x7fb688de5610>
Name: output 	Module: <crypten.nn.module.Linear object at 0x7fb688de5650>


We see that the encrypted network has 3 modules, named '5', '6' and 'output', denoting the first Linear layer, the ReLU activation, and the second Linear layer respectively. These modules are encrypted just as the layers in the previous section were. 

Now let's encrypt Bob's data, and step it through each encrypted module. For readability, we will use only 3 examples from Bob's data to illustrate the inference. Note how Bob's data remains encrypted after each individual layer's computation!

In [11]:
# Pre-processing: Select only the first three examples in Bob's data for readability
data = data_bob[:3]
sample_data_bob_file = '/tmp/tutorial5_data_bob3.pth'
torch.save(data, sample_data_bob_file)
temp_files.append(sample_data_bob_file)

In [12]:
@mpc.run_multiprocess(world_size=2)
def step_through_two_layers():    
    rank = comm.get().get_rank()

    # Load and encrypt the network
    model = crypten.load(sample_trained_model_file, dummy_model=AliceNet(), src=ALICE)
    private_model = crypten.nn.from_pytorch(model, dummy_input=torch.empty((1, 50)))
    private_model.encrypt(src=ALICE)

    # Load and encrypt the data
    data_enc = crypten.load(sample_data_bob_file, src=BOB)

    # Forward through the first layer
    out_enc = private_model._modules['5'].forward(data_enc)
    print("Rank: {} First Linear Layer: Output Encrypted: {}\n".format(rank, crypten.is_encrypted_tensor(out_enc)))
    print("Rank: {} Shares after First Linear Layer:{}\n".format(rank, out_enc.share))

    # Apply ReLU activation
    out_enc = private_model._modules['6'].forward(out_enc)
    print("Rank: {} ReLU:\n Output Encrypted: {}\n".format(rank, crypten.is_encrypted_tensor(out_enc)))
    print("Rank: {} Shares after ReLU: {}\n".format(rank, out_enc.share))

    # Forward through the second Linear layer
    out_enc = private_model._modules['output'].forward(out_enc)
    print("Rank: {} Second Linear layer:\n Output Encrypted: {}\n".format(rank, crypten.is_encrypted_tensor(out_enc))), 
    print("Rank: {} Shares after Second Linear layer:{}\n".format(rank, out_enc.share))

    # Decrypt the output
    out_dec = out_enc.get_plain_text()
    # Since both parties have same decrypted results, only print the rank 0 output
    if rank == 0:
        print("Decrypted output:\n Output Encrypted:", crypten.is_encrypted_tensor(out_dec))
        print("Tensors:\n", out_dec)
    
z = step_through_two_layers()

Rank: 0 First Linear Layer: Output Encrypted: True
Rank: 1 First Linear Layer: Output Encrypted: True


Rank: 1 Shares after First Linear Layer:tensor([[-6301967327924596124,  1441054799224171566,  6544478469984047445,
          5552031352886356103,  5272638317470990646,  7596515253675258249,
           -69400202746828565,  5828418225954818579, -4749428913547144580,
          6606576454575994085, -8611449871174003033, -6879516893031986420,
         -6036166085005125285, -4328313888449389705, -1251365986125740619,
          1495196102692730834,  6980508649455116950, -3898827469711590256,
         -3534438480830590527,  7371155115760810984],
        [-6302093306442423781,  1441068736631097953,  6544305404511755054,
          5552029867695634565,  5272413795589744532,  7596469503207299393,
           -69427899322696412,  5828641617418962017, -4749404070611225973,
          6606436629940671374, -8611581346037569523, -6879276243271746212,
         -6036089440109381038, -4328277524632069842,

Again, we emphasize that the output of each layer is an encrypted tensor. Only after the final call to `get_plain_text` do we get the plaintext tensor.

### From PyTorch to CrypTen: Structural Changes in Network Architecture 

We have used a simple two-layer network in the above example, but the same ideas apply to more complex networks and operations. However, in more complex networks, there may not always be a one-to-one mapping between the PyTorch layers and the CrypTen layers. This is because we use PyTorch's onnx implementation to convert PyTorch models to CrypTen models. 
As an example, we'll take a typical network used to classify digits in MNIST data, and look at what happens to its structure we convert it to a CrypTen module. (As we only wish to illustrate the structural changes in layers, we will not train this network on data; we will just use it with its randomly initialized weights). 

In [13]:
# Define Alice's network
class AliceNet2(nn.Module):
    def __init__(self):
        super(AliceNet2, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 4 * 4, 100)
        self.fc2 = nn.Linear(100, 10)
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.batchnorm2 = nn.BatchNorm2d(16)
        self.batchnorm3 = nn.BatchNorm1d(100)
 
    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = F.relu(out)
        out = F.avg_pool2d(out, 2)
        out = self.conv2(out)
        out = self.batchnorm2(out)
        out = F.relu(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.batchnorm3(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out
    
model = AliceNet2()

# Let's encrypt the complex network. 
# Create dummy input of the correct input shape for the model
dummy_input = torch.empty((1, 1, 28, 28))

# Encrypt the network
private_model = crypten.nn.from_pytorch(model, dummy_input)
private_model.encrypt(src=ALICE)

# Examine the structure of the encrypted network
for name, curr_module in private_model._modules.items():
    print("Name:", name, "\tModule:", curr_module)

Name: 24 	Module: <crypten.nn.module.Conv2d object at 0x7fb6890a62d0>
Name: 25 	Module: <crypten.nn.module._BatchNorm object at 0x7fb6890a6190>
Name: 26 	Module: <crypten.nn.module.ReLU object at 0x7fb6890a6310>
Name: 27 	Module: <crypten.nn.module._ConstantPad object at 0x7fb6890a6c10>
Name: 28 	Module: <crypten.nn.module.AvgPool2d object at 0x7fb6890a6290>
Name: 29 	Module: <crypten.nn.module.Conv2d object at 0x7fb6b88de510>
Name: 30 	Module: <crypten.nn.module._BatchNorm object at 0x7fb6890a6e50>
Name: 31 	Module: <crypten.nn.module.ReLU object at 0x7fb6890a6150>
Name: 32 	Module: <crypten.nn.module._ConstantPad object at 0x7fb6890a61d0>
Name: 33 	Module: <crypten.nn.module.AvgPool2d object at 0x7fb6890a6f90>
Name: 34 	Module: <crypten.nn.module.Constant object at 0x7fb6890a6c50>
Name: 35 	Module: <crypten.nn.module.Shape object at 0x7fb6890a6ed0>
Name: 36 	Module: <crypten.nn.module.Gather object at 0x7fb6890a6d90>
Name: 37 	Module: <crypten.nn.module.Constant object at 0x7fb6890a6

Notice how the CrypTen network has split some the layers in the PyTorch module into several CrypTen modules. Each PyTorch operation may correspond to one or more operations in CrypTen. However, during the conversion, these are sometimes split due to limitations intorduced by onnx.

Before exiting this tutorial, please clean up the files generated using the following code.

In [14]:
import os
for fn in temp_files:
    if os.path.exists(fn): os.remove(fn)