In [1]:
import torch
torch.set_num_threads(1) # We ask torch to use a single thread 
# as we run async code which conflicts with multithreading
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms
import time

from models.net import presnet10

In [2]:
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
train_transform= transforms.Compose([transforms.Resize((96, 96)),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std =std)])
# test_transform_list = train_transform_list

train_ds = torchvision.datasets.ImageFolder('./datasets/chest_xray/train' ,transform = train_transform)
# test_ds = torchvision.datasets.ImageFolder('./datasets/chest_xray/test' ,transform = test_transform_list)

dataloader = torch.utils.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)

In [3]:
model = presnet10(2, act='poly').eval()
state = torch.load("./presnet10_chest_best.pth", map_location='cpu')
model.load_state_dict(state)

In [4]:
import syft as sy

hook = sy.TorchHook(torch) 
data_owner = sy.VirtualWorker(hook, id="data_owner")
model_owner = sy.VirtualWorker(hook, id="model_owner")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

from syft.serde.compression import NO_COMPRESSION
sy.serde.compression.default_compress_scheme = NO_COMPRESSION

data, true_labels = next(iter(dataloader))
data_ptr = data.send(data_owner)

# We store the true output of the model for comparison purpose
true_prediction = model(data)
model_ptr = model.send(model_owner)

In [5]:
encryption_kwargs = dict(
    workers=(data_owner, model_owner), # the workers holding shares of the secret-shared encrypted data
    crypto_provider=crypto_provider, # a third party providing some cryptography primitives
    protocol="fss", # the name of the crypto protocol, fss stands for "Function Secret Sharing"
    precision_fractional=4, # the encoding fixed precision (i.e. floats are truncated to the 4th decimal)
)
encrypted_data = data_ptr.encrypt(**encryption_kwargs).get()
encrypted_model = model_ptr.encrypt(**encryption_kwargs).get()

In [9]:
start_time = time.time()

encrypted_prediction = encrypted_model(encrypted_data)
encrypted_labels = encrypted_prediction.argmax(dim=1)

print(time.time() - start_time, "seconds")

labels = encrypted_labels.decrypt()

print("Predicted labels:", labels)
print("     True labels:", true_labels)



58.69201326370239 seconds
Predicted labels: tensor([0., 0.])
     True labels: tensor([0, 0])
