*Based on this tutorial https://www.zama.ai/post/quantization-of-neural-networks-for-fully-homomorphic-encryption*

## Train a 3-layer MNIST model

In [46]:
import numpy as np
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from threading import Thread
import matplotlib.pyplot as plt

from concrete.ml.torch import NumpyModule
from concrete.ml.quantization import PostTrainingAffineQuantization
from concrete.ml.quantization import QuantizedArray

In [2]:
batch_size = 16
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_data = torchvision.datasets.MNIST('./data', train=True, download=False,
                    transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

test_data = torchvision.datasets.MNIST('./data', train=False,
                    transform=transform)
test_loader = torch.utils.data.DataLoader(test_data, 10000, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(in_features=28*28, out_features=128, bias=False)
        self.sigmoid1 = nn.Sigmoid()
        self.fc2 = nn.Linear(in_features=128, out_features=64, bias=False)
        self.sigmoid2 = nn.Sigmoid()
        self.fc3 = nn.Linear(in_features=64, out_features=10, bias=False)

    def forward(self, x):
        out = self.fc1(x)
        out = self.sigmoid1(out)
        out = self.fc2(out)
        out = self.sigmoid2(out)
        out = self.fc3(out)
        return out
        
net = Model()

def test(epoch):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.reshape(inputs.shape[0], inputs.shape[1]* inputs.shape[2]*inputs.shape[3])
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        print(f"Epoch {epoch}, Test accuracy: {np.round(100.*correct/total, 2)}, Correct: {correct}, Total: {total}")

In [3]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001,momentum=0.9, weight_decay=5e-4)
net = net.to(device)
net.train()

def train(epochs):
    for epoch in range(epochs):
        for i, (inputs, targets) in enumerate(train_loader, 0):
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.reshape(inputs.shape[0], inputs.shape[1]*inputs.shape[2]*inputs.shape[3])
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        test(epoch)

    print('Finished Training. Model saved to "mnist_97p.pt"')
    torch.save(net.state_dict(), "mnist_97p.pt")

In [5]:
train_thread = Thread(target=train, name="train mnist", args=(50, ))
train_thread.is_alive()

False

In [6]:
train_thread.start()

True

Epoch 0, Test accuracy: 86.98, Correct: 8698, Total: 10000


In [9]:
train_thread.is_alive()

True

Epoch 18, Test accuracy: 96.38, Correct: 9638, Total: 10000
Epoch 19, Test accuracy: 96.56, Correct: 9656, Total: 10000
Epoch 20, Test accuracy: 96.53, Correct: 9653, Total: 10000
Epoch 21, Test accuracy: 96.59, Correct: 9659, Total: 10000
Epoch 22, Test accuracy: 96.66, Correct: 9666, Total: 10000
Epoch 23, Test accuracy: 96.79, Correct: 9679, Total: 10000
Epoch 24, Test accuracy: 96.78, Correct: 9678, Total: 10000
Epoch 25, Test accuracy: 96.79, Correct: 9679, Total: 10000
Epoch 26, Test accuracy: 96.8, Correct: 9680, Total: 10000
Epoch 27, Test accuracy: 96.92, Correct: 9692, Total: 10000
Epoch 28, Test accuracy: 97.01, Correct: 9701, Total: 10000
Epoch 29, Test accuracy: 96.95, Correct: 9695, Total: 10000
Epoch 30, Test accuracy: 97.03, Correct: 9703, Total: 10000
Epoch 31, Test accuracy: 97.1, Correct: 9710, Total: 10000
Epoch 32, Test accuracy: 97.17, Correct: 9717, Total: 10000
Epoch 33, Test accuracy: 97.14, Correct: 9714, Total: 10000
Epoch 34, Test accuracy: 97.21, Correct: 9

## Load and eval

In [10]:
train_thread.join()
PATH = "./mnist_97p.pt"
torch_fc_model = Model()
torch_fc_model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
torch_fc_model.eval()

Epoch 46, Test accuracy: 97.44, Correct: 9744, Total: 10000
Epoch 47, Test accuracy: 97.41, Correct: 9741, Total: 10000
Epoch 48, Test accuracy: 97.48, Correct: 9748, Total: 10000
Epoch 49, Test accuracy: 97.45, Correct: 9745, Total: 10000
Finished Training. Model saved to "mnist_97p.pt"


Model(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (sigmoid1): Sigmoid()
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (sigmoid2): Sigmoid()
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [14]:
# Create some dummy input for the tracing.
dummy_data = torch.randn(1, 28*28)

# Carry torch model to numpy (required in by concrete)
numpy_fc_model = NumpyModule(torch_fc_model, dummy_data)

# Create random inputs of (n_examples, n_features)
mnist_test_data, mnist_test_target = next(iter(test_loader))
mnist_test_data = mnist_test_data.detach().numpy()
mnist_test_target = mnist_test_target.detach().numpy()
mnist_test_data = mnist_test_data.reshape(mnist_test_data.shape[0], mnist_test_data.shape[1]* mnist_test_data.shape[2]*mnist_test_data.shape[3])

# Check that both model give same output
# Accuracy pytorch
(torch_fc_model(torch.from_numpy(mnist_test_data)).detach().numpy().argmax(1) == mnist_test_target).mean()
# Output: 0.9733

# Accuracy numpy
(numpy_fc_model(mnist_test_data).argmax(1) == mnist_test_target).mean()
# Output: 0.9733

verbose: False, log level: Level.ERROR



0.9745

## Quantized inference

In [19]:
n_bits = 6  # 64
is_signed = False

Quantizing our model 

In [21]:
pt_quant = PostTrainingAffineQuantization(n_bits = n_bits, 
                                          numpy_model = numpy_fc_model)

In [44]:
pt_quant

<concrete.ml.quantization.post_training.PostTrainingAffineQuantization at 0x7f1fc19819d0>

Calibrate layers and activations

In [22]:
quant_module = pt_quant.quantize_module(mnist_test_data)

Quantize the input

In [23]:
q_mnist_test_data = QuantizedArray(n_bits = n_bits, 
                                   values=mnist_test_data, 
                                   is_signed=is_signed)

Compare dequantized input value vs real input values

In [74]:
arg_diff_values = (mnist_test_data != -0.42421296)
# Real input
print(f"real value = {mnist_test_data[arg_diff_values][:16]} \n")
print(f"dequantized value = {mnist_test_data[arg_diff_values][:16]} \n")
print(f"quantized value = {q_mnist_test_data.qvalues[arg_diff_values][:16]}")


real value = [0.64495873 1.9305104  1.5995764  1.4977505  0.33948106 0.03400347
 2.401455   2.8087585  2.8087585  2.8087585  2.8087585  2.6432915
 2.0959773  2.0959773  2.0959773  2.0959773 ] 

dequantized value = [0.64495873 1.9305104  1.5995764  1.4977505  0.33948106 0.03400347
 2.401455   2.8087585  2.8087585  2.8087585  2.8087585  2.6432915
 2.0959773  2.0959773  2.0959773  2.0959773 ] 

quantized value = [21 45 39 37 15  9 55 63 63 63 63 59 49 49 49 49]


Check the quantized weights for the first layer

In [77]:
next(iter(quant_module.quant_layers_dict.values()))[1].constant_inputs[1].qvalues[0][:16]

array([ 2,  0,  0,  0,  0, -1,  1,  0,  0,  2,  0,  1,  1,  1,  2,  0])

Make sure all input values are integers with 2**6 (64) values

In [78]:
np.unique(q_mnist_test_data.qvalues[arg_diff_values])

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [81]:
# Accuracy Quantized Numpy (6 bits)
(quant_module.quantized_forward(q_mnist_test_data.qvalues).argmax(1) == mnist_test_target).mean()

0.4507