# Classification with Encrypted Neural Networks

In this tutorial, we'll look at how we can achieve the <i>Model Hiding</i> application we discussed in the Introduction. That is, suppose say Alice has a trained model she wishes to keep private, and Bob has some data he wishes to classify while keeping it private. We will see how CrypTen allows Alice and Bob to coordinate and classify the data, while achieving their privacy requirements.

To simulate this scenario, we will begin with Alice training a simple neural network on MNIST data. Then we'll see how Alice and Bob encrypt their network and data respectively, classify the encrypted data and finally decrypt the labels.

## Setup

We first import the `torch` and `crypten` libraries, and initialize `crypten`. We will use a helper script `mnist_utils.py` to split the public MNIST data into Alice's portion and Bob's portion. 

In [1]:
import crypten
import crypten.nn as nn
import math
import torch
import logging

crypten.init()
torch.set_num_threads(1)



[Device] LUTs initialized for cpu



In [2]:
# class Attention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(Attention, self).__init__()

#         assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.search_dim = embed_dim // num_heads

#         self.key = nn.Linear(embed_dim, embed_dim)
#         self.value = nn.Linear(embed_dim, embed_dim)
#         self.query = nn.Linear(embed_dim, embed_dim)
#         self.proj = nn.Linear(embed_dim, embed_dim)


#     def forward(self, x):
#         batch_size = x.shape[0]
#         seq_len = x.shape[1]

#         k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
#         v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
#         q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

#         attn = q.matmul(k_t) / math.sqrt(q.size(-1))
#         attn = attn.softmax(dim=-1)
#         y = attn.matmul(v)
#         y = y.transpose(1, 2)
#         y = y.reshape(batch_size, seq_len, self.embed_dim)
#         return y


# class BertBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(BertBlock, self).__init__()
#         embed_dim = embed_dim
#         self.ln1 = nn.LayerNorm(embed_dim)
#         self.ln2 = nn.LayerNorm(embed_dim)
#         self.attn = Attention(embed_dim, num_heads)
#         self.ff = nn.Sequential(
#             nn.Linear(embed_dim, embed_dim * 4),
#             nn.GELU(),
#             nn.Linear(embed_dim * 4, embed_dim),
#         )

#     def forward(self, x):
#         x = self.ln1(x + self.attn(x))
#         x = self.ln2(x + self.ff(x))
#         return x

# class Bert(nn.Module):
#     def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
#         super(Bert, self).__init__()
#         self.full = full
#         if full:
#             self.word_embed = nn.Embedding(vocab_size, embed_dim)
#             self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, embed_dim))
#             self.type_embed = nn.Embedding(2, embed_dim)
#             self.ln = nn.LayerNorm(embed_dim)
#         self.blocks = nn.Sequential(
#             *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
#         )
#         if full:
#             self.pooler = nn.Linear(embed_dim, embed_dim)
#             self.classifier = nn.Linear(embed_dim, 2)

#     def forward(self, x, target=None):
#         if self.full:
#             word_embedding = self.word_embed(x["input_ids"])
#             pos_embedding = self.pos_embed(x)[:, :x["input_ids"].size()[1], :]
#             type_embedding = self.type_embed(x["token_type_ids"])
#             x = word_embedding + pos_embedding + type_embedding
#             x = self.ln(x)
#         x = self.blocks(x)
#         if self.full:
#             x = self.pooler(x[:, 0, :])
#             x = x.tanh()
#             x = self.classifier(x)
#         return x

# model = Bert(768, 12, 12, 30522, 1024, True) # bert base 13.5s

In [3]:
class BertEmbeddings(nn.Module):
    def __init__(self, vocab_size, emb_size, max_seq_length):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, emb_size)
        self.position_embeddings = nn.Embedding(max_seq_length, emb_size)
        self.token_type_embeddings = nn.Embedding(2, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, input_ids, token_type_ids):
        word_emb = self.word_embeddings(input_ids)
        pos_emb = self.position_embeddings.weight[:input_ids.size()[1], :]
        type_emb = self.token_type_embeddings(token_type_ids)

        emb = word_emb + pos_emb + type_emb
        emb = self.LayerNorm(emb)

        return emb

class BertSelfAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_size = emb_size // self.n_heads
        self.query = nn.Linear(emb_size, emb_size)
        self.key = nn.Linear(emb_size, emb_size)
        self.value = nn.Linear(emb_size, emb_size)

    def forward(self, emb):
        B, T, C = emb.shape  # batch size, sequence length, embedding size

        q = self.query(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = self.key(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = self.value(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        weights = q @ k.transpose(-2, -1) * self.head_size**-0.5

        weights = weights.softmax(dim=-1)

        emb_rich = weights @ v
        emb_rich = emb_rich.transpose(1, 2).reshape(B, T, C)

        return emb_rich

class BertSelfOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, emb_rich, emb):
        x = self.dense(emb_rich)
        x = x + emb
        out = self.LayerNorm(x)

        return out


class BertAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.self = BertSelfAttention(emb_size, n_heads)
        self.output = BertSelfOutput(emb_size)

    def forward(self, emb):
        emb_rich = self.self(emb)
        out = self.output(emb_rich, emb)

        return out

class BertIntermediate(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, 4 * emb_size)
        self.gelu = nn.GELU()

    def forward(self, att_out):
        x = self.dense(att_out)
        out = self.gelu(x)
        return out


class BertOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(4 * emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, intermediate_out, att_out):
        x = self.dense(intermediate_out)
        x = x + att_out
        out = self.LayerNorm(x)

        return out

class BertLayer(nn.Module):
    def __init__(self, emb_size, n_heads ):
        super().__init__()
        self.attention = BertAttention(emb_size, n_heads)
        self.intermediate = BertIntermediate(emb_size)
        self.output = BertOutput(emb_size)

    def forward(self, emb):
        att_out = self.attention(emb)
        intermediate_out = self.intermediate(att_out)
        out = self.output(intermediate_out, att_out)

        return out

class BertEncoder(nn.Module):
    def __init__(self, emb_size, n_heads, n_layers):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(emb_size, n_heads) for i in range(n_layers)])

    def forward(self, x):
        for l in self.layer:
            x = l(x)
        return x

