<a href="https://colab.research.google.com/github/inspire-lab/CyberAI-labs/blob/main/category-PrivateAI/CNN-inference-homomorphic-encryption/Encrypted_Convolution_on_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Encrypted Convolution on MNIST

Here we are going to perform encrypted evaluation on MNIST examples, using a convolutional neural network (CNN).

We will be using CKKS extensively in this lab same as in previous lab.

We will start by explaining how the different layers can be performed on encrypted data. Next we train a PyTorch model on MNIST, then implement an equivalent one using TenSEAL, but which can evaluate encrypted inputs.

## Machine Learning Model

With the MNIST dataset in hand, we can use a simple neural network composed of a convolutional layer, followed by two linear layers. Here we use the square activation function for simplicity, and ease of use, given the limitation of the number of multiplications with the CKKS scheme.

We will keep in mind that the input for the model needs to be encrypted using CKKS, but the parameters of the model don't, they will be kept in plain during the whole protocol.

### Model Description
The model is the sequence of the below layers:

- **Conv:** Convolution with 4 kernels. Shape of the kernel is 7x7. Strides are 3x3.
- **Activation:** Square activation function.
- **Linear Layer 1:** Input size: 256. Output size: 64.
- **Activation:** Square activation function.
- **Linear Layer 2:** Input size: 64. Output size: 10.


### Input Representation

In order to keep the memory and computation to its lowest, we will mostly try to use a single ciphertext. It's not always possible, and we often loose some flexibility. For this model, there are two different representations. One for the convolution, and one for the linear layers. The former will be quickly explained in the convolution section. For the latter, it's simply the input vector for the linear layer which is replicated many times to fit the slots of the ciphertexts. So a single ciphertext will contain the whole input for the linear layer.


### Convolution

There is actually different ways for doing convolution, and one way we can do it is via a well-known algorithm that translates the 2D convolution into a single matrix multiplication operation. This operation is often referred to as image-to-column convolution and is depicted in *Figure1*.

<div align="center">
<img src="https://github.com/OpenMined/TenSEAL/blob/main/tutorials/assets/im2col_conv2d.png?raw=true" width="50%"/>
<div><b>Figure1:</b> Image to column convolution</div>
</div>

However, this requires arranging the elements of the input matrix in a special way, and since we can't easily do that with a ciphertext, we have to do this as a pre-processing step before encryption. This also means that only a single convolution can be performed. To perform the convolution, we first need to do *im2col* encoding to the input matrix and encrypt it into a single ciphertext. It's worth noting that the matrix is translated into a vector using vertical scan. We then perform a matrix multiplication between an encrypted matrix (input image encoded in a ciphertext) and a plain vector (the flattened kernel of the convolution). This is done by first constructing this new flattened kernel, which replicates every element in the kernel $n$ times, where $n$ is the number of windows. Then we perform a ciphertext-plaintext multiplication, and continue with a sequence of rotate and sum operations in order to sum the elements of the same window. The process is depicted in *Figure 2* and *Figure 3*.

<div align="center">
<img src="https://github.com/OpenMined/TenSEAL/blob/main/tutorials/assets/im2col_conv2d_ckks1.png?raw=true" width="50%"/>
<div><b>Figure2:</b> Image to column convolution with CKKS - step 1</div>
</div>

<div align="center">
<img src="https://github.com/OpenMined/TenSEAL/blob/main/tutorials/assets/im2col_conv2d_ckks2.png?raw=true" width="50%"/>
<div><b>Figure3:</b> Image to column convolution with CKKS - step 2</div>
</div>

If multiple kernels are used, then we need to perform this operation multiple times, yielding different output ciphertexts. These ciphertexts can later be combined (using a single multiplication) into a flattened vector. So every convolution will output a ciphertext containing 64 useful slots, then combining the 4 kernel outputs will yield us a ciphertext with 256 useful slots that will be the input for the first linear layer. The algorithm requires a single multiplication and $log_2(n)$ ciphertext rotations where $n$ is the number of windows in the convolution.

