## This Notebook is intended to show a method to compute precision in FHE computations

The resulting FHE vectors have been computed using the C++ program in verbose mode.

Replicate it by launching the following command:

```
./FHEBERT-tiny "Nuovo Cinema Paradiso has been an incredible movie! A gem in the italian culture." --verbose
```

In [39]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np
import math
from matplotlib import pyplot as plt 
from datasets import load_dataset
import pandas as pd

def precision(correct, approx):
    if type(approx) == list:
        approx = np.array(approx)
    absolute = sum(abs(correct - approx))/len(correct)
    relative = absolute / (sum(abs(correct))/len(correct))
    return 1 - relative

In [40]:
from transformers import logging
logging.set_verbosity_error() #Otherwise it will log annoying warnings

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
trained = torch.load('SST-2-BERT-tiny.bin', map_location=torch.device('cpu'))
model.load_state_dict(trained , strict=True)

model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, e

In [142]:
text = "it 's just disappointingly superficial -- a movie that has all the elements necessary to be a fascinating , involving character study , but never does more than scratch the surface . "
text = "[CLS] " + text + " [SEP]"

In [143]:
text

"[CLS] it 's just disappointingly superficial -- a movie that has all the elements necessary to be a fascinating , involving character study , but never does more than scratch the surface .  [SEP]"

In [144]:
#This is computed client-side

tokenized = tokenizer(text)
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

x = model.bert.embeddings(tokens_tensor, torch.tensor([[1] * len(tokenized_text)]))

In [145]:
#for i in range(len(x[0])):
#    np.savetxt('../sample-inputs/0/input_{}.txt'.format(i), x[0][i].detach(), delimiter=',')

### 1) Layer 1 -- Self-Attention

In [146]:
key = model.bert.encoder.layer[0].attention.self.key.weight.clone().detach().double().transpose(0, 1)
query = model.bert.encoder.layer[0].attention.self.query.weight.clone().detach().double().transpose(0, 1)
value = model.bert.encoder.layer[0].attention.self.value.weight.clone().detach().double().transpose(0, 1)

key_bias = model.bert.encoder.layer[0].attention.self.key.bias.clone().detach().double()
query_bias = model.bert.encoder.layer[0].attention.self.query.bias.clone().detach().double()
value_bias = model.bert.encoder.layer[0].attention.self.value.bias.clone().detach().double()

original_input_tensor = x.double()

input_tensor = x.double()

q = torch.matmul(input_tensor, query) + query_bias
k = torch.matmul(input_tensor, key) + key_bias
v = torch.matmul(input_tensor, value) + value_bias

q = q.reshape([1, input_tensor.size()[1], 2, 64])
k = k.reshape([1, input_tensor.size()[1], 2, 64])
v = v.reshape([1, input_tensor.size()[1], 2, 64])

q = q.permute([0, 2, 1, 3])
k = k.permute([0, 2, 3, 1])

qk = torch.matmul(q, k)
qk = qk / 8

qk_softmaxed = torch.softmax(qk, -1)

v = v.permute([0, 2, 1, 3])

fin = torch.matmul(qk_softmaxed, v)
fin = fin.permute([0, 2, 1, 3])
fin = fin.reshape([1, input_tensor.size()[1], 128])

In [148]:
fhe_vector = np.array([ -0.3605,  0.0869,  0.7868, -0.1405, -0.0443, -0.2340,  0.3175,  0.0588, -0.9211, -0.3049,  0.0520, -0.4095,  0.0655, -0.1715, -0.2530, -0.7243, -0.3613, -0.4495,  0.5664, -0.0007,  0.4017, -0.7573,  0.0569,  0.4485, -0.5187, -0.4934,  0.5262, -0.6699,  0.6010, -0.3901, -0.2280,  0.2874,  0.0076, -0.4084, -0.0639,  0.5529, -0.3038, -0.8664, -0.0366,  0.5128,  0.2738, -0.3366,  0.6202,  0.3919,  0.0032, -0.3683,  0.7703, -0.1979,  0.2448,  0.7228,  0.0926, -0.1409, -0.3108,  0.4950, -0.4295, -0.0281,  0.5658,  0.2900, -0.1685,  0.5634,  0.2320, -0.6074, -0.1157,  0.4062,  0.7357,  0.2348, -0.3566,  1.2485,  0.7350, -0.3265, -0.3070,  0.1094,  0.8037, -0.0739, -0.8402,  0.5542,  0.1176,  0.8950, -0.3943, -1.1855,  0.3308,  0.6542,  0.5699,  0.0745, -0.1792, -1.0596, -0.5243,  0.0910, -0.0077, -1.3643,  0.6721, -0.5084, -0.0865,  0.7021,  0.1404,  1.3437,  1.0542, -0.3776,  0.0284,  1.9061, -0.1753,  0.7719, -1.0014,  0.2928, -0.9781,  0.6475,  0.7803,  0.8051, -1.0638, -0.3390,  0.9720,  0.2629, -0.3725,  0.0700, -2.1387,  0.1663,  0.8791, -0.5437, -0.4642,  1.3472,  1.5909, -0.3683, -0.4747,  0.2629,  0.8553, -0.0531, -0.1120, -0.5098 ])