class BertPooler(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.tanh = nn.Tanh()

    def forward(self, encoder_out):
        pool_first_token = encoder_out[:, 0]
        out = self.dense(pool_first_token)
        out = self.tanh(out)
        return out

class BertModel(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads, n_layers):
        super().__init__()
        self.embeddings = BertEmbeddings(vocab_size, emb_size, seq_len)
        self.encoder = BertEncoder(emb_size, n_heads, n_layers)
        self.pooler = BertPooler(emb_size)

    def forward(self, input_ids, token_type_ids):
        emb = self.embeddings(input_ids, token_type_ids)
        out = self.encoder(emb)
        pooled_out = self.pooler(out)
        return out, pooled_out

class BertForSequenceClassification(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads, n_layers):
        super().__init__()
        self.bert = BertModel(vocab_size, emb_size, seq_len, n_heads, n_layers)
        self.classifier = nn.Linear(emb_size, 2)

    def forward(self, input_ids, token_type_ids):
        _, pooled_out = self.bert(input_ids, token_type_ids)
        logits = self.classifier(pooled_out)
        return logits

# model = LLM(Bert(768, 12, 12, 30522, 1024))
model = BertForSequenceClassification(28996, 768, 512, 12, 12)
# model = BertForSequenceClassification(30522, 128, 512, 2, 2) # tiny

In [4]:
from transformers import AutoTokenizer, BertForSequenceClassification

bert_model = BertForSequenceClassification.from_pretrained('gchhablani/bert-base-cased-finetuned-qnli')
# bert_model = BertForSequenceClassification.from_pretrained('M-FAC/bert-tiny-finetuned-qnli')
bert_model.eval()

# Access the model's weights
weights = bert_model.state_dict()

# Modify the weights or perform any operation you desire
# Example: Print the shape of each weight tensor
o = 0
b = 0
for name, weight in weights.items():
    if "weight" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        o += p
    elif "bias" in str(name):
        print(f"{name}: {weight.size()}")
        p = 1
        for w in weight.size():
            p *= w
        b += p
    else:
        print(f"else {name}: {weight.size()}")
print(f"n_weight={o}, n_bias={b}, n_param={o+b}")


bert.embeddings.word_embeddings.weight: torch.Size([28996, 768])
bert.embeddings.position_embeddings.weight: torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight: torch.Size([2, 768])
bert.embeddings.LayerNorm.weight: torch.Size([768])
bert.embeddings.LayerNorm.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias: torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight: torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight: torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias: torch.Size([768

In [5]:
model.load_state_dict(weights)

# model.word_embed.weight = weights["bert.embeddings.word_embeddings.weight"]
# model.pos_embed.data = weights["bert.embeddings.position_embeddings.weight"][None, :, :]
# model.type_embed.weight = weights["bert.embeddings.token_type_embeddings.weight"]
# model.ln.weight = weights["bert.embeddings.LayerNorm.weight"]
# model.ln.bias = weights["bert.embeddings.LayerNorm.bias"]
# for m in range(len(model.blocks._modules)):
#     layer = "bert.encoder.layer."
#     model.blocks._modules[str(m)].attn.query.weight = weights[layer+str(m)+".attention.self.query.weight"]
#     model.blocks._modules[str(m)].attn.query.bias = weights[layer+str(m)+".attention.self.query.bias"]
#     model.blocks._modules[str(m)].attn.key.weight = weights[layer+str(m)+".attention.self.key.weight"]
#     model.blocks._modules[str(m)].attn.key.bias = weights[layer+str(m)+".attention.self.key.bias"]
#     model.blocks._modules[str(m)].attn.value.weight = weights[layer+str(m)+".attention.self.value.weight"]
#     model.blocks._modules[str(m)].attn.value.bias = weights[layer+str(m)+".attention.self.value.bias"]
#     model.blocks._modules[str(m)].attn.proj.weight = weights[layer+str(m)+".attention.output.dense.weight"] # .t()
#     model.blocks._modules[str(m)].attn.proj.bias = weights[layer+str(m)+".attention.output.dense.bias"]
#     model.blocks._modules[str(m)].ln1.weight = weights[layer+str(m)+".attention.output.LayerNorm.weight"]
#     model.blocks._modules[str(m)].ln1.bias = weights[layer+str(m)+".attention.output.LayerNorm.bias"]
#     model.blocks._modules[str(m)].ff._modules['0'].weight = weights[layer+str(m)+".intermediate.dense.weight"] # .t()
#     model.blocks._modules[str(m)].ff._modules['0'].bias = weights[layer+str(m)+".intermediate.dense.bias"]
#     model.blocks._modules[str(m)].ff._modules['2'].weight = weights[layer+str(m)+".output.dense.weight"] # .t()
#     model.blocks._modules[str(m)].ff._modules['2'].bias = weights[layer+str(m)+".output.dense.bias"]
#     model.blocks._modules[str(m)].ln2.weight = weights[layer+str(m)+".output.LayerNorm.weight"]
#     model.blocks._modules[str(m)].ln2.bias = weights[layer+str(m)+".output.LayerNorm.bias"]
# model.pooler.weight = weights["bert.pooler.dense.weight"] # .t()
# model.pooler.bias = weights["bert.pooler.dense.bias"]
# model.classifier.weight = weights["classifier.weight"] # .t()
# model.classifier.bias = weights["classifier.bias"]

In [6]:
tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli")
# tokenizer = AutoTokenizer.from_pretrained("M-FAC/bert-tiny-finetuned-qnli")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = bert_model(**inputs)
realput = outputs
print(realput)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.5781,  0.1473]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [7]:
import codecs

def load_tsv(data_file, max_seq_len, delimiter='\t'):
    '''Load a tsv '''
    sentences = []
    targets = []
    with codecs.open(data_file, 'r', 'utf-8') as data_fh:
        for _ in range(1):
            data_fh.readline()
        for row in data_fh:
            row = row.strip().split(delimiter)
            sentences.append(tokenizer(row[1][:512], row[2][:512], return_tensors="pt"))
            targets.append(1*(row[3] == "not_entailment"))
    return sentences, targets

data, targets = load_tsv("GLUE-baselines/glue_data/QNLI/dev.tsv", 1024)

In [8]:
count = 0
total = 0
for label in range(len(data)):
    outputs = bert_model(**data[label])
    count += targets[label] == outputs.logits.argmax()
    total += 1
count / total

tensor(0.9099)

In [9]:
model.encrypt(src=0)
# Load data to Bob
print('loading data')
# data_enc = crypten.load_from_party('/tmp/bob_test.pth', src=ALICE)
x = {}
x['input_ids'] = crypten.cryptensor(data[0]["input_ids"], precision = 0)
x['token_type_ids'] = crypten.cryptensor(data[0]["token_type_ids"], precision = 0)
# Classify the encrypted data
model.eval()
print("forward")
x = model(x['input_ids'], x['token_type_ids'])
print('output_enc')
# Compute the accuracy
output = x.get_plain_text()
print(f"{output.shape=}")
print(f"{output=}")

loading data
forward
output_enc
output.shape=torch.Size([1, 2])
output=tensor([[ 2.1267, -1.6942]])


In [10]:
label = 3
x = {}
x['input_ids'] = crypten.cryptensor(data[label]["input_ids"], precision = 0)
x['token_type_ids'] = crypten.cryptensor(data[label]["token_type_ids"], precision = 0)
outputs = model(x['input_ids'], x['token_type_ids'])
cleartext = outputs.get_plain_text()
print(cleartext)

tensor([[ 2.7345, -2.2678]])


In [11]:
total = 0
count = 0
for i in range(10):
    outputs = bert_model(**data[i])
    print(outputs.logits)
    count += targets[i] == outputs.logits.argmax()
    total += 1
count / total

tensor([[ 2.0014, -1.6264]], grad_fn=<AddmmBackward0>)
tensor([[-3.5682,  2.9418]], grad_fn=<AddmmBackward0>)
tensor([[-2.8672,  2.4340]], grad_fn=<AddmmBackward0>)
tensor([[ 3.1729, -2.6355]], grad_fn=<AddmmBackward0>)
tensor([[-2.2055,  1.8402]], grad_fn=<AddmmBackward0>)
tensor([[-3.6816,  3.0014]], grad_fn=<AddmmBackward0>)
tensor([[-1.3215,  1.2353]], grad_fn=<AddmmBackward0>)
tensor([[-3.7192,  2.9996]], grad_fn=<AddmmBackward0>)
tensor([[-3.8530,  3.0773]], grad_fn=<AddmmBackward0>)
tensor([[ 3.5189, -2.9471]], grad_fn=<AddmmBackward0>)


tensor(1.)

In [None]:
total = 0
count = 0
for label in range(1000):
    x = {}
    x['input_ids'] = crypten.cryptensor(data[label]["input_ids"], precision = 0)
    x['token_type_ids'] = crypten.cryptensor(data[label]["token_type_ids"], precision = 0)
    outputs = model(x['input_ids'], x['token_type_ids'])
    cleartext = outputs.get_plain_text()
    print(cleartext)
    count += targets[label] == cleartext.argmax()
    total += 1
count / total

In [19]:
for label in range(861, 1000):
    x = {}
    x['input_ids'] = crypten.cryptensor(data[label]["input_ids"], precision = 0)
    x['token_type_ids'] = crypten.cryptensor(data[label]["token_type_ids"], precision = 0)
    outputs = model(x['input_ids'], x['token_type_ids'])
    cleartext = outputs.get_plain_text()
    print(cleartext)
    count += targets[label] == cleartext.argmax()
    total += 1
count / total

tensor([[-1.0731e+14,  4.7462e+13]])
tensor([[2.6860e+13, 9.0002e+13]])
tensor([[-5.0927e+13, -1.0232e+14]])
tensor([[-6.8693e+13, -6.0797e+13]])
tensor([[-2.8268e+13, -8.9721e+13]])
tensor([[-7.4354e+13, -6.7186e+13]])
tensor([[3.1812e+13, 8.1238e+13]])
tensor([[ 1.3858e+14, -1.3209e+14]])
tensor([[-1.3398e+14, -1.4564e+13]])
tensor([[-8.6160e+13,  9.1290e+13]])
tensor([[-5.9247e+13,  1.1057e+14]])
tensor([[-2.4512e+11,  5.3922e+12]])
tensor([[-9.2227e+13,  9.4108e+13]])
tensor([[ 3.1355e+12, -1.1937e+14]])
tensor([[-1.2403e+14, -7.3490e+13]])
tensor([[-5.7338e+13, -3.2418e+13]])
tensor([[-1.1355e+14, -4.5022e+13]])
tensor([[6.0592e+13, 5.1518e+13]])
tensor([[ 8.1546e+13, -7.4386e+13]])
tensor([[ 4.2795e+13, -1.0965e+14]])
tensor([[1.2805e+14, 4.6305e+13]])
tensor([[6.9031e+13, 3.2738e+12]])
tensor([[4.9871e+13, 6.0588e+13]])
tensor([[-4.6483e+13, -6.8444e+13]])
tensor([[-3.1408e+13,  1.3933e+14]])
tensor([[-1.2467e+13, -4.6917e+13]])
tensor([[-1.1689e+14, -4.2620e+12]])
tensor([[ 1.6

tensor(0.8070)

In [13]:
# Load data to Bob
print('loading data')
# Classify the encrypted data
print("forward")
x = {}
x['input_ids'] = crypten.cryptensor(data[0]["input_ids"], precision = 0)
x['token_type_ids'] = crypten.cryptensor(data[0]["token_type_ids"], precision = 0)
# x['input_ids'] = crypten.cryptensor(torch.tensor([1]), precision = 0)
# x['token_type_ids'] = crypten.cryptensor(torch.tensor([0, 1, 0, 1]), precision = 0)
word_embedding = model.word_embed(x["input_ids"])
pos_embedding = model.pos_embed(x)[:, :x["input_ids"].size()[1], :]
type_embedding = model.type_embed(x["token_type_ids"])
x = word_embedding + pos_embedding + type_embedding
x = model.ln(x)
# v = '0'
# x = model.blocks._modules[v].attn(x)
x = model.blocks._modules['0'](x)
x = model.blocks._modules['1'](x)
x = model.blocks._modules['2'](x)
x = model.blocks._modules['3'](x)
x = model.blocks._modules['4'](x)
x = model.blocks._modules['5'](x)
x = model.blocks._modules['6'](x)
x = model.blocks._modules['7'](x)
x = model.blocks._modules['8'](x)
x = model.blocks._modules['9'](x)
x = model.blocks._modules['10'](x)
x = model.blocks._modules['11'](x)
x = model.pooler(x[:, 0, :])
x = x.tanh()
x = model.classifier(x)

# x = model(x)
print('output_enc')
# Compute the accuracy
output = x.get_plain_text()
print(f"{output.shape=}")
print(f"{output=}")
print(f"{output.max()} {output.min()}")
print(f"{output.mean()} {output.var()}")

loading data
forward


AttributeError: 'BertForSequenceClassification' object has no attribute 'word_embed'

In [None]:
# class Attention(torch.nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(Attention, self).__init__()

#         assert embed_dim % num_heads == 0, "invalid heads and embedding dimension"

#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.search_dim = embed_dim // num_heads

#         self.key = torch.nn.Linear(embed_dim, embed_dim)
#         self.value = torch.nn.Linear(embed_dim, embed_dim)
#         self.query = torch.nn.Linear(embed_dim, embed_dim)
#         self.proj = torch.nn.Linear(embed_dim, embed_dim)


#     def forward(self, x):
#         batch_size = x.shape[0]
#         seq_len = x.shape[1]

#         k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).permute(0, 2, 3, 1)
#         v = self.value(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)
#         q = self.query(x).reshape(batch_size, seq_len, self.num_heads, self.search_dim).transpose(1, 2)

#         attn = q.matmul(k_t) / math.sqrt(q.size(-1))
#         attn = attn.softmax(dim=-1)
#         y = attn.matmul(v)
#         y = y.transpose(1, 2)
#         y = y.reshape(batch_size, seq_len, self.embed_dim)
#         return y


# class BertBlock(torch.nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super(BertBlock, self).__init__()
#         embed_dim = embed_dim
#         self.ln1 = torch.nn.LayerNorm(embed_dim)
#         self.ln2 = torch.nn.LayerNorm(embed_dim)
#         self.attn = Attention(embed_dim, num_heads)
#         self.ff = torch.nn.Sequential(
#             torch.nn.Linear(embed_dim, embed_dim * 4),
#             torch.nn.GELU(),
#             torch.nn.Linear(embed_dim * 4, embed_dim),
#         )

#     def forward(self, x):
#         x = self.ln1(x + self.attn(x))
#         x = self.ln2(x + self.ff(x))
#         return x

# class Bert(torch.nn.Module):
#     def __init__(self, embed_dim, num_heads, num_blocks, vocab_size, seq_len, full=True):
#         super(Bert, self).__init__()
#         self.full = full
#         if full:
#             self.word_embed = torch.nn.Embedding(vocab_size, embed_dim)
#             self.pos_embed = torch.nn.Parameter(torch.zeros(1, seq_len, embed_dim))
#             self.type_embed = torch.nn.Embedding(2, embed_dim)
#             self.ln = torch.nn.LayerNorm(embed_dim)
#         self.blocks = torch.nn.Sequential(
#             *[BertBlock(embed_dim, num_heads) for _ in range(num_blocks)]
#         )
#         if full:
#             self.pooler = torch.nn.Linear(embed_dim, embed_dim)
#             self.classifier = torch.nn.Linear(embed_dim, 2)

#     def forward(self, x, target=None):
#         if self.full:
#             word_embedding = self.word_embed(x["input_ids"])
#             pos_embedding = self.pos_embed[:, :x["input_ids"].size()[1], :]
#             type_embedding = self.type_embed(x["token_type_ids"])
#             x = word_embedding + pos_embedding + type_embedding
#             x = self.ln(x)
#         x = self.blocks(x)
#         if self.full:
#             x = self.pooler(x[:, 0, :])
#             x = x.tanh()
#             x = self.classifier(x)
#         return x

# plain_model = Bert(768, 12, 12, 30522, 1024, True) # bert base 13.5s

In [None]:
class BertEmbeddings(nn.Module):
    def __init__(self, vocab_size, emb_size, max_seq_length):
        super().__init__()
        self.word_embeddings = nn.Embedding(vocab_size, emb_size)
        self.position_embeddings = nn.Embedding(max_seq_length, emb_size)
        self.token_type_embeddings = nn.Embedding(2, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, input_ids, token_type_ids):
        word_emb = self.word_embeddings(input_ids)
        pos_emb = self.position_embeddings.weight[:input_ids.size()[1], :]
        type_emb = self.token_type_embeddings(token_type_ids)

        emb = word_emb + pos_emb + type_emb
        emb = self.LayerNorm(emb)

        return emb

class BertSelfAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_size = emb_size // self.n_heads
        self.query = nn.Linear(emb_size, emb_size)
        self.key = nn.Linear(emb_size, emb_size)
        self.value = nn.Linear(emb_size, emb_size)

    def forward(self, emb):
        B, T, C = emb.shape  # batch size, sequence length, embedding size

        q = self.query(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = self.key(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        v = self.value(emb).view(B, T, self.n_heads, self.head_size).transpose(1, 2)

        weights = q @ k.transpose(-2, -1) * self.head_size**-0.5

        weights = weights.softmax(dim=-1)

        emb_rich = weights @ v
        emb_rich = emb_rich.transpose(1, 2).contiguous().view(B, T, C)

        return emb_rich

class BertSelfOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, emb_rich, emb):
        x = self.dense(emb_rich)
        x = x + emb
        out = self.LayerNorm(x)

        return out


class BertAttention(nn.Module):
    def __init__(self, emb_size, n_heads):
        super().__init__()
        self.self = BertSelfAttention(emb_size, n_heads)
        self.output = BertSelfOutput(emb_size)

    def forward(self, emb):
        emb_rich = self.self(emb)
        out = self.output(emb_rich, emb)

        return out

class BertIntermediate(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, 4 * emb_size)
        self.gelu = nn.GELU()

    def forward(self, att_out):
        x = self.dense(att_out)
        out = self.gelu(x)
        return out


class BertOutput(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(4 * emb_size, emb_size)
        self.LayerNorm = nn.LayerNorm(emb_size)

    def forward(self, intermediate_out, att_out):
        x = self.dense(intermediate_out)
        x = x + att_out
        out = self.LayerNorm(x)

        return out

class BertLayer(nn.Module):
    def __init__(self, emb_size, n_heads ):
        super().__init__()
        self.attention = BertAttention(emb_size, n_heads)
        self.intermediate = BertIntermediate(emb_size)
        self.output = BertOutput(emb_size)

    def forward(self, emb):
        att_out = self.attention(emb)
        intermediate_out = self.intermediate(att_out)
        out = self.output(intermediate_out, att_out)

        return out

class BertEncoder(nn.Module):
    def __init__(self, emb_size, n_heads, n_layers):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(emb_size, n_heads) for i in range(n_layers)])

    def forward(self, x):
        for l in self.layer:
            x = l(x)
        return x

class BertPooler(nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.dense = nn.Linear(emb_size, emb_size)
        self.tanh = nn.Tanh()

    def forward(self, encoder_out):
        pool_first_token = encoder_out[:, 0]
        out = self.dense(pool_first_token)
        out = self.tanh(out)
        return out

class BertModel(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads, n_layers):
        super().__init__()
        self.embeddings = BertEmbeddings(vocab_size, emb_size, seq_len)
        self.encoder = BertEncoder(emb_size, n_heads, n_layers)
        self.pooler = BertPooler(emb_size)

    def forward(self, input_ids, token_type_ids):
        emb = self.embeddings(input_ids, token_type_ids)
        out = self.encoder(emb)
        pooled_out = self.pooler(out)
        return out, pooled_out

class BertForSequenceClassification(nn.Module):
    def __init__(self, vocab_size, emb_size, seq_len, n_heads, n_layers):
        super().__init__()
        self.bert = BertModel(vocab_size, emb_size, seq_len, n_heads, n_layers)
        self.classifier = nn.Linear(emb_size, 2)

    def forward(self, input_ids, token_type_ids):
        _, pooled_out = self.bert(input_ids, token_type_ids)
        logits = self.classifier(pooled_out)
        return logits

# model = LLM(Bert(768, 12, 12, 30522, 1024))
model = BertForSequenceClassification(28996, 768, 512, 12, 12)
model.load_state_dict(weights)

In [None]:
# plain_model.word_embed.weight = torch.nn.Parameter(weights["bert.embeddings.word_embeddings.weight"])
# plain_model.pos_embed.data = torch.nn.Parameter(weights["bert.embeddings.position_embeddings.weight"][None, :, :])
# plain_model.type_embed.weight = torch.nn.Parameter(weights["bert.embeddings.token_type_embeddings.weight"])
# plain_model.ln.weight = torch.nn.Parameter(weights["bert.embeddings.LayerNorm.weight"])
# plain_model.ln.bias = torch.nn.Parameter(weights["bert.embeddings.LayerNorm.bias"])
# for m in range(len(plain_model.blocks._modules)):
#     layer = "bert.encoder.layer."
#     plain_model.blocks._modules[str(m)].attn.query.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.query.weight"])
#     plain_model.blocks._modules[str(m)].attn.query.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.query.bias"])
#     plain_model.blocks._modules[str(m)].attn.key.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.key.weight"])
#     plain_model.blocks._modules[str(m)].attn.key.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.key.bias"])
#     plain_model.blocks._modules[str(m)].attn.value.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.self.value.weight"])
#     plain_model.blocks._modules[str(m)].attn.value.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.self.value.bias"])
#     plain_model.blocks._modules[str(m)].attn.proj.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.output.dense.weight"]) # .t()
#     plain_model.blocks._modules[str(m)].attn.proj.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.output.dense.bias"])
#     plain_model.blocks._modules[str(m)].ln1.weight = torch.nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.weight"])
#     plain_model.blocks._modules[str(m)].ln1.bias = torch.nn.Parameter(weights[layer+str(m)+".attention.output.LayerNorm.bias"])
#     plain_model.blocks._modules[str(m)].ff._modules['0'].weight = torch.nn.Parameter(weights[layer+str(m)+".intermediate.dense.weight"]) # .t()
#     plain_model.blocks._modules[str(m)].ff._modules['0'].bias = torch.nn.Parameter(weights[layer+str(m)+".intermediate.dense.bias"])
#     plain_model.blocks._modules[str(m)].ff._modules['2'].weight = torch.nn.Parameter(weights[layer+str(m)+".output.dense.weight"]) # .t()
#     plain_model.blocks._modules[str(m)].ff._modules['2'].bias = torch.nn.Parameter(weights[layer+str(m)+".output.dense.bias"])
#     plain_model.blocks._modules[str(m)].ln2.weight = torch.nn.Parameter(weights[layer+str(m)+".output.LayerNorm.weight"])
#     plain_model.blocks._modules[str(m)].ln2.bias = torch.nn.Parameter(weights[layer+str(m)+".output.LayerNorm.bias"])
# plain_model.pooler.weight = torch.nn.Parameter(weights["bert.pooler.dense.weight"]) # .t()
# plain_model.pooler.bias = torch.nn.Parameter(weights["bert.pooler.dense.bias"])
# plain_model.classifier.weight = torch.nn.Parameter(weights["classifier.weight"]) # .t()
# plain_model.classifier.bias = torch.nn.Parameter(weights["classifier.bias"])

In [None]:
# Load data to Bob
print('loading data')
# Classify the encrypted data
print("forward")
x = {}
x['input_ids'] = data[0]["input_ids"]
x['token_type_ids'] = data[0]["token_type_ids"]
# x['input_ids'] = torch.tensor([1])
# x['token_type_ids'] = torch.tensor([0, 1, 0, 1])
word_embedding = plain_model.word_embed(x["input_ids"])
pos_embedding = plain_model.pos_embed[:, :x["input_ids"].size()[1], :]
type_embedding = plain_model.type_embed(x["token_type_ids"])
x = word_embedding + pos_embedding + type_embedding
x = plain_model.ln(x)
# v = '0'
# x = plain_model.blocks._modules[v].attn(x)
x = plain_model.blocks._modules['0'](x)
x = plain_model.blocks._modules['1'](x)
x = plain_model.blocks._modules['2'](x)
x = plain_model.blocks._modules['3'](x)
x = plain_model.blocks._modules['4'](x)
x = plain_model.blocks._modules['5'](x)
x = plain_model.blocks._modules['6'](x)
x = plain_model.blocks._modules['7'](x)
x = plain_model.blocks._modules['8'](x)
x = plain_model.blocks._modules['9'](x)
x = plain_model.blocks._modules['10'](x)
x = plain_model.blocks._modules['11'](x)
x = plain_model.pooler(x[:, 0, :])
x = x.tanh()
x = plain_model.classifier(x)

# x = plain_model(x)
print('output')
# Compute the accuracy
output = x
print(f"{output.shape=}")
print(f"{output=}")
print(f"{output.max()} {output.min()}")
print(f"{output.mean()} {output.var()}")

In [None]:
count = 0
total = 0
for label in range(10):
    x = {}
    x['input_ids'] = data[label]["input_ids"]
    x['token_type_ids'] = data[label]["token_type_ids"]
    outputs = plain_model(x)
    print(outputs)
    count += targets[label] == outputs.argmax()
    total += 1
count / total