### Przykład predykcji przy użyciu szyfrowania homomorficznego i sieci neuronowej na zbiorze MNIST

Schemat:

- Uczymy model na danych niezaszyfrowanych
- Robimy testową predykcję na danych niezaszyfrowanych
- Używając wag modelu tworzymy predykcję na danych zaszyfrowanych

Importujemy potrzebne biblioteki dla modelu jawnego

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time

Uczymy model jawny na danych niezaszyfrowanych. Z przyczyn sprzętowych ustawiamy test_batch_size na jeden. 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

print(device)

train_batch_size = 64
test_batch_size = 1

train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
    root="./data", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=train_batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=test_batch_size, shuffle=False
)

examples = iter(train_loader)
example_data, example_targets = next(examples)

print(f"Batch shape: {example_data.shape}")  # Should be [batch_size, channels, height, width]
print(f"Target shape: {example_targets.shape}")  # Should be [batch_size]


class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(28*28, 64)
        self.sigm = nn.Sigmoid()
        self.fc2 = nn.Linear(64,10)


    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.sigm(x)
        x = self.fc2(x)
        return x


model = Net().to(device)
print(model)

learning_rate = 1e-2
losses = []

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

num_epochs = 10

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)           # Step 1: Get predictions
        loss = criterion(outputs, labels) # Step 2: Measure error
        # Backward pass and optimize
        optimizer.zero_grad()             # Step 3: Clear old gradients
        loss.backward()                   # Step 4: Compute new gradients
        optimizer.step()                  # Step 5: Update model weights

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    losses.append(running_loss/len(train_loader))

cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.23MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.11MB/s]


Batch shape: torch.Size([64, 1, 28, 28])
Target shape: torch.Size([64])
Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (sigm): Sigmoid()
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)
Epoch [1/10], Loss: 2.1845
Epoch [2/10], Loss: 1.7424
Epoch [3/10], Loss: 1.2242
Epoch [4/10], Loss: 0.9166
Epoch [5/10], Loss: 0.7478
Epoch [6/10], Loss: 0.6449
Epoch [7/10], Loss: 0.5761
Epoch [8/10], Loss: 0.5265
Epoch [9/10], Loss: 0.4893
Epoch [10/10], Loss: 0.4601


Sprawdzamy dokładność modelu

In [29]:
model.eval()
correct = 0
total = 0

start = time.time()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if total == 100: break

print(f"processing time for 100 images: {time.time() - start}")

clean_acc = 100 * correct / total
print(f"Accuracy on clean test images: {clean_acc:.2f}% (expect ~90%)")

processing time for 100 images: 0.06319689750671387
Accuracy on clean test images: 88.00% (expect ~90%)


In [None]:
!pip install tenseal

Collecting tenseal
  Downloading tenseal-0.3.16-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading tenseal-0.3.16-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (4.8 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/4.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/4.8 MB[0m [31m51.5 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m4.8/4.8 MB[0m [31m75.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.16


Tworzymy kontekst i klucze galois. Następnie definiujemy fazę forward na wektorach zaszyfrowanych, przy użyciu wag i biasów poprzedniego modelu. Szyfrowanie CKKS nie pozwala na użycie funkcji sigmoid, więc używamy wielomianowego przybliżenia stopnia 3.

In [None]:
import tenseal as ts

context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes = [30, 20, 20, 20, 20, 30]
)

context.global_scale = 2**20

context.generate_galois_keys()

class EncryptedNet:
    def __init__(self, torch_nn : Net):

        self.fc1_weight = torch_nn.fc1.weight.T.data.tolist()
        self.fc1_bias = torch_nn.fc1.bias.data.tolist()
        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()

    def forward(self, enc_x):
        enc_x = enc_x.mm(self.fc1_weight) + self.fc1_bias
        enc_x = self.encrypted_sigmoid(enc_x)
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        return enc_x

    def encrypted_sigmoid(self, enc_x):
        return enc_x.polyval([0.5, 0.197, 0, -0.004])

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

Tworzymy model zaszyfrowany

In [None]:
enc_model = EncryptedNet(model)
print(enc_model)

<__main__.EncryptedNet object at 0x7ae2c1a4c6d0>


Sprawdzamy dokładność i czas pracy modelu, który używa danych zaszyfowanych. 

In [None]:
import time

start = time.time()
with torch.no_grad():
    correct = 0
    total = 0

    for images, labels in test_loader:

        image = images.view(-1, 28*28).numpy().flatten()

        x_enc = ts.ckks_vector(context, image)
        
        # klasyfikujemy na danych zaszyfrowanych
        enc_output = enc_model(x_enc)
        
        # deszyfrujemy wynik
        decrypted_output = enc_output.decrypt()
        output = torch.tensor(decrypted_output, dtype=torch.float32).view(1, -1)
        
        # obliczamy dokładność
        _, predicted = torch.max(output.data, 1)
        total += 1
        correct += (predicted.item() == labels.item())

        print(f"Processed {total} samples, Accuracy: {100 * correct / total:.2f}%")

        if total == 100: break

    print(f"processing time for 100 images: {time.time() - start}")

    print(f'Accuracy on test images: {100 * correct / total:.2f}%')

Processed 1 samples, Accuracy: 100.00%
Processed 2 samples, Accuracy: 100.00%
Processed 3 samples, Accuracy: 100.00%
Processed 4 samples, Accuracy: 100.00%
Processed 5 samples, Accuracy: 100.00%
Processed 6 samples, Accuracy: 100.00%
Processed 7 samples, Accuracy: 100.00%
Processed 8 samples, Accuracy: 100.00%
Processed 9 samples, Accuracy: 88.89%
Processed 10 samples, Accuracy: 90.00%
Processed 11 samples, Accuracy: 90.91%
Processed 12 samples, Accuracy: 91.67%
Processed 13 samples, Accuracy: 92.31%
Processed 14 samples, Accuracy: 92.86%
Processed 15 samples, Accuracy: 93.33%
Processed 16 samples, Accuracy: 93.75%
Processed 17 samples, Accuracy: 94.12%
Processed 18 samples, Accuracy: 94.44%
Processed 19 samples, Accuracy: 94.74%
Processed 20 samples, Accuracy: 95.00%
Processed 21 samples, Accuracy: 95.24%
Processed 22 samples, Accuracy: 95.45%
Processed 23 samples, Accuracy: 95.65%
Processed 24 samples, Accuracy: 95.83%
Processed 25 samples, Accuracy: 96.00%
Processed 26 samples, Accu

Wnioski: Z uwagi na prostotę danych udało się uzyskać identyczną dokładność (88%), okupione zostało to o wiele większym czasem wykonania (ok. x4200). Użycie przybliżenia funkcji sigmoid, i fakt że pojedyńczy obraz ma 728px ma spory wpływ na czas wykonania.