precision(fin[0][0].detach(), fhe_vector)

tensor(0.9926, dtype=torch.float64)

### 2) Layer 1 -- Self-Output

In [150]:
w_output_dense = model.bert.encoder.layer[0].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
b_output_dense = model.bert.encoder.layer[0].attention.output.dense.bias.clone().detach().double()

mean = np.array([-0.03383045433490704, -0.04689138747464171, -0.04320052751297194, -0.04194874763842685, -0.03849735236740709, -0.03583471496309556, -0.036673685450259945, -0.03533623114666153, -0.03301200050649906, -0.03385619903604035, -0.03394064677150061, -0.03581378040060232, -0.04000193681582013, -0.042994980738727644, -0.042689484809151766, -0.0422699887342667, -0.040702211423783496, -0.043257636922742766, -0.040924377288572664, -0.04212762593354266, -0.040090620729304687, -0.03727317047412721, -0.030603299343800818, -0.034141189654495016, -0.03468711091296442, -0.032307857857310274, -0.02926372943560165, -0.031292906450152466, -0.037837883896213766, -0.03745859562807607, -0.03794657692710982, -0.03860214509229593, -0.036185650111238955, -0.039154371235979875, -0.03589729976884486, -0.031731895884233016, -0.03465287223481833, -0.031348414682812194, -0.03688161652969029, -0.03338290816163936, -0.038240660222183975, -0.037525466450406116, -0.038229222217722264, -0.041201914113547705, -0.04212576296359885, -0.03980083151775188, -0.04072657806877826, -0.040145599490268025, -0.036685242667777444, -0.034109016054392725, -0.03544325775104831, -0.03623692053970561, -0.04948334692050963, -0.04596823422981405, -0.04892271117435003])
var = np.array([0.7495962428549272, 0.6109555428467895, 0.6225590467577651, 0.62495153067201, 0.631395549935461, 0.634492711694546, 0.644892789064359, 0.6542099965205022, 0.6595559062153842, 0.6659906881037033, 0.6680168012366937, 0.6758412527257586, 0.6668118068796066, 0.6718192460326265, 0.67786737736941, 0.6808577853930836, 0.6736657333151266, 0.6676446046843724, 0.6659979061989304, 0.6743226078654423, 0.681388263935704, 0.6837117808950258, 0.6907147768934253, 0.684537831509984, 0.6896744328697597, 0.6916627127801457, 0.6954043965468235, 0.6954046755145293, 0.7001025287354249, 0.695094327647078, 0.6854203403085795, 0.7027792682295838, 0.6956849098218769, 0.6945153573872891, 0.6856697060013522, 0.6897353511373785, 0.700668908202082, 0.6965624918742969, 0.7082690699456209, 0.7043163331126293, 0.7070770512949652, 0.7042510307314358, 0.6978925459183357, 0.7205035876616076, 0.6902461198740245, 0.686971254827903, 0.7028843270104062, 0.7032880792671149, 0.7057843340136714, 0.7104860015626775, 0.7321738164781159, 0.71095817492914, 0.7401485084476891, 0.7312957890728539, 0.7375994654874705])
    
fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
fin2_backup = fin2.clone()
fin2_backup = fin2_backup + original_input_tensor

fin3_whole = []

for i in range(len(original_input_tensor.squeeze())):
    fin2 = fin2_backup.squeeze()[i]
    fin3_corr = (fin2.squeeze().detach() - mean[i]) * var[i]
    
    #fin3_corr = (fin2.squeeze().detach() - torch.mean(fin2.squeeze())) / math.sqrt(torch.var(fin2.squeeze()))
    
    w_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.bias.clone().detach().double()

    fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
    fin3_whole.append(fin3_corr.detach())

fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)

In [151]:
fhe_vector = np.array( [ -0.6502,  0.1038, -11.9365, -0.2933,  0.7621,  0.0864,  1.0436, -0.9450, -2.0711,  1.0147, -1.0093,  0.2296,  0.4937,  1.0723, -0.0604,  0.3134,  1.5943,  1.4551, -0.0586,  0.7537, -0.2927, -0.5716,  2.4983,  2.4837,  0.8170, -0.2943,  0.0615, -0.5154, -0.6066, -0.2910,  0.0230, -0.2595, -4.5228,  0.9756, -0.1828,  0.1736,  0.4966,  0.2262,  0.9596,  0.3358,  0.6954, -0.3204,  3.2484, -0.5211, -0.8727,  0.3096,  1.0428,  0.8283,  0.3429,  0.1663,  0.4823,  0.8321, -0.8450,  0.6826,  0.2240,  0.8249, -0.5279, -0.1942,  1.5627,  1.3711,  0.0298, -0.0128, -0.3125,  0.7707, -0.0211,  0.6369, -0.0852,  0.6062, -0.5113, -0.0658, -0.2131, -0.0578,  0.3495, -3.6247,  0.4657, -1.8394,  0.8633,  0.5575,  0.7089, -4.1608,  0.0935, -1.6439,  1.0793,  1.6602,  0.9174, -0.4732,  0.7837, -0.1984, -0.9227,  0.3616,  0.2056, -0.5405, -0.2051,  0.5238, -0.4161,  1.1202, -0.4941, -0.4859,  0.6556, -0.7438, -0.2146,  0.2825,  1.0641,  0.9456,  0.2273,  1.4860,  1.0908,  1.1163,  1.2558, -0.0576,  0.7241,  1.6288, -0.2308,  0.1932, -0.0035, -0.6822, -0.6336, -0.3637, -0.8280, -0.9068, -1.2381, -0.3880, -0.8409,  0.4272,  0.8688, -0.1926, -0.6532,  0.8866,  ])
precision(fin3_whole[0][0].detach(), fhe_vector)

