# Classification with Encrypted Neural Networks

In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, suppose say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We will see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.

To simulate this scenario, we will begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

## Setup

We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. 

In [2]:
import crypten
import crypten.nn as nn
import math
import torch
import logging

crypten.init()
torch.set_num_threads(1)



In [3]:
class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Attention, self).__init__()

        assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.search_dim = embed_dim // num_heads

        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)


    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

        attn = q.matmul(k_t) / math.sqrt(q.size(-1))
        attn = attn.softmax(dim=-1)
        y = attn.matmul(v)
        y = y.transpose(1, 2)
        y = y.reshape(batch_size, seq_len, self.embed_dim)
        return y


class BertBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(BertBlock, self).__init__()
        embed_dim = embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, num_heads)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x))
        x = self.ln2(x + self.ff(x))
        return x

class Bert(nn.Module):
    def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
        super(Bert, self).__init__()
        self.full = full
        if full:
            self.tok_embed = nn.Embedding(vocab_size, embed_dim)
            self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))
        self.ln = nn.LayerNorm
        self.blocks = nn.Sequential(
            *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
        )
        self.ln = nn.LayerNorm(embed_dim)
        if full:
            self.pooler = nn.Linear(embed_dim, embed_dim)
            self.tanh = nn.TANH()
            self.classifier = nn.Linear(embed_dim, 2)

    def forward(self, x, target=None):
        if self.full:
            tok_embedding = self.tok_embed(x)
            pos_embedding = self.pos_embed(x)[:, :x.size()[1], :]
            x = tok_embedding + pos_embedding
        x = self.ln(x)
        x = self.blocks(x)
        if self.full:
            x = self.pooler(x)
            x = self.tanh(x)
            x = self.classifier(x)
        return x

model = Bert(768, 12, 12, 30522, 1024, True) # bert base 13.5s
model.encrypt(src=0)

# Load data to Bob
print('loading data')
data_enc = crypten.cryptensor(torch.arange(64).reshape(1, 64))

# Classify the encrypted data
model.eval()
print("forward")
output_enc = model(data_enc)
print('output_enc')
# Compute the accuracy
output = output_enc.get_plain_text()
print(f"{output=}")


