# ENCRYPTED INFERENCE USING RESNET-18
Imports

In [1]:
import torch
torch.set_num_threads(1) # We ask torch to use a single thread
# as we run async code which conflicts with multithreading
import torch.nn as nn
import numpy as np
import torchvision

from torchvision import datasets, models, transforms
import time
import os
import syft as sy

### Loading data

In [2]:
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# data_dir = os.path.realpath('../../data/hymenoptera_data')
data_dir = '../../data/hymenoptera_data'
image_dataset = datasets.ImageFolder(data_dir+'/val', data_transform)
dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=2, shuffle=True, num_workers=4)

dataset_size = len(image_dataset)
class_names = image_dataset.classes

### Loading Model

In [3]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
state = torch.load("../../models/resnet18_ants_bees.pt", map_location='cpu')
model.load_state_dict(state)
model.eval()

model.maxpool, model.relu = model.relu, model.maxpool

### Virtual setup

In [4]:
hook = sy.TorchHook(torch)
data_owner = sy.VirtualWorker(hook, id="data_owner")
model_owner = sy.VirtualWorker(hook, id="model_owner")
crypto_provider = sy.VirtualWorker(hook, id="crypto_provider")

In [5]:
# Remove compression to have faster communication, because compression time 
# is non-negligible: we send to workers crypto material which is very heavy
# and pseudo-random, so compressing it takes a long time and isn't useful:
# randomness can't be compressed, otherwise it wouldn't be random!
from syft.serde.compression import NO_COMPRESSION
sy.serde.compression.default_compress_scheme = NO_COMPRESSION

Adding data to the **data_owner** and the model on **model_owner**

In [6]:
data, true_labels = next(iter(dataloader))
data_ptr = data.send(data_owner)

true_predicition = model(data)
model_ptr = model.send(model_owner)

In [7]:
data.shape

torch.Size([2, 3, 224, 224])

### Encryption

In [7]:
encryption_kwargs = dict(
    workers=(data_owner, model_owner), # the workers holding shares of the secret-shared encrypted data
    crypto_provider=crypto_provider, # a third party providing some cryptography primitives
    protocol="fss", # the name of the crypto protocol, fss stands for "Function Secret Sharing"
    precision_fractional=4, # the encoding fixed precision (i.e. floats are truncated to the 4th decimal)
)

encrypted_data = data_ptr.encrypt(**encryption_kwargs).get()
encrypted_model = model_ptr.encrypt(**encryption_kwargs).get()

In [8]:
encrypted_data

(Wrapper)>FixedPrecisionTensor>[AdditiveSharingTensor]
	-> [PointerTensor | me:82059159364 -> data_owner:12074529607]
	-> [PointerTensor | me:31469196163 -> model_owner:81424253850]
	*crypto provider: crypto_provider*

### Secure inference
We are now able to run our secure inference, so let's do it and let's compare it to the true_labels

In [None]:
start_time = time.time()
encrypted_predictions = encrypted_model(encrypted_data)
encrypted_labels = encrypted_prediction.argmax(dim=1)

print(time.time() - start_time, "seconds")

labels = encrypted_labels.decrypt()

print("Predicted labels:", labels)
print("     True labels:", true_labels)

Exception in thread Thread-20:
Traceback (most recent call last):
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/multiprocessing/pool.py", line 513, in _handle_workers
    cls._maintain_pool(ctx, Process, processes, pool, inqueue,
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/multiprocessing/pool.py", line 337, in _maintain_pool
    Pool._repopulate_pool_static(ctx, Process, processes, pool,
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/multiprocessing/pool.py", line 326, in _repopulate_pool_static
    w.start()
  File "/home/lytvyn/anaconda3/envs/syft_python/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/lytvyn/anacon