tensor(0.9949, dtype=torch.float64)

In [81]:
fin3_whole

tensor([[[ -1.3148,   0.3858, -11.3850,  ...,  -0.3709,   0.8667,   1.2124],
         [ -3.4727,   0.9965,  -0.7988,  ...,  -3.9541,   1.2899,   1.2560],
         [ -3.1280,   0.5889,  -1.6285,  ...,  -3.5301,   0.3958,   1.1160],
         ...,
         [ -4.1072,   0.0962,  -1.3231,  ...,  -2.8202,   1.2232,   3.5629],
         [ -2.4529,   1.2259,  -0.3580,  ...,  -1.8126,  -0.9176,  -0.2391],
         [ -3.6870,   0.3899,   0.3177,  ...,  -1.1680,  -1.2935,   1.0331]]],
       dtype=torch.float64)

### 3) Layer 1 -- Intermediate

In [152]:
fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[0].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].intermediate.dense.bias

fin_5 = torch.nn.functional.gelu(fin_4)

In [153]:
fin_5

tensor([[[-1.6816e-01, -1.5905e-01,  7.4299e-02,  ..., -1.5669e-01,
           1.1731e-01, -8.6088e-03],
         [-6.1865e-03, -4.9313e-02, -1.6767e-01,  ..., -1.4382e-01,
           2.0933e-02, -4.2751e-03],
         [-1.1197e-01, -1.2775e-01, -9.4852e-02,  ..., -1.2677e-02,
           3.8604e-01, -9.9763e-02],
         ...,
         [-1.1775e-01, -5.8830e-02,  8.1227e-02,  ..., -1.6097e-01,
          -1.6579e-01, -1.9127e-02],
         [-9.6812e-02,  3.7330e-01, -8.2041e-02,  ..., -1.0699e-01,
          -1.2465e-01, -3.4668e-03],
         [-1.6409e-01,  2.5816e-03,  6.5666e-03,  ..., -1.3236e-01,
           7.3566e-01, -4.0566e-04]]], dtype=torch.float64,
       grad_fn=<GeluBackward0>)

In [157]:
fhe_vector = np.array([ -0.1685, -0.1586,  0.0765, -0.1698,  0.3273, -0.1607, -0.1696,  1.2687, -0.0419, -0.1277, -0.0821, -0.1406, -0.1599, -0.1520, -0.0060, -0.0025, -0.1700, -0.1660, -0.1254,  0.1932,  0.0589,  1.2671,  0.0365, -0.1622,  1.0573, -0.1267,  0.4731,  0.0235, -0.1493, -0.1698, -0.0288, -0.0114, -0.1693, -0.0000, -0.0410, -0.1287,  0.0718, -0.1129,  0.9469, -0.1466, -0.1699, -0.1624,  0.7507, -0.1682, -0.1668, -0.0837,  0.8414,  0.1870,  0.4747,  0.2399,  0.5632, -0.1233, -0.1152,  0.8708,  0.1919,  0.2623, -0.0803, -0.0168,  0.1063,  0.3965,  0.7230,  0.9475, -0.0497, -0.1606, -0.0839, -0.1213, -0.0706,  0.1691,  0.8465,  0.7356,  0.2259,  0.0206, -0.1344, -0.0210, -0.1337,  0.1457,  0.8905, -0.0590,  0.7177, -0.1134,  0.4805, -0.1641, -0.1128,  0.0463, -0.0404, -0.1591, -0.1124, -0.1692, -0.0003,  0.1175, -0.1519,  1.0902,  0.2578, -0.0237, -0.0899, -0.0111,  0.0978, -0.0064, -0.1350, -0.1407,  0.1565, -0.1618,  0.0684,  0.9367,  0.0593,  0.0659, -0.1598,  0.9292,  0.7847,  0.3192,  0.4959,  0.3916, -0.1640, -0.0241, -0.1470,  0.3871, -0.1155, -0.0486,  0.5500, -0.1475, -0.1380, -0.1625,  0.0002,  0.5865,  0.1104, -0.1695,  0.4012, -0.1105 ])
precision(fin_4[0][0][:128].detach(), fhe_vector)

tensor(0.2919, dtype=torch.float64)