loading data
forward
output_enc
output=tensor([[[ 3.2761e-01, -2.6791e-01],
         [-1.1989e-01, -3.7218e-01],
         [ 1.6673e-01, -2.5415e-01],
         [ 2.8795e-01,  2.3799e-01],
         [ 2.1680e-01,  1.4282e-02],
         [ 2.2366e-01, -2.4100e-01],
         [ 5.3741e-02, -2.5662e-01],
         [ 1.0928e-01, -1.0956e-02],
         [ 2.2964e-01, -1.6299e-01],
         [ 2.3883e-01, -3.0801e-01],
         [ 1.6910e-01, -8.8898e-02],
         [ 3.8145e-01,  4.3121e-02],
         [ 2.7388e-01, -3.4944e-01],
         [ 1.1966e-01, -5.3261e-01],
         [ 2.6854e-01, -8.8959e-02],
         [ 4.8885e-01, -3.7724e-01],
         [ 2.8743e-01, -1.3196e-01],
         [ 3.5741e-01, -6.4240e-02],
         [ 8.6365e-03,  4.3289e-02],
         [-8.2016e-02,  8.9615e-02],
         [ 4.1743e-01, -4.1037e-01],
         [ 3.1194e-01, -3.7967e-01],
         [ 8.5495e-02, -4.7932e-01],
         [ 3.5640e-01, -1.0381e-01],
         [ 6.4499e-02,  2.8355e-01],
         [ 2.1394e-01, -3.0057e-01],

In [4]:
from transformers import AutoTokenizer, BertForSequenceClassification

bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Access the model's weights
weights = bert_model.state_dict()

# Modify the weights or perform any operation you desire
# Example: Print the shape of each weight tensor
o = 0
b = 0
for name, weight in weights.items():
    if "weight" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        o += p
    elif "bias" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        b += p
    else:
        print(f"else {name}: {weight.size()}")
print(f"n_weight={o}, n_bias={b}, n_param={o+b}")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


bert.embeddings.word_embeddings.weight: torch.Size([30522, 768])
bert.embeddings.position_embeddings.weight: torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight: torch.Size([2, 768])
bert.embeddings.LayerNorm.weight: torch.Size([768])
bert.embeddings.LayerNorm.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias: torch.Size([768

In [39]:
model = Bert(768, 12, 12, 30522, 1024, True)

In [40]:
model.tok_embed.weight = weights["bert.embeddings.word_embeddings.weight"]
model.pos_embed.weight = weights["bert.embeddings.position_embeddings.weight"][None, :, :]
model.ln.weight = weights["bert.embeddings.LayerNorm.weight"]
model.ln.bias = weights["bert.embeddings.LayerNorm.bias"]
for m in range(len(model.blocks._modules)):
    layer = "bert.encoder.layer."
    model.blocks._modules[str(m)].attn.query.weight = weights[layer+str(m)+".attention.self.query.weight"]
    model.blocks._modules[str(m)].attn.query.bias = weights[layer+str(m)+".attention.self.query.bias"]
    model.blocks._modules[str(m)].attn.key.weight = weights[layer+str(m)+".attention.self.key.weight"]
    model.blocks._modules[str(m)].attn.key.bias = weights[layer+str(m)+".attention.self.key.bias"]
    model.blocks._modules[str(m)].attn.value.weight = weights[layer+str(m)+".attention.self.value.weight"]
    model.blocks._modules[str(m)].attn.value.bias = weights[layer+str(m)+".attention.self.value.bias"]
    model.blocks._modules[str(m)].attn.proj.weight = weights[layer+str(m)+".attention.output.dense.weight"] # .t()
    model.blocks._modules[str(m)].attn.proj.bias = weights[layer+str(m)+".attention.output.dense.bias"]
    model.blocks._modules[str(m)].ln1.weight = weights[layer+str(m)+".attention.output.LayerNorm.weight"]
    model.blocks._modules[str(m)].ln1.bias = weights[layer+str(m)+".attention.output.LayerNorm.bias"]
    model.blocks._modules[str(m)].ff._modules['0'].weight = weights[layer+str(m)+".intermediate.dense.weight"] # .t()
    model.blocks._modules[str(m)].ff._modules['0'].bias = weights[layer+str(m)+".intermediate.dense.bias"]
    model.blocks._modules[str(m)].ff._modules['2'].weight = weights[layer+str(m)+".output.dense.weight"] # .t()
    model.blocks._modules[str(m)].ff._modules['2'].bias = weights[layer+str(m)+".output.dense.bias"]
    model.blocks._modules[str(m)].ln2.weight = weights[layer+str(m)+".output.LayerNorm.weight"]
    model.blocks._modules[str(m)].ln2.bias = weights[layer+str(m)+".output.LayerNorm.bias"]
model.pooler.weight = weights["bert.pooler.dense.weight"] # .t()
model.pooler.bias = weights["bert.pooler.dense.bias"]
model.classifier.weight = weights["classifier.weight"] # .t()
model.classifier.bias = weights["classifier.bias"]

In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = bert_model(**inputs)
realput = outputs
print(realput)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.2915, -0.1456]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [45]:
import codecs

def load_tsv(data_file, max_seq_len, delimiter='\t'):
    '''Load a tsv '''
    sentences = []
    targets = []
    with codecs.open(data_file, 'r', 'utf-8') as data_fh:
        for _ in range(1):
            data_fh.readline()
        for row in data_fh:
            row = row.strip().split(delimiter)
            sentence = "Question: " + row[1] + "\nAnswer Text: " + row[2]
            sentences.append(tokenizer(sentence, return_tensors="pt"))
            targets.append(1*(row[3] == "not_entailment"))
    return sentences, targets

data, targets = load_tsv("GLUE-baselines/glue_data/QNLI/train.tsv", 1024)

In [84]:
count = 0
total = 0
for label in range(1000):
    outputs = bert_model(**data[label])
    count += targets[label] == outputs.logits.argmax()
    total += 1
count / total

tensor(0.4940)

In [41]:
model.encrypt(src=0)
# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
x = crypten.cryptensor(data[0]['input_ids']).reshape(1, -1)
# Classify the encrypted data
model.eval()
print("forward")
x = model(x)
print('output_enc')
# Compute the accuracy
output = x.get_plain_text()
print(f"{output.shape=}")
print(f"{output=}")

loading data
forward
output_enc
output.shape=torch.Size([1, 69, 2])
output=tensor([[[-0.1482,  0.2658],
         [-0.2622,  0.2909],
         [-0.4788, -0.2028],
         [-0.3937,  0.0043],
         [-0.4800, -0.0475],
         [-0.3161, -0.1730],
         [-0.2097,  0.2421],
         [-0.5458, -0.2761],
         [-0.4550, -0.0066],
         [-0.3025, -0.0943],
         [-0.5511, -0.4168],
         [-0.1581,  0.2003],
         [-0.5367, -0.2714],
         [-0.4930, -0.0416],
         [-0.3660,  0.0542],
         [-0.4442, -0.1850],
         [-0.4354, -0.2316],
         [-0.3931, -0.0638],
         [-0.0797,  0.4111],
         [-0.3819, -0.1658],
         [-0.1856,  0.0288],
         [-0.4155, -0.2120],
         [-0.2390,  0.1315],
         [-0.2065,  0.0993],
         [-0.3333, -0.1773],
         [-0.5713, -0.1519],
         [-0.4734, -0.2253],
         [-0.4576, -0.1731],
         [-0.3396,  0.0236],
         [-0.5201, -0.0949],
         [-0.4643, -0.2706],
         [-0.5551, -0.0987