# Install dependencies

In [1]:
!pip install transformers[onnx]

Collecting transformers[onnx]
  Using cached transformers-4.16.2-py3-none-any.whl (3.5 MB)
Collecting sacremoses
  Using cached sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
Collecting huggingface-hub<1.0,>=0.1.0
  Using cached huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
Collecting requests
  Using cached requests-2.27.1-py2.py3-none-any.whl (63 kB)
Collecting numpy>=1.17
  Using cached numpy-1.19.5-cp36-cp36m-manylinux2010_x86_64.whl (14.8 MB)
Collecting filelock
  Using cached filelock-3.4.1-py3-none-any.whl (9.9 kB)
Collecting pyyaml>=5.1
  Using cached PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (603 kB)
Collecting tokenizers!=0.11.3,>=0.10.1
  Using cached tokenizers-0.11.4-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
Collecting tqdm>=4.27
  Using cached tqdm-4.62.3-py2.py3-none-any.whl (76 kB)
Collecting regex!=2019.12.17
  Using cached regex-2022.1.18-cp36-cp36m-manylinux_2_17_x86_64.manyl

In [2]:
!pip install blindai



# Load the model

In [1]:
from transformers import DistilBertForSequenceClassification

# Load the model
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier

# Export the model

In [2]:
from transformers import DistilBertTokenizer
import torch

# Create dummy input for export
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
sentence = "I love AI and privacy!"
inputs = tokenizer(sentence, padding = "max_length", max_length = 8, return_tensors="pt")["input_ids"]

# Export the model
torch.onnx.export(
	model, inputs, "./distilbert-base-uncased.onnx",
	export_params=True, opset_version=11,
	input_names = ['input'], output_names = ['output'],
	dynamic_axes={'input' : {0 : 'batch_size'},
	'output' : {0 : 'batch_size'}})

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

# Upload model to inference server

In [1]:
!wget https://raw.githubusercontent.com/mithril-security/blindai/master/examples/distilbert/hardware/host_server.pem
!wget https://raw.githubusercontent.com/mithril-security/blindai/master/examples/distilbert/hardware/policy.toml

--2022-02-23 16:10:27--  https://raw.githubusercontent.com/mithril-security/blindai/master/examples/distilbert/hardware/host_server.pem
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1155 (1,1K) [text/plain]
Saving to: ‘host_server.pem.1’


2022-02-23 16:10:27 (48,5 MB/s) - ‘host_server.pem.1’ saved [1155/1155]

--2022-02-23 16:10:27--  https://raw.githubusercontent.com/mithril-security/blindai/master/examples/distilbert/hardware/policy.toml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 881 [text/plain]
Saving to: ‘policy.toml.

In [3]:
from blindai.client import BlindAiClient, ModelDatumType

# Launch client
client = BlindAiClient()

client.connect_server(addr="localhost", policy="policy.toml", certificate="host_server.pem")

client.upload_model(model="./distilbert-base-uncased.onnx", shape=inputs.shape, dtype=ModelDatumType.I64)

ok: true
msg: "OK"

# Send data for prediction

In [6]:
from transformers import DistilBertTokenizer

# Prepare the inputs
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
sentence = "I love AI and privacy!"
inputs = tokenizer(sentence, padding = "max_length", max_length = 8)["input_ids"]

In [5]:
from blindai.client import BlindAiClient

# Load the client
client = BlindAiClient()
client.connect_server(addr="localhost", policy="policy.toml", certificate="host_server.pem")

# Get prediction
response = client.run_model(inputs)

In [8]:
response.output

output: 0.0005601687589660287
output: 0.06354495882987976
ok: true
msg: "OK"

Here we can compare the results against the original prediction.

In [9]:
model(torch.tensor(inputs).unsqueeze(0)).logits.detach()

SequenceClassifierOutput(loss=None, logits=tensor([[0.0006, 0.0635]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)