In [158]:
fhe_vector = np.array([ -0.1583, -0.0932,  0.1390, -0.1316,  0.1832, -0.1579,  0.0295,  0.6584, -0.1697, -0.0745, -0.0778, -0.1361, -0.0851, -0.1277, -0.1238, -0.0024, -0.1698,  0.0091, -0.0671,  0.1949,  0.0653,  0.5006,  0.1600, -0.1693,  0.8949, -0.1537,  0.1074, -0.0665, -0.1015, -0.1674, -0.0546, -0.0174, -0.1697, -0.0003, -0.0391, -0.1185,  0.0782, -0.1243,  0.9874, -0.1250, -0.1696, -0.1400,  0.0816, -0.1607, -0.1580, -0.1083,  0.3395, -0.0009,  0.4383,  0.2845, -0.0485, -0.1432, -0.1681,  0.5591,  0.1691,  0.3170, -0.1693, -0.1617,  0.3886,  0.8495, -0.0699,  0.4940, -0.0654, -0.1281, -0.1352, -0.1420, -0.1601,  0.1337,  0.8145,  0.2782, -0.1447,  0.0203, -0.1170, -0.0540,  0.3685,  0.3343,  1.0245, -0.0258,  0.1222, -0.1396,  0.4261, -0.1621, -0.1361,  0.1135, -0.0559, -0.1423,  0.1038, -0.1685, -0.0009, -0.0530, -0.1691,  0.3990,  0.4315,  0.1031,  0.1540, -0.0150,  0.6171, -0.0768, -0.1354, -0.1670, -0.1306, -0.1573,  0.1795,  0.4059,  0.1407,  0.0300, -0.1289,  0.9738,  0.1634,  0.1293,  0.3501,  0.0138, -0.1146, -0.1273, -0.0390,  0.4399, -0.1216, -0.0955, -0.1645, -0.1451, -0.1102, -0.0096, -0.1562,  0.4812,  0.0059, -0.1617,  0.1348,  0.2014 ])
precision(fin_5[0][0][:128].detach(), fhe_vector)

tensor(0.4431, dtype=torch.float64)

### 4) Layer 1 -- Output

In [159]:
mean = np.array([-0.09545516102868973, 0.034540955180462664, 0.03934738149667437, 0.040802318439555035, 0.04426037798445811, 0.04919343175846099, 0.0493616301294401, 0.047896279398118795, 0.04912640635535303, 0.048717249992826256, 0.0477219385203478, 0.05095357678578503, 0.05094908370417657, 0.0493275745992752, 0.048418324664654545, 0.0473653504669205, 0.04528009986283869, 0.04524247257539856, 0.046555073355952846, 0.0516135997743503, 0.049103903254210594, 0.048877585502238356, 0.048364988370661784, 0.049043507301742846, 0.049933470462367846, 0.05175179126331398, 0.05057227793143223, 0.055763206569478994, 0.055243365455213404, 0.04986745821758072, 0.047789218698650125, 0.047852162700887234, 0.04279460740337753, 0.04280733225675328, 0.04644169155736491, 0.04783492130826333, 0.04759649093761958, 0.045252139153821, 0.04367184005341422, 0.039034762655413016, 0.04374965234639466, 0.04355128435775863, 0.04499861862695065, 0.04318602336450084, 0.04549296197766528, 0.03907804279518851, 0.037683132925437485, 0.04109696491189214, 0.04410155617431274, 0.05015992918511731, 0.04335430986396108, 0.046492484403760526, 0.044277581701870204, 0.03723061917091777, 0.039156973130334664])
var = np.array([0.4156698594967092, 0.7008452266859936, 0.7214270983257646, 0.7095727482866087, 0.7102521835201318, 0.710293676073547, 0.7091783271698753, 0.6973493176419543, 0.7011688527520855, 0.7007704875343309, 0.6950537183089973, 0.6948029158092094, 0.6919309911197036, 0.6933694537037308, 0.6970711644923971, 0.7004276850010867, 0.6964234913676165, 0.6987678419874651, 0.6951829293138483, 0.6973048809142951, 0.6989420799277399, 0.7005696487948311, 0.6993937733493811, 0.6902070532566239, 0.6958399824203775, 0.6900361005407983, 0.6925891359742274, 0.6831642926666377, 0.6865279710039072, 0.6904370385593245, 0.6963724536275457, 0.6948942601360332, 0.6784634186071326, 0.6759657478656234, 0.6828578884489792, 0.683566347862741, 0.6857777074044566, 0.672040915409448, 0.6784995422914343, 0.6732453264186854, 0.683881765911935, 0.6909411690410042, 0.6715428435769978, 0.6775867807314924, 0.6785015863916147, 0.676156117696202, 0.6786376609996214, 0.6763771062984715, 0.7119440584663215, 0.7070342067744777, 0.6895996022331654, 0.6683970656272868, 0.6695013664908844, 0.6566575067124804, 0.672887703816164])    
    
fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[0].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].output.dense.bias
fin_6 = fin_6 + fin3_whole

fin7_whole = []

for i in range(len(input_tensor.squeeze())):
    fin_7 = fin_6.squeeze()[i]
    
    fin7_corr = (fin_7.squeeze().detach() - mean[i]) * var[i]
    
    w_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.bias.clone().detach().double()

    fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm

    fin7_whole.append(fin7_corr.detach())

fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