### Linear Layer
A linear layer boils down to a vector-matrix multiplication and an addition of a bias. The matrix and the bias are not encrypted. The vector-matrix multiplication is implemented based on [Halevi and Shoup ](https://link.springer.com/chapter/10.1007/978-3-662-44371-2_31) diagonal method. It's an accumulation of multiple ciphertext-plaintext multiplications, with slightly different rotations. We iterate over every diagonal in the plain matrix and multiply it with the ciphertext rotated $n$ slots to the left, where $n$ is the index (0-indexed) of the diagonal. The process is depicted in *Figure 4*. The algorithm runs in $O(n)$ where $n$ is the size of the encrypted vector.

<div align="center">
<img src="https://github.com/OpenMined/TenSEAL/blob/main/tutorials/assets/vec-matmul.png?raw=true" width="65%"/>
<div><b>Figure4:</b> Vector-Matrix Multiplication</div>
</div>

### Square Activation
The square activation is pretty straightforward. We just multiply a ciphertext by itself.


Building on these operations, we now know that this evaluation requires exactly 6 multiplications to be performed, 2 for the convolution, 1 for the first square activation, 1 for the first linear layer, 1 for the second square activation, and finally 1 for the last linear layer.

## Training

Now that we know how we can implement such a model via HE, we will start using a library called [TenSEAL](https://github.com/OpenMined/TenSEAL) that implements all these operations we have been describing. But first, we need to train a plain PyTorch model to classify the MNIST dataset.

In [None]:
import torch
from torch.utils.data import Subset
import numpy as np
import torchvision.datasets as datasets  # Import datasets from torchvision
import torchvision.transforms as transforms

train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor())
test_data = Subset(test_data,np.arange(100))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 30.1MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.14MB/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 9.63MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 6.38MB/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
len(test_data)

100

In [None]:

from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

torch.manual_seed(73)

batch_size = 64

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

class ConvNet(torch.nn.Module):
    def __init__(self, hidden=64, output=10):
        super(ConvNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 4, kernel_size=7, padding=0, stride=3)
        self.fc1 = torch.nn.Linear(256, hidden)
        self.fc2 = torch.nn.Linear(hidden, output)

    def forward(self, x):
        x = self.conv1(x)
        # the model uses the square activation function
        x = x * x
        # flattening while keeping the batch axis
        x = x.view(-1, 256)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x


def train(model, train_loader, criterion, optimizer, n_epochs=10):
    # model in training mode
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))

    # model in evaluation mode
    model.eval()
    return model


model = ConvNet()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = train(model, train_loader, criterion, optimizer, 10)

Epoch: 1 	Training Loss: 0.397561
Epoch: 2 	Training Loss: 0.130699
Epoch: 3 	Training Loss: 0.088399
Epoch: 4 	Training Loss: 0.071318
Epoch: 5 	Training Loss: 0.058989
Epoch: 6 	Training Loss: 0.050542
Epoch: 7 	Training Loss: 0.044438
Epoch: 8 	Training Loss: 0.038261
Epoch: 9 	Training Loss: 0.034641
Epoch: 10 	Training Loss: 0.030696


Then test its accuracy on the test set:

In [None]:
def test(model, test_loader, criterion):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    # model in evaluation mode
    model.eval()

    for data, target in test_loader:
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # calculate and print avg test loss
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% '
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

test(model, test_loader, criterion)

Test Loss: 0.006644

Test Accuracy of 0: 100% (8/8)
Test Accuracy of 1: 100% (14/14)
Test Accuracy of 2: 100% (8/8)
Test Accuracy of 3: 100% (11/11)
Test Accuracy of 4: 100% (14/14)
Test Accuracy of 5: 100% (7/7)
Test Accuracy of 6: 100% (10/10)
Test Accuracy of 7: 100% (15/15)
Test Accuracy of 8: 100% (2/2)
Test Accuracy of 9: 100% (11/11)

Test Accuracy (Overall): 100% (100/100)


In [None]:
!pip install tenseal==0.3.15

Collecting tenseal==0.3.15
  Downloading tenseal-0.3.15-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.2 kB)