In [160]:
fhe_vector = np.array([ -0.3007,  0.0240, -6.4075, -0.1774,  0.9323, -0.0049,  0.5888, -0.1074, -0.8474,  0.6121, -0.8743,  0.1986,  0.3994,  0.6398, -0.0501, -0.0060,  0.5857,  1.2665,  0.1118,  0.6966,  0.0023, -0.3905,  1.5584,  1.1895,  0.6970,  0.0453,  0.0163, -0.3130, -0.0078,  0.0211,  0.3925,  0.1109, -2.9748,  0.2944,  0.0968,  0.2541,  0.4003,  0.0404,  0.0288,  0.1705,  0.6248, -0.3489,  1.0776, -0.0865, -0.5512,  0.4508,  0.5693,  0.5113,  0.4272,  0.1222, -0.7093,  0.6784, -0.0379,  0.4320,  0.0886,  0.0856, -0.2977,  0.2092,  0.7456,  0.4874,  0.2559, -0.0279, -0.0128, -0.0677,  0.2999,  0.3994,  0.5172,  0.5540,  0.3606, -0.3796,  0.2487, -0.3342,  0.1204, -1.2565,  0.4647, -0.7414,  0.4185,  0.6306,  0.8280, -1.1122,  0.7631, -0.4105,  0.5644,  0.8341,  0.6168, -0.5376,  0.4419, -0.0303, -0.6366, -0.0609,  0.3319,  0.4920, -0.1648, -0.1613,  0.5028,  0.6506, -0.6504, -0.4798,  0.4654, -0.2111,  0.3318, -0.5171,  0.5169,  0.4179,  0.1037,  0.5769,  0.7918,  0.5318,  0.7841, -0.6297,  0.2247,  1.2829,  0.1715,  0.1452, -0.1313, -0.4007, -0.1594,  0.1177, -0.3348, -0.2741, -1.0940, -0.2212, -0.2961, -0.2539,  0.3822, -0.3997, -0.2394,  0.6668,  ])

precision(fin7_whole[0][0].detach(), fhe_vector)

tensor(0.9939, dtype=torch.float64)

### 5) Layer 2 -- Self-Attention

In [131]:
key = model.bert.encoder.layer[1].attention.self.key.weight.clone().detach().double().transpose(0, 1)
query = model.bert.encoder.layer[1].attention.self.query.weight.clone().detach().double().transpose(0, 1)
value = model.bert.encoder.layer[1].attention.self.value.weight.clone().detach().double().transpose(0, 1)

key_bias = model.bert.encoder.layer[1].attention.self.key.bias.clone().detach().double()
query_bias = model.bert.encoder.layer[1].attention.self.query.bias.clone().detach().double()
value_bias = model.bert.encoder.layer[1].attention.self.value.bias.clone().detach().double()

original_input_tensor = fin7_whole
input_tensor = fin7_whole

q = torch.matmul(input_tensor, query) + query_bias
k = torch.matmul(input_tensor, key) + key_bias
v = torch.matmul(input_tensor, value) + value_bias

q = q.reshape([1, input_tensor.size()[1], 2, 64])
k = k.reshape([1, input_tensor.size()[1], 2, 64])
v = v.reshape([1, input_tensor.size()[1], 2, 64])

q = q.permute([0, 2, 1, 3])
k = k.permute([0, 2, 3, 1])

qk = torch.matmul(q, k)
qk = qk / 8

qk_softmaxed = torch.softmax(qk, -1)

v = v.permute([0, 2, 1, 3])

fin = torch.matmul(qk_softmaxed, v)
fin = fin.permute([0, 2, 1, 3])
fin = fin.reshape([1, input_tensor.size()[1], 128])

In [132]:
fhe_vector = np.array([ -1.5990, -0.1837,  0.7751,  0.2522, -0.4854, -0.8096,  1.5433,  0.3790,  0.0293, -0.8233, -0.3404,  0.5676, -0.4548, -1.7416,  0.7405,  1.6679,  0.7490, -0.4478, -0.5280,  0.2524, -0.8667, -0.9974, -1.5663,  1.5735, -0.5942,  0.4008,  0.3410,  1.1989,  0.9992,  1.0091, -0.3735, -0.9058,  0.8441, -0.5615,  1.4018,  0.8261, -0.4877, -2.2596, -0.0759,  0.9762,  1.2387, -0.4508, -0.5494, -1.9519,  0.8798,  0.5510,  0.1867, -0.2034, -0.2550, -1.0106, -0.7441, -0.4794,  0.2770,  0.4403,  0.4754, -0.4346, -0.4162,  0.6085, -1.2210,  0.2973,  0.3303, -0.2580, -0.9881, -0.0548,  0.1130, -0.0202,  0.0646,  0.3497, -0.2268, -0.1469, -0.9055,  0.1102, -0.2019,  0.4469,  0.2901, -0.1854,  0.1660, -0.3442,  0.1237, -0.0641, -0.0894,  0.4678, -0.2254,  0.0495, -0.0175, -0.0848,  0.4013, -0.2342, -0.3845, -0.2491, -0.5496,  0.3275, -0.5900,  0.5558, -0.5520,  0.1564, -0.3758, -0.2200, -0.1933, -0.1291, -0.0595, -0.1405, -0.5287,  0.5134, -0.2915,  0.0759, -0.4710, -0.1138, -0.2776,  0.3620, -0.9443, -0.2583,  0.2026, -0.0014, -0.0870,  0.2710, -0.3310,  0.8927, -0.7023, -0.5042,  0.0742, -0.3483, -0.2828,  0.6910,  0.2384,  0.3929, -0.3206,  0.0425 ])

precision(fin[0][0].detach(), fhe_vector)

tensor(0.7153, dtype=torch.float64)

### 6) Layer 2 -- Self-Output

In [134]:
mean = np.array([0.04805131047475803, 0.014145706172069285, 0.010630181813540026, 0.010521146572975027, 0.00956244983947186, 0.008211288558782809, 0.008817800275674387, 0.008911457532306733, 0.008643898058317862, 0.008801769546523253, 0.009472254700839258, 0.008094415948174241, 0.007702615754430344, 0.005460620353838359, 0.007021847370084451, 0.008373831982472147, 0.01022061224155272, 0.00927594903773269, 0.009277225000069925, 0.007049453120897054, 0.008682554190420182, 0.008749022040809715, 0.010118317324741522, 0.008998865743435887, 0.008763833543884292, 0.008285728555981435, 0.006967351876718886, 0.00588068616144895, 0.0030701809065725363, 0.003659716972971551, 0.002116778487431024, 0.003947434346765913, 0.006907859825079262, 0.008494112860837831, 0.007040283968419036, 0.007197681884381672, 0.008232685835987293, 0.009965029801574864, 0.00731962961637719, 0.00830555309310382, 0.005340440177451385, 0.007833324368720607, 0.01047456825511633, 0.009674864773662995, 0.010093537461664302, 0.01588798917017868, 0.018537933333636507, 0.018245848282989877, 0.012253993810893607, 0.011354133953173591, 0.013474744814287221, 0.013707011955501919, 0.007918842609048385, 0.017240907760895086, 0.03465881962238184])
var = np.array([0.6741653046411179, 0.602392389437227, 0.5945841451997256, 0.5997135932136959, 0.6033806506910513, 0.6064839949503851, 0.6058735285405447, 0.6059001754921257, 0.6086086189801689, 0.6118981975241923, 0.6161533101614306, 0.6105411757987637, 0.6102443339235957, 0.6004337682468068, 0.6068584434133084, 0.6123178593290803, 0.6150302868629213, 0.6102744641580546, 0.6143169356654037, 0.6105845722771672, 0.61540315154488, 0.622109065598561, 0.6221720668578823, 0.6279330579960701, 0.6282907135959079, 0.6258439179151315, 0.6187239026398644, 0.618294817104495, 0.609488586748927, 0.6085185174201381, 0.6154275326252285, 0.6207534846328591, 0.6290521066315713, 0.6375810334496135, 0.6238236165346044, 0.6310571465398529, 0.6350551779511981, 0.6452639043477173, 0.6346915398812409, 0.646622546259538, 0.6435498445423712, 0.6401589932559348, 0.6458833892517316, 0.6354378204804867, 0.651796667347259, 0.6547600574517144, 0.6554038815336571, 0.655910889886979, 0.6412602949793637, 0.6489736968517984, 0.6633309254993116, 0.6771441398382873, 0.6423362709438692, 0.6302863730404997, 0.5940213893371686])

w_output_dense = model.bert.encoder.layer[1].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
b_output_dense = model.bert.encoder.layer[1].attention.output.dense.bias.clone().detach().double()

fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
fin2_backup = fin2.clone()
fin2_backup = fin2_backup + original_input_tensor

fin3_whole = []

for i in range(len(original_input_tensor.squeeze())):
    fin2 = fin2_backup.squeeze()[i]

    fin3_corr = (fin2.squeeze().detach() - mean[i]) * var[i]

    w_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.bias.clone().detach().double()

    fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
    fin3_whole.append(fin3_corr.detach())

fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)

In [135]:
fhe_vector = np.array([ -1.3344,  0.2835, -1.7833, -0.1106,  0.5614, -0.2162,  1.0203,  0.0973, -0.4119, -0.0139, -0.2767,  0.2300, -0.0070,  0.7883,  0.1941,  1.4207,  1.3967,  0.0698, -1.4442, -0.5511, -0.0545, -0.0853,  0.6875,  0.5139,  0.4317, -0.2034, -0.0719, -0.4483,  0.2270, -0.2727,  0.7259, -0.3560, -2.6692,  0.4223, -0.1196, -0.2982,  0.7261,  0.0995, -1.1343,  1.3658, -0.6335,  0.2158,  0.1109, -0.5617,  0.4306, -1.8238,  0.3798, -0.5813,  0.3885,  0.3324,  0.9758,  1.5692,  0.7256,  1.2132,  0.4934,  0.2799, -0.9350, -1.0343,  0.7930,  0.1330,  1.0144, -0.0372, -1.0106, -1.5178, -0.7635, -0.0910,  0.2632,  1.2250,  0.9479, -0.3563,  1.2092, -0.6060, -0.3435, -0.0398, -0.0809, -0.2783,  1.0083, -0.3137,  0.3951, -0.6469,  0.9773, -0.8072,  0.0259,  0.8244, -0.6673, -1.4109, -0.2372, -1.2556, -0.5015, -1.4908,  1.2925,  0.9175, -0.4494, -0.2011,  1.0650,  0.6107, -0.2736,  0.9693, -0.1507,  0.5105, -1.2599,  0.0694,  0.1838, -1.4038,  0.3698,  0.9902, -0.3737,  2.2232,  0.2465, -0.3858, -0.4558, -0.0372,  1.6856,  0.3768,  0.1177,  0.0457,  0.8807, -0.2895, -0.5373,  1.1324, -1.0272,  0.7337,  1.7779, -0.8356, -1.2441, -0.7107,  0.4361,  2.1687,  ])