Downloading tenseal-0.3.15-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (4.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/4.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/4.8 MB[0m [31m4.5 MB/s[0m eta [36m0:00:02[0m[2K   [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/4.8 MB[0m [31m17.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m4.8/4.8 MB[0m [31m47.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.15


## Encrypted Evaluation

Now start the encrypted evaluation that will use the pre-trained model:

In [None]:
"""
It's a PyTorch-like model using operations implemented in TenSEAL.
    - .mm() method is doing the vector-matrix multiplication explained above.
    - you can use + operator to add a plain vector as a bias.
    - .conv2d_im2col() method is doing a single convlution operation.
    - .square_() just square the encrypted vector inplace.
"""

import tenseal as ts


class EncConvNet:
    def __init__(self, torch_nn):
        self.conv1_weight = torch_nn.conv1.weight.data.view(
            torch_nn.conv1.out_channels, torch_nn.conv1.kernel_size[0],
            torch_nn.conv1.kernel_size[1]
        ).tolist()
        self.conv1_bias = torch_nn.conv1.bias.data.tolist()

        self.fc1_weight = torch_nn.fc1.weight.T.data.tolist()
        self.fc1_bias = torch_nn.fc1.bias.data.tolist()

        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()


    def forward(self, enc_x, windows_nb):
        # conv layer
        enc_channels = []
        for kernel, bias in zip(self.conv1_weight, self.conv1_bias):
            y = enc_x.conv2d_im2col(kernel, windows_nb) + bias
            enc_channels.append(y)
        # pack all channels into a single flattened vector
        enc_x =
        #####################
        #Your code goes here
        #####################
        # square activation
        #####################
        #Your code goes here
        #####################
        # fc1 layer
        enc_x = enc_x.mm(self.fc1_weight) + self.fc1_bias
        # square activation
        enc_x.square_()
        # fc2 layer
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        return enc_x

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


def enc_test(context, model, test_loader, criterion, kernel_shape, stride):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    for data, target in test_loader:
        # Encoding and encryption
        x_enc, windows_nb = ts.im2col_encoding(
            context, data.view(28, 28).tolist(), kernel_shape[0],
            kernel_shape[1], stride
        )
        # Encrypted evaluation
        enc_output = enc_model(x_enc, windows_nb)
        # Decryption of result
        output =
        #####################
        #Your code goes here
        #####################
        output = torch.tensor(output).view(1, -1)

        # compute loss
        loss = criterion(output, target)
        test_loss += loss.item()

        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        label = target.data[0]
        class_correct[label] += correct.item()
        class_total[label] += 1


    # calculate and print avg test loss
    test_loss = test_loss / sum(class_total)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% '
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )


# Load one element at a time
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
# required for encoding
kernel_shape = model.conv1.kernel_size
stride = model.conv1.stride[0]

Choosing the parameters isn't easy, so we list some intuition here for why we have chosen these parameters exactly:

1. For a given security level (e.g. 128-bits security) and a polynomial modulus degree (e.g. 8192) there is an upper bound for the bit count of the coefficient modulus (`sum(coeff_mod_bit_sizes)`). If the upper bound is surpassed, there is a need to use a higher polynomial modulus degree (e.g. 16384) in order to make sure we still have the required security level.
2. The multiplicative depth is controlled by the number of primes constituting our coefficient modulus.
3. All elements of `coeff_mod_bit_sizes[1: -1]` should be equal in TenSEAL, since it takes care of rescaling ciphertexts. And we also want to use the same number of bits (e.g. 2 ^ 26) for the scale during encryption.
4. The scale is what controls the precision of the fractional part, since it's the value that plaintexts are multiplied with before being encoded into a polynomial of integer coefficients.

Starting with a scale of more than 20 bits, we need to choose the number of bits of all the middle primes equal to that, so we are already over 120 bits. With this lower bound of coefficient modulus and a security level of 128-bits, we will need a polynomial modulus degree of at least 8192. The upper bound for choosing a higher degree is at 218. Trying different values for the precision and adjusting the coefficient modulus, while studying the loss and accuracy, we end up with 26-bits of scale and primes. We also have 5 bits (31 - 26) for the integer part in the last coefficient modulus, which should be enough for our use case, since output values aren't that big.

In [None]:
## Encryption Parameters

# controls precision of the fractional part
bits_scale = 26

# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)

# set the scale
context.global_scale = pow(2, bits_scale)

# galois keys are required to do ciphertext rotations
context.generate_galois_keys()

In [None]:
enc_model = EncConvNet(model)
enc_test(context, enc_model, test_loader, criterion, kernel_shape, stride)

Test Loss: 0.008425

Test Accuracy of 0: 100% (8/8)
Test Accuracy of 1: 100% (14/14)
Test Accuracy of 2: 100% (8/8)
Test Accuracy of 3: 100% (11/11)
Test Accuracy of 4: 100% (14/14)
Test Accuracy of 5: 100% (7/7)
Test Accuracy of 6: 100% (10/10)
Test Accuracy of 7: 100% (15/15)
Test Accuracy of 8: 100% (2/2)
Test Accuracy of 9: 100% (11/11)

Test Accuracy (Overall): 100% (100/100)


## References

https://github.com/OpenMined/TenSEAL