precision(fin3_whole[0][0].detach(), fhe_vector)

tensor(0.7321, dtype=torch.float64)

### 7) Layer 2 -- Intermediate

In [136]:
fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[1].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].intermediate.dense.bias   
fin_5 = torch.nn.functional.gelu(fin_4)    

In [137]:
fhe_vector = np.array([ -0.0683, -0.0562,  1.3694, -0.1403, -0.0703, -0.1537, -0.1658,  0.4683, -0.0210, -0.0005, -0.1563, -0.1662, -0.0726, -0.0936, -0.1059, -0.1640, -0.0935, -0.1621, -0.1689, -0.0099, -0.1663, -0.0288, -0.0431, -0.1237, -0.1699,  1.3531, -0.0271, -0.1230, -0.0315, -0.1606, -0.1465, -0.1607, -0.1111, -0.1253, -0.1450,  0.0236,  1.3971, -0.1058, -0.1201,  0.6417, -0.1595, -0.1340,  0.3707, -0.0008, -0.1128,  0.5451, -0.0293, -0.0071,  0.0274,  0.0609, -0.1688,  0.0656,  2.1240, -0.1667, -0.0440,  0.1143,  1.0968, -0.0513,  0.0032, -0.1634, -0.1636, -0.0114,  0.3805, -0.0070, -0.1653, -0.0769, -0.1158,  0.8741,  0.2711,  0.5258,  0.7994, -0.1006,  0.0409,  0.1818,  0.7667, -0.0002, -0.0010, -0.0818,  1.7126, -0.0568, -0.1693,  0.0353, -0.1208,  0.4403, -0.0188, -0.1565, -0.1689, -0.0890, -0.1334,  0.1987, -0.0233,  1.3172, -0.1609,  0.0282, -0.1696, -0.0133, -0.0055, -0.1478, -0.1054, -0.1695,  0.0069, -0.1269,  3.3577, -0.1573, -0.0397,  0.1693,  0.9749,  0.9535,  0.9454,  0.0970, -0.1689,  2.1287,  0.2266,  1.7717, -0.1216, -0.1288, -0.0847, -0.1600, -0.0314,  0.6490, -0.0465, -0.1664, -0.0266, -0.1589, -0.0447, -0.1676, -0.0243, -0.1359 ])
precision(fin_5[0][0][:128].detach(), fhe_vector)

tensor(-0.6709, dtype=torch.float64)

### 8) Layer 2 -- Output

In [138]:
fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[1].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].output.dense.bias
fin_6 = fin_6 + fin3_whole

fin7_whole = []

mean = np.array([0.06643368, 0.05726708, 0.05311476, 0.05229822, 0.05352628, 0.05238868, 0.0536801 , 0.05327334, 0.05206954, 0.05110339, 0.051747  , 0.05016997, 0.04943122, 0.04937956, 0.04952862, 0.04973959, 0.04852742, 0.04696055, 0.04846476, 0.04925392,0.0509005 , 0.05373027, 0.05371865, 0.05446217, 0.05222489,0.05142676, 0.05080909, 0.05179351, 0.05049174, 0.04965748,0.05138143, 0.0499965 , 0.05194982, 0.05178364, 0.0521023 ,0.05059624, 0.05445499, 0.05507825, 0.05241935, 0.05073552,0.05200171, 0.04858642, 0.04419684, 0.04642237, 0.05115073,0.05028116, 0.05021724, 0.05312114, 0.0524375 , 0.04643478,0.05026358, 0.04248708, 0.04675281, 0.03895142, 0.04558007])
var = np.array([0.81992316, 0.78486345, 0.79259   , 0.79754392, 0.79350872, 0.79652433, 0.79935746, 0.79867687, 0.80257863, 0.80235328,0.80521209, 0.80621272, 0.80330435, 0.80469855, 0.81171202,0.81136354, 0.80977166, 0.8089956 , 0.8106946 , 0.80862825,0.81450049, 0.81722176, 0.82121488, 0.82012788, 0.8254015 ,0.82097106, 0.81742119, 0.82090554, 0.82116105, 0.82017896,0.82234659, 0.82832269, 0.82888739, 0.81852014, 0.82054523,0.8224114 , 0.82913892, 0.8289046 , 0.81985612, 0.83341215,0.82896934, 0.82315006, 0.82802216, 0.81886278, 0.8274004 ,0.83436616, 0.82014282, 0.82628005, 0.83230868, 0.84511334,0.85141143, 0.84934269, 0.83041272, 0.826798  , 0.83660989])

for i in range(len(input_tensor.squeeze())):
    fin_7 = fin_6.squeeze()[i]

    fin7_corr = (fin_7.squeeze().detach() - mean[i]) * var[i]

    w_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.bias.clone().detach().double()

    fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm

    fin7_whole.append(fin7_corr.detach())

fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

In [139]:
fhe_vector = np.array([  1.7862, -2.0674, -0.2598, -0.2733, -0.6657,  0.9152,  0.6312,  1.0929,  0.3524,  0.7416,  1.2811,  0.2206, -0.8468, -0.5167, -0.1659,  0.6020,  1.3285, -0.9925, -1.9419,  0.2647,  0.2765,  2.3105,  2.8253,  0.5482, -1.4334, -0.3802, -0.3619, -1.4140,  0.5294,  0.6072, -1.8004,  0.5581, -2.8421,  0.2503,  0.4287, -0.3454,  0.2077, -1.0119, -1.5609, -0.6429, -1.4182,  0.1802, -1.7165, -0.2336,  0.9720, -1.6756, -0.1018,  0.1046,  1.0860,  0.9126,  0.9143,  1.2146,  0.9086, -0.7014,  0.1545,  0.2080, -3.3099, -0.0128,  2.6352,  1.0011,  0.7249, -1.0812, -1.1600, -1.2047, -0.4703, -0.2173, -1.3199,  2.1158,  0.7578, -0.3307,  0.9533,  0.5950, -0.1573,  0.8234,  0.7160,  0.6736,  0.2819,  0.8938, -0.2418,  1.3240, -2.0877, -1.4034, -1.4676,  0.0390,  1.2058, -0.2615, -1.5343, -2.3307,  1.4072,  0.2342, -0.0834,  1.4614, -0.2938,  0.3641,  1.0156, -0.0286, -0.7436, -1.0538,  1.2407, -0.2570, -0.8378, -0.0219,  1.4528, -0.0432, -0.9515, -1.8858, -0.4678, -1.1132,  1.9448, -0.1209,  0.0245, -1.3956, -0.0740,  2.0146,  0.8191,  1.0593,  0.4562, -0.5237, -2.1609, -0.7794,  1.5314,  0.5960,  1.6027,  0.2354, -0.7131, -0.2370, -2.6296, -2.5624,  ])

precision(fin7_whole[0][0].detach(), fhe_vector)

tensor(-0.6248, dtype=torch.float64)

### 9) Pooler

In [141]:
pooler_output = torch.tanh(torch.matmul(fin7_whole.double(), model.bert.pooler.dense.weight.transpose(0, 1).double()) + model.bert.pooler.dense.bias)

fhe_vector = np.array([ -0.9657,  0.0534, -0.9750,  0.5383, -0.9309, -0.1242, -0.8649,  0.9233, -0.1681,  0.1008, -0.4494,  0.0672, -0.1991,  0.7375, -0.1530, -0.5025,  0.9864, -0.0625, -0.4462,  0.6912,  0.9464,  0.1654,  0.4593,  0.1309, -0.9803,  0.1077, -0.9866,  0.9757,  0.9961,  0.1778,  0.0406, -0.1484, -0.7218, -0.8691,  0.9554,  0.9977, -0.9027, -0.0351, -0.1248, -0.5630,  0.9415,  0.8972, -0.9423, -0.2117,  0.9148, -0.2622, -0.8666,  0.7585, -0.0698, -0.1944,  0.6214, -0.9240,  0.3051,  0.8587,  0.9446, -0.9602,  0.9628, -0.8386, -0.2455, -0.5549, -0.3302, -0.2826, -0.1894,  0.1568, -0.4767, -0.9593,  0.9806,  0.6198, -0.5856,  0.9871,  0.9892, -0.0829, -0.9799,  0.0531, -0.2904,  0.5600, -0.9654,  0.0994, -0.7429,  0.2282,  0.0544,  0.1308, -0.7545, -0.9394,  0.9976, -0.8687,  0.4460, -0.4465,  0.6326,  0.7942,  0.9814,  0.9368, -0.8862,  0.9663,  0.5675,  0.2038, -0.7993, -0.8848, -0.9872, -0.9927, -0.5350,  0.9186, -0.9823,  0.0910,  0.3787,  0.5804, -0.9932, -0.8794, -0.4858,  0.9416,  0.9036,  0.9329, -0.9697,  0.4366,  0.5860,  0.1168,  0.5759, -0.2495,  0.0759, -0.2110,  0.2657, -0.9910, -0.6700,  0.2117, -0.9851,  0.9928,  0.9216,  0.1755 ])
precision(pooler_output[0][0].detach(), fhe_vector)

tensor(0.7244, dtype=torch.float64)

### 10) Classifier

In [25]:
classification = torch.matmul(pooler_output, model.classifier.weight.transpose(0, 1).double()) + model.classifier.bias.double()

fhe_vector = np.array([ -0.0197,  0.1699 ])

precision(classification[0][0].detach(), fhe_vector)

tensor(0.5020, dtype=torch.float64)

In [26]:
print("Plain circuit output: {}\nFHE circuit output: {}".format(classification[0][0].detach().numpy(), fhe_vector))

Plain circuit output: [0.02078085 0.13353153]
FHE circuit output: [-0.0197  0.1699]


In [27]:
model(tokens_tensor, torch.tensor([[1] * len(tokenized_text)]))

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.0951, -0.0199]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)