## This Notebook is intended to show a method to compute precision in FHE computations

The resulting FHE vectors have been computed using the C++ program in verbose mode.

Replicate it by launching the following command:

```
./FHEBERT-tiny "Nuovo Cinema Paradiso has been an incredible movie! A gem in the italian culture." --verbose
```

In [39]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np
import math
from matplotlib import pyplot as plt 
from datasets import load_dataset
import pandas as pd

def precision(correct, approx):
    if type(approx) == list:
        approx = np.array(approx)
    absolute = sum(abs(correct - approx))/len(correct)
    relative = absolute / (sum(abs(correct))/len(correct))
    return 1 - relative

In [40]:
from transformers import logging
logging.set_verbosity_error() #Otherwise it will log annoying warnings

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
trained = torch.load('SST-2-BERT-tiny.bin', map_location=torch.device('cpu'))
model.load_state_dict(trained , strict=True)

model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, e

In [71]:
text = "( lawrence bounces ) all over the stage , dancing , running , sweating , mopping his face and generally displaying the wacky talent that brought him fame in the first place . "
text = "[CLS] " + text + " [SEP]"

In [72]:
text

'[CLS] ( lawrence bounces ) all over the stage , dancing , running , sweating , mopping his face and generally displaying the wacky talent that brought him fame in the first place .  [SEP]'

In [73]:
#This is computed client-side

tokenized = tokenizer(text)
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

x = model.bert.embeddings(tokens_tensor, torch.tensor([[1] * len(tokenized_text)]))

In [74]:
#for i in range(len(x[0])):
#    np.savetxt('../sample-inputs/0/input_{}.txt'.format(i), x[0][i].detach(), delimiter=',')

### 1) Layer 1 -- Self-Attention

In [75]:
key = model.bert.encoder.layer[0].attention.self.key.weight.clone().detach().double().transpose(0, 1)
query = model.bert.encoder.layer[0].attention.self.query.weight.clone().detach().double().transpose(0, 1)
value = model.bert.encoder.layer[0].attention.self.value.weight.clone().detach().double().transpose(0, 1)

key_bias = model.bert.encoder.layer[0].attention.self.key.bias.clone().detach().double()
query_bias = model.bert.encoder.layer[0].attention.self.query.bias.clone().detach().double()
value_bias = model.bert.encoder.layer[0].attention.self.value.bias.clone().detach().double()

original_input_tensor = x.double()

input_tensor = x.double()

q = torch.matmul(input_tensor, query) + query_bias
k = torch.matmul(input_tensor, key) + key_bias
v = torch.matmul(input_tensor, value) + value_bias

q = q.reshape([1, input_tensor.size()[1], 2, 64])
k = k.reshape([1, input_tensor.size()[1], 2, 64])
v = v.reshape([1, input_tensor.size()[1], 2, 64])

q = q.permute([0, 2, 1, 3])
k = k.permute([0, 2, 3, 1])

qk = torch.matmul(q, k)
qk = qk / 8

qk_softmaxed = torch.softmax(qk, -1)

v = v.permute([0, 2, 1, 3])

fin = torch.matmul(qk_softmaxed, v)
fin = fin.permute([0, 2, 1, 3])
fin = fin.reshape([1, input_tensor.size()[1], 128])

In [77]:
fhe_vector = np.array([ -0.1864,  0.0359,  0.8130, -0.0891,  0.1157,  0.0847,  0.1650, -0.5956, -0.8062, -0.0082,  0.1067, -0.4779, -0.0233, -0.0575, -0.3933, -0.4412, -0.4133, -0.6620,  0.4666, -0.2215,  0.0762, -0.7581, -0.2348,  0.3608, -0.0836, -0.1974,  0.6343, -0.5655, -0.0053, -0.4426,  0.0384,  0.2947, -0.2067, -0.4707, -0.0871,  0.6517, -0.6589, -0.4605, -0.0783,  0.1602,  0.1052, -0.2409,  1.0541, -0.0897,  0.6023, -0.4129,  0.8202,  0.2241,  0.1216,  0.7801,  0.4308, -0.2763,  0.0906,  0.2970, -0.5456,  0.1085,  0.3494,  0.6032, -0.2180,  0.3150,  0.4404, -0.6391,  0.2835,  0.5721,  1.4293,  0.0680, -0.3189,  1.1362,  0.4736,  0.5106, -0.1547, -0.2553,  0.9037,  0.1954, -0.6704, -0.1891,  0.0801,  0.0633, -0.3092, -0.9273,  0.8928,  0.4880,  0.7071, -0.2215,  0.2903, -0.9620, -0.1780, -0.6417, -1.2386, -0.6511,  0.3044, -0.3024, -0.0448, -0.0249, -0.7101,  0.9553,  0.5963, -0.3677,  0.3483, -0.3662, -0.2212,  0.4121, -0.8762,  0.8329, -0.8962,  0.5615,  0.0446, -0.8245, -0.8327,  0.2437,  1.1858, -0.9220,  0.1997, -0.0575,  0.1217, -0.1036, -0.1878, -0.1736,  0.3420,  0.2190,  0.1645,  0.8943, -0.8772, -0.0570,  0.0773,  0.0373,  0.9794, -0.4930 ])

precision(fin[0][0].detach(), fhe_vector)

tensor(0.9878, dtype=torch.float64)

### 2) Layer 1 -- Self-Output

In [78]:
w_output_dense = model.bert.encoder.layer[0].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
b_output_dense = model.bert.encoder.layer[0].attention.output.dense.bias.clone().detach().double()

mean = np.array([-0.03383045433490704, -0.04689138747464171, -0.04320052751297194, -0.04194874763842685, -0.03849735236740709, -0.03583471496309556, -0.036673685450259945, -0.03533623114666153, -0.03301200050649906, -0.03385619903604035, -0.03394064677150061, -0.03581378040060232, -0.04000193681582013, -0.042994980738727644, -0.042689484809151766, -0.0422699887342667, -0.040702211423783496, -0.043257636922742766, -0.040924377288572664, -0.04212762593354266, -0.040090620729304687, -0.03727317047412721, -0.030603299343800818, -0.034141189654495016, -0.03468711091296442, -0.032307857857310274, -0.02926372943560165, -0.031292906450152466, -0.037837883896213766, -0.03745859562807607, -0.03794657692710982, -0.03860214509229593, -0.036185650111238955, -0.039154371235979875, -0.03589729976884486, -0.031731895884233016, -0.03465287223481833, -0.031348414682812194, -0.03688161652969029, -0.03338290816163936, -0.038240660222183975, -0.037525466450406116, -0.038229222217722264, -0.041201914113547705, -0.04212576296359885, -0.03980083151775188, -0.04072657806877826, -0.040145599490268025, -0.036685242667777444, -0.034109016054392725, -0.03544325775104831, -0.03623692053970561, -0.04948334692050963, -0.04596823422981405, -0.04892271117435003])
var = np.array([0.7495962428549272, 0.6109555428467895, 0.6225590467577651, 0.62495153067201, 0.631395549935461, 0.634492711694546, 0.644892789064359, 0.6542099965205022, 0.6595559062153842, 0.6659906881037033, 0.6680168012366937, 0.6758412527257586, 0.6668118068796066, 0.6718192460326265, 0.67786737736941, 0.6808577853930836, 0.6736657333151266, 0.6676446046843724, 0.6659979061989304, 0.6743226078654423, 0.681388263935704, 0.6837117808950258, 0.6907147768934253, 0.684537831509984, 0.6896744328697597, 0.6916627127801457, 0.6954043965468235, 0.6954046755145293, 0.7001025287354249, 0.695094327647078, 0.6854203403085795, 0.7027792682295838, 0.6956849098218769, 0.6945153573872891, 0.6856697060013522, 0.6897353511373785, 0.700668908202082, 0.6965624918742969, 0.7082690699456209, 0.7043163331126293, 0.7070770512949652, 0.7042510307314358, 0.6978925459183357, 0.7205035876616076, 0.6902461198740245, 0.686971254827903, 0.7028843270104062, 0.7032880792671149, 0.7057843340136714, 0.7104860015626775, 0.7321738164781159, 0.71095817492914, 0.7401485084476891, 0.7312957890728539, 0.7375994654874705])
    
fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
fin2_backup = fin2.clone()
fin2_backup = fin2_backup + original_input_tensor

fin3_whole = []

for i in range(len(original_input_tensor.squeeze())):
    fin2 = fin2_backup.squeeze()[i]
    fin3_corr = (fin2.squeeze().detach() - mean[i]) * var[i]
    
    #fin3_corr = (fin2.squeeze().detach() - torch.mean(fin2.squeeze())) / math.sqrt(torch.var(fin2.squeeze()))
    
    w_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[0].attention.output.LayerNorm.bias.clone().detach().double()

    fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
    fin3_whole.append(fin3_corr.detach())

fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)

In [80]:
fhe_vector = np.array([ -1.3216,  0.3943, -11.3764, -0.1851,  1.0620,  0.1171,  0.5584, -0.2315, -1.5654,  0.5121, -0.0177,  0.7993,  0.5296,  0.2811, -0.2357,  1.1707,  1.0067,  1.1213, -0.9423, -0.0485, -0.1010, -0.7063,  2.6714,  1.0284, -0.2182,  0.4198, -0.2310, -0.3919, -0.3436,  0.0209,  0.0081, -0.0128, -6.0198,  0.4562, -0.7480,  0.2193,  0.2361, -0.2838,  0.2971,  0.3294,  0.5741, -0.2275,  2.5782, -0.5683,  0.2341, -0.1563,  0.9609,  0.8661,  0.0796,  0.3318,  1.1611,  0.3277,  0.1713,  1.1441,  0.9330,  1.0272, -0.6367,  0.2313,  0.5022,  0.4916, -0.1517,  0.1372, -0.2848, -0.0573, -0.3756,  0.0983,  0.3428,  0.0015,  0.1176,  0.2694,  0.5759, -0.7957,  0.3611, -2.4415,  0.1436, -0.9979,  1.1635,  0.1807,  0.7883, -3.1042,  0.0889, -1.1248,  1.1513,  0.6631,  0.2680, -0.4830,  0.4446, -0.4005, -0.8541, -0.5443,  0.2943, -0.1439,  0.6619, -0.3214,  0.3452,  0.2795,  0.5496,  0.5585, -0.0128, -0.5769, -0.9112, -0.1204,  0.2533,  0.3010,  0.2565,  1.8140,  0.3824,  1.5825,  0.2291, -0.3089, -0.1782,  0.2882,  0.0670,  0.5138,  0.1509, -0.0582,  1.0456, -1.1855, -0.0300,  1.0703, -1.1859,  0.2143,  0.1140, -0.7228,  0.4116, -0.3671,  0.8836,  1.2229,  ])
precision(fin3_whole[0][0].detach(), fhe_vector)

tensor(0.9912, dtype=torch.float64)

In [81]:
fin3_whole

tensor([[[ -1.3148,   0.3858, -11.3850,  ...,  -0.3709,   0.8667,   1.2124],
         [ -3.4727,   0.9965,  -0.7988,  ...,  -3.9541,   1.2899,   1.2560],
         [ -3.1280,   0.5889,  -1.6285,  ...,  -3.5301,   0.3958,   1.1160],
         ...,
         [ -4.1072,   0.0962,  -1.3231,  ...,  -2.8202,   1.2232,   3.5629],
         [ -2.4529,   1.2259,  -0.3580,  ...,  -1.8126,  -0.9176,  -0.2391],
         [ -3.6870,   0.3899,   0.3177,  ...,  -1.1680,  -1.2935,   1.0331]]],
       dtype=torch.float64)

### 3) Layer 1 -- Intermediate

In [96]:
fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[0].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].intermediate.dense.bias

fin_5 = torch.nn.functional.gelu(fin_4)

In [83]:
fin_5

tensor([[[-0.1572, -0.0911,  0.1413,  ...,  0.2130,  0.3380, -0.0051],
         [-0.1364,  0.2151, -0.0464,  ...,  0.0195,  0.2342, -0.0062],
         [-0.0759, -0.0778,  0.1660,  ..., -0.1099,  0.1672, -0.0709],
         ...,
         [-0.1577, -0.0863, -0.1596,  ..., -0.1685, -0.1358, -0.0052],
         [-0.1675,  0.7096, -0.0950,  ..., -0.1235, -0.0770, -0.0036],
         [-0.1424,  0.2056,  0.0057,  ...,  0.0888,  0.9086, -0.0009]]],
       dtype=torch.float64, grad_fn=<GeluBackward0>)

In [122]:
fhe_vector = np.array([ -0.5330, -0.2273,  0.2345, -0.3697,  0.2970, -0.5294,  0.0565,  0.8272, -0.7153, -0.1726, -1.6825, -0.3913, -0.2027, -1.2849, -1.3158, -3.1668, -0.7815,  0.0179, -1.7771,  0.3129,  0.1193,  0.6690,  0.2648, -0.6968,  1.0492, -0.4954,  0.1870, -0.1511, -0.2539, -0.6455, -1.8991, -2.4513, -0.7193, -3.7977, -0.0837, -1.3579,  0.1407, -1.3123,  1.1331, -0.3409, -0.7910, -0.4112,  0.1462, -0.5556, -1.0079, -1.4370,  0.4928, -0.0017,  0.6031,  0.4275, -0.1060, -0.4284, -0.6600,  0.7290,  0.2775,  0.4665, -0.8068, -0.9618,  0.5486,  1.0075, -0.1601,  0.6621, -0.1483, -0.3542, -1.2240, -0.4219, -0.9824,  0.2267,  0.9752,  0.4199, -0.4372,  0.0394, -1.3696, -1.9063,  0.5260,  0.4867,  1.1665, -2.2793,  0.2096, -1.1870,  0.5899, -0.9560, -1.2170,  0.1964, -1.8866, -1.1638,  0.1815, -0.8357, -3.4606, -0.1168, -0.6876,  0.5601,  0.5957,  0.1803,  0.2562, -0.0307,  0.7868, -1.6910, -1.2224, -0.6374, -0.3655, -1.0158,  0.2919,  0.5678,  0.2370,  0.0573, -0.3579,  1.1209,  0.2696,  0.2203,  0.5051,  0.0269, -1.3878, -1.2882, -0.0836,  0.6048, -1.3334, -1.5376, -0.5986, -0.4393, -0.2839, -0.0195, -0.5154,  0.6487,  0.0117, -0.5656,  0.2284,  0.3217,  0.7261,  0.1420, -0.4013, -0.5076, -0.0200,  1.2051,  1.0192, -1.7285, -0.0986, -0.6606,  0.1542, -0.3458,  0.7084, -0.6772, -2.5102,  0.1021, -1.5361, -0.5311, -0.3406, -0.2925, -1.3313,  0.4681, -2.2569, -0.2746,  1.6836, -0.4938,  0.4101,  0.2404,  0.7889, -0.8880, -1.2268, -0.7209,  0.3907, -0.1574,  0.2911,  0.5587, -0.9561, -3.4308,  0.1591, -0.0236,  0.5549, -2.5780, -0.3286,  0.7431,  0.2113,  0.4559, -0.4310, -0.2113, -0.1625, -0.1426, -0.8745,  0.8300,  0.1040,  0.0636, -0.1788, -0.9805,  0.4381,  0.0511, -0.8142, -0.2321,  0.1639,  0.5920, -0.8984, -1.1838, -0.2599,  0.8565,  1.1861, -1.2306,  5.2940, -0.5703, -0.0042,  0.3883, -0.0213,  0.0751,  0.6211, -1.2630,  0.7186, -0.7046, -1.3726, -0.4557, -0.9880, -1.0223, -0.6100, -1.3209,  0.2430,  0.3542, -1.0642,  0.2679, -0.1970, -0.0159,  1.1732, -0.8052, -0.1743, -0.6104,  1.4697, -1.3691,  0.2760,  1.4484, -0.7397, -0.3188, -0.2068,  0.7872, -1.0221, -0.0034, -0.6137,  1.0902, -0.9960, -2.4438, -0.5714,  0.7220, -0.6532,  0.3245, -0.6575, -0.6372, -0.9932,  0.0845, -0.0638, -1.4336, -0.7709,  0.2955, -1.3879, -0.1854,  1.0031, -0.9626,  0.8858,  0.0147,  0.4307, -0.9693, -0.2051,  0.0578, -1.2155, -0.8771, -0.6355, -1.3800,  0.5483,  0.7251, -0.6300,  0.6872, -2.1835, -2.7558,  0.8189, -2.2503, -3.3357, -0.2197,  0.1206, -0.7035, -0.3694, -0.5209,  0.4991, -1.1801, -2.2455, -0.1077,  0.6663,  0.3011,  0.6903, -1.2452, -1.2066, -0.7137, -0.0299, -2.0174,  0.5608, -0.5332,  0.4464, -0.8255, -0.9662,  0.9659, -0.3153, -0.1035, -0.1510,  0.1794, -0.5238, -0.2585, -0.9515, -1.1444, -1.0786, -2.3950, -0.3903,  0.5005, -2.2705,  0.6724,  0.2173, -0.6892, -1.3332, -0.6334, -0.6889, -0.2897, -0.5437, -1.1193,  0.1663, -0.5837, -4.1280,  0.3632, -0.3907, -1.2608, -0.6170, -0.8997, -1.5595, -1.0618,  0.4584, -0.5312,  0.6782,  0.4575, -1.4799,  0.5869, -0.2499, -0.5237, -0.7441, -0.7961, -2.4736,  0.3434, -1.0491, -0.5886, -0.3229, -0.3877, -1.0490, -0.3960, -0.1722, -0.3819,  0.5922, -0.0746,  0.1875,  0.5499, -0.9565, -0.2733, -0.8183,  0.0823, -0.5222, -1.2691, -2.1509, -2.0148, -0.4922, -0.5204, -1.3539, -1.2819, -3.8538, -2.3548, -0.0702, -0.7759,  1.4571, -0.4104,  0.1313,  0.3970,  0.1748,  0.7853,  0.0806, -1.5523, -1.3143, -2.2255, -1.4484, -1.5520,  0.3667,  0.1740, -0.7376,  0.3107, -0.2123, -0.6602, -1.3506,  0.3907,  0.4018,  0.2095, -0.9890, -1.4884, -0.2458,  0.0834, -1.2941, -0.3318, -2.7505, -0.4036, -0.1555,  1.0325, -1.1008, -0.5016,  0.4653, -0.1429, -0.0889,  0.6230, -0.2122, -0.2609,  0.1717, -0.5215, -0.5952, -0.0394, -0.1201, -1.1994,  0.7683, -1.0797, -1.1886, -0.3520, -0.2996,  1.1558, -0.9670,  1.5620, -1.5597, -0.4422, -0.0748,  1.1966,  1.0691, -1.0566, -0.7020,  0.2551, -0.7676, -0.6269, -0.8738, -0.8518,  1.2148,  1.7399, -0.6897,  0.7183, -1.2955, -0.6945, -0.1010,  0.6609, -0.4401,  1.1732,  0.0250, -0.7444, -0.7502, -1.2021, -0.4836,  0.2209,  0.2644, -0.1128, -0.9396,  1.3635, -0.0583,  0.5863, -1.5772,  0.5194, -1.1027,  0.0509, -0.2314,  0.8496, -0.0734, -1.2835, -1.8268, -0.6830,  0.0890, -0.5979,  0.0618,  0.1963,  0.5300, -0.9159, -0.0039,  0.6527, -0.8332, -1.1300, -0.1125, -1.6626, -0.6901,  0.4938,  0.5833,  0.0380,  0.3883, -1.2735, -0.6300, -0.4347,  0.1410, -0.2494,  0.3981, -2.0557, -0.8664,  0.5679, -2.1442, -0.9085,  0.0939, -0.4817, -1.6977, -0.9859,  0.1788, -2.8050, -1.7254,  0.3313,  0.1732, -0.4663, -1.3678, -0.9631, -0.6200,  0.8757,  0.1189,  0.3029,  0.3052,  0.3388,  0.4981, -2.9121 ])
precision(fin_4[0][0].detach(), fhe_vector)

tensor(0.9930, dtype=torch.float64)

In [125]:
fhe_vector = np.array([ -0.1583, -0.0932,  0.1390, -0.1316,  0.1832, -0.1579,  0.0295,  0.6584, -0.1697, -0.0745, -0.0778, -0.1361, -0.0851, -0.1277, -0.1238, -0.0024, -0.1698,  0.0091, -0.0671,  0.1949,  0.0653,  0.5006,  0.1600, -0.1693,  0.8949, -0.1537,  0.1074, -0.0665, -0.1015, -0.1674, -0.0546, -0.0174, -0.1697, -0.0003, -0.0391, -0.1185,  0.0782, -0.1243,  0.9874, -0.1250, -0.1696, -0.1400,  0.0816, -0.1607, -0.1580, -0.1083,  0.3395, -0.0009,  0.4383,  0.2845, -0.0485, -0.1432, -0.1681,  0.5591,  0.1691,  0.3170, -0.1693, -0.1617,  0.3886,  0.8495, -0.0699,  0.4940, -0.0654, -0.1281, -0.1352, -0.1420, -0.1601,  0.1337,  0.8145,  0.2782, -0.1447,  0.0203, -0.1170, -0.0540,  0.3685,  0.3343,  1.0245, -0.0258,  0.1222, -0.1396,  0.4261, -0.1621, -0.1361,  0.1135, -0.0559, -0.1423,  0.1038, -0.1685, -0.0009, -0.0530, -0.1691,  0.3990,  0.4315,  0.1031,  0.1540, -0.0150,  0.6171, -0.0768, -0.1354, -0.1670, -0.1306, -0.1573,  0.1795,  0.4059,  0.1407,  0.0300, -0.1289,  0.9738,  0.1634,  0.1293,  0.3501,  0.0138, -0.1146, -0.1273, -0.0390,  0.4399, -0.1216, -0.0955, -0.1645, -0.1451, -0.1102, -0.0096, -0.1562,  0.4812,  0.0059, -0.1617,  0.1348,  0.2014 ])
precision(fin_5[0][0][:128].detach(), fhe_vector)

tensor(0.9910, dtype=torch.float64)

### 4) Layer 1 -- Output

In [127]:
mean = np.array([-0.09545516102868973, 0.034540955180462664, 0.03934738149667437, 0.040802318439555035, 0.04426037798445811, 0.04919343175846099, 0.0493616301294401, 0.047896279398118795, 0.04912640635535303, 0.048717249992826256, 0.0477219385203478, 0.05095357678578503, 0.05094908370417657, 0.0493275745992752, 0.048418324664654545, 0.0473653504669205, 0.04528009986283869, 0.04524247257539856, 0.046555073355952846, 0.0516135997743503, 0.049103903254210594, 0.048877585502238356, 0.048364988370661784, 0.049043507301742846, 0.049933470462367846, 0.05175179126331398, 0.05057227793143223, 0.055763206569478994, 0.055243365455213404, 0.04986745821758072, 0.047789218698650125, 0.047852162700887234, 0.04279460740337753, 0.04280733225675328, 0.04644169155736491, 0.04783492130826333, 0.04759649093761958, 0.045252139153821, 0.04367184005341422, 0.039034762655413016, 0.04374965234639466, 0.04355128435775863, 0.04499861862695065, 0.04318602336450084, 0.04549296197766528, 0.03907804279518851, 0.037683132925437485, 0.04109696491189214, 0.04410155617431274, 0.05015992918511731, 0.04335430986396108, 0.046492484403760526, 0.044277581701870204, 0.03723061917091777, 0.039156973130334664])
var = np.array([0.4156698594967092, 0.7008452266859936, 0.7214270983257646, 0.7095727482866087, 0.7102521835201318, 0.710293676073547, 0.7091783271698753, 0.6973493176419543, 0.7011688527520855, 0.7007704875343309, 0.6950537183089973, 0.6948029158092094, 0.6919309911197036, 0.6933694537037308, 0.6970711644923971, 0.7004276850010867, 0.6964234913676165, 0.6987678419874651, 0.6951829293138483, 0.6973048809142951, 0.6989420799277399, 0.7005696487948311, 0.6993937733493811, 0.6902070532566239, 0.6958399824203775, 0.6900361005407983, 0.6925891359742274, 0.6831642926666377, 0.6865279710039072, 0.6904370385593245, 0.6963724536275457, 0.6948942601360332, 0.6784634186071326, 0.6759657478656234, 0.6828578884489792, 0.683566347862741, 0.6857777074044566, 0.672040915409448, 0.6784995422914343, 0.6732453264186854, 0.683881765911935, 0.6909411690410042, 0.6715428435769978, 0.6775867807314924, 0.6785015863916147, 0.676156117696202, 0.6786376609996214, 0.6763771062984715, 0.7119440584663215, 0.7070342067744777, 0.6895996022331654, 0.6683970656272868, 0.6695013664908844, 0.6566575067124804, 0.672887703816164])    
    
fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[0].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[0].output.dense.bias
fin_6 = fin_6 + fin3_whole

fin7_whole = []

for i in range(len(input_tensor.squeeze())):
    fin_7 = fin_6.squeeze()[i]
    
    fin7_corr = (fin_7.squeeze().detach() - mean[i]) * var[i]
    
    w_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[0].output.LayerNorm.bias.clone().detach().double()

    fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm

    fin7_whole.append(fin7_corr.detach())

fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

In [130]:
fhe_vector = np.array([ -2.5672, -0.0404,  0.1736,  0.0734,  0.4014, -0.0162,  0.6670, -0.1658, -2.6238,  2.4949, -0.0569,  1.1267,  0.5967, -2.4788, -0.2496,  0.1204,  1.0123, -1.1119, -1.6441,  0.0791,  0.2120, -1.8691, -0.0163, -0.3107,  0.6332, -0.5724,  2.2917,  1.1560,  1.0361,  1.1970,  0.2287, -0.7155, -0.4602,  0.6244,  0.6548, -0.1565,  2.1784,  0.2266, -1.5136,  0.2002,  0.8796,  0.1156,  1.6029, -0.1679,  0.8397,  0.5351,  0.1815,  0.2280,  0.7055, -0.2807, -0.1966,  0.3239,  1.5733, -1.1548,  0.5877, -0.9185,  0.1887,  0.4781, -0.2558,  0.3328,  0.7042,  1.0406, -1.6705, -0.4073, -1.7941, -0.6432,  0.1797,  1.0775, -0.3360,  1.1114,  0.5275, -0.5175,  0.1560, -1.1586, -0.8505,  0.3479,  1.0907, -1.0918,  1.0484, -2.0402, -0.1048,  0.1999,  0.5569,  0.3111,  3.3590, -0.0307, -0.7502,  1.2190,  0.2485,  1.2731,  1.1261, -1.4094,  1.4868, -0.3189,  1.5277,  0.2742, -0.6627, -0.3644,  2.3280, -0.9356, -0.1478,  0.2325, -0.5535,  0.4641,  1.4898,  0.3516, -0.9462,  0.6526,  0.1597, -1.8249, -2.2442,  0.9491,  0.0111,  0.3965, -1.4054, -1.9002,  0.3990, -0.9621, -1.4328,  1.4456, -0.1939,  0.1988, -0.9830, -1.0697, -0.3367, -2.8929, -0.0422,  0.9205,  ])

precision(fin7_whole[0][34].detach(), fhe_vector)

tensor(0.9980, dtype=torch.float64)

### 5) Layer 2 -- Self-Attention

In [16]:
key = model.bert.encoder.layer[1].attention.self.key.weight.clone().detach().double().transpose(0, 1)
query = model.bert.encoder.layer[1].attention.self.query.weight.clone().detach().double().transpose(0, 1)
value = model.bert.encoder.layer[1].attention.self.value.weight.clone().detach().double().transpose(0, 1)

key_bias = model.bert.encoder.layer[1].attention.self.key.bias.clone().detach().double()
query_bias = model.bert.encoder.layer[1].attention.self.query.bias.clone().detach().double()
value_bias = model.bert.encoder.layer[1].attention.self.value.bias.clone().detach().double()

original_input_tensor = fin7_whole
input_tensor = fin7_whole

q = torch.matmul(input_tensor, query) + query_bias
k = torch.matmul(input_tensor, key) + key_bias
v = torch.matmul(input_tensor, value) + value_bias

q = q.reshape([1, input_tensor.size()[1], 2, 64])
k = k.reshape([1, input_tensor.size()[1], 2, 64])
v = v.reshape([1, input_tensor.size()[1], 2, 64])

q = q.permute([0, 2, 1, 3])
k = k.permute([0, 2, 3, 1])

qk = torch.matmul(q, k)
qk = qk / 8

qk_softmaxed = torch.softmax(qk, -1)

v = v.permute([0, 2, 1, 3])

fin = torch.matmul(qk_softmaxed, v)
fin = fin.permute([0, 2, 1, 3])
fin = fin.reshape([1, input_tensor.size()[1], 128])

In [17]:
fhe_vector = np.array([ -0.8123, -0.8500,  0.0295,  0.2296,  0.4401, -0.6094,  1.6168,  0.2558, -0.2224, -0.6283, -0.5895,  0.7919, -0.2594, -0.3843,  0.0067,  1.5401, -0.0503,  0.1357, -0.4071, -0.4671, -1.0653, -1.1093, -2.0851,  0.5782,  0.5840, -0.6833,  1.5346,  1.3422,  0.2175,  0.9805, -0.1275, -1.5916,  1.0102, -0.1957,  0.0962, -0.0464, -0.4231, -1.3056,  0.0510, -1.1596,  0.1894,  0.4713, -0.0684, -1.0158, -0.2589, -0.5890, -0.8593, -0.2406,  0.2359,  0.8717, -0.7101, -1.6676, -0.3206, -0.3165, -0.8318, -0.7661, -0.8755,  0.2422, -1.1412, -0.2040,  0.8289, -0.2363, -0.6205, -0.4749, -0.5698,  0.6264, -0.6598, -0.3961,  0.3553, -0.3192, -0.1223, -0.0449, -0.3661,  0.7190, -0.3748,  0.1306,  0.9412, -1.5460,  0.8761, -0.1402,  1.2423,  0.7885, -0.3937, -0.0085, -1.3537,  0.0370, -1.2522,  1.0030,  2.0746, -0.7593, -0.2284, -0.3362,  0.6514, -0.0331, -0.1410,  1.6767, -0.2301,  1.1221, -0.6067,  0.6165, -0.6068, -1.2288, -0.9807,  0.4249, -1.3200, -0.2358, -0.9543,  0.7164,  0.9259,  0.6031,  0.3302, -0.2839, -0.2300, -0.8882, -0.6937, -0.4157,  0.5060, -0.0074, -0.8772, -0.7689,  0.4577, -1.0608, -0.1394, -1.5404,  1.5020,  0.1260, -0.5764, -0.5608 ])

precision(fin[0][0].detach(), fhe_vector)

tensor(-0.2955, dtype=torch.float64)

### 6) Layer 2 -- Self-Output

In [18]:
mean = np.array([0.04805131047475803, 0.014145706172069285, 0.010630181813540026, 0.010521146572975027, 0.00956244983947186, 0.008211288558782809, 0.008817800275674387, 0.008911457532306733, 0.008643898058317862, 0.008801769546523253, 0.009472254700839258, 0.008094415948174241, 0.007702615754430344, 0.005460620353838359, 0.007021847370084451, 0.008373831982472147, 0.01022061224155272, 0.00927594903773269, 0.009277225000069925, 0.007049453120897054, 0.008682554190420182, 0.008749022040809715, 0.010118317324741522, 0.008998865743435887, 0.008763833543884292, 0.008285728555981435, 0.006967351876718886, 0.00588068616144895, 0.0030701809065725363, 0.003659716972971551, 0.002116778487431024, 0.003947434346765913, 0.006907859825079262, 0.008494112860837831, 0.007040283968419036, 0.007197681884381672, 0.008232685835987293, 0.009965029801574864, 0.00731962961637719, 0.00830555309310382, 0.005340440177451385, 0.007833324368720607, 0.01047456825511633, 0.009674864773662995, 0.010093537461664302, 0.01588798917017868, 0.018537933333636507, 0.018245848282989877, 0.012253993810893607, 0.011354133953173591, 0.013474744814287221, 0.013707011955501919, 0.007918842609048385, 0.017240907760895086, 0.03465881962238184])
var = np.array([0.6741653046411179, 0.602392389437227, 0.5945841451997256, 0.5997135932136959, 0.6033806506910513, 0.6064839949503851, 0.6058735285405447, 0.6059001754921257, 0.6086086189801689, 0.6118981975241923, 0.6161533101614306, 0.6105411757987637, 0.6102443339235957, 0.6004337682468068, 0.6068584434133084, 0.6123178593290803, 0.6150302868629213, 0.6102744641580546, 0.6143169356654037, 0.6105845722771672, 0.61540315154488, 0.622109065598561, 0.6221720668578823, 0.6279330579960701, 0.6282907135959079, 0.6258439179151315, 0.6187239026398644, 0.618294817104495, 0.609488586748927, 0.6085185174201381, 0.6154275326252285, 0.6207534846328591, 0.6290521066315713, 0.6375810334496135, 0.6238236165346044, 0.6310571465398529, 0.6350551779511981, 0.6452639043477173, 0.6346915398812409, 0.646622546259538, 0.6435498445423712, 0.6401589932559348, 0.6458833892517316, 0.6354378204804867, 0.651796667347259, 0.6547600574517144, 0.6554038815336571, 0.655910889886979, 0.6412602949793637, 0.6489736968517984, 0.6633309254993116, 0.6771441398382873, 0.6423362709438692, 0.6302863730404997, 0.5940213893371686])

w_output_dense = model.bert.encoder.layer[1].attention.output.dense.weight.clone().detach().double().transpose(0, 1)
b_output_dense = model.bert.encoder.layer[1].attention.output.dense.bias.clone().detach().double()

fin2 = torch.matmul(fin, w_output_dense) + b_output_dense
fin2_backup = fin2.clone()
fin2_backup = fin2_backup + original_input_tensor

fin3_whole = []

for i in range(len(original_input_tensor.squeeze())):
    fin2 = fin2_backup.squeeze()[i]

    fin3_corr = (fin2.squeeze().detach() - mean[i]) * var[i]

    w_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[1].attention.output.LayerNorm.bias.clone().detach().double()

    fin3_corr = fin3_corr * w_output_layernorm + b_output_layernorm
    fin3_whole.append(fin3_corr.detach())

fin3_whole = torch.cat(tuple(fin3_whole), 0).unsqueeze(0)

In [19]:
fhe_vector = np.array([  1.3770, -1.8186, -1.6910,  0.6093, -0.1824,  0.1786,  1.6645,  0.7341, -0.6092,  0.7240,  1.0244, -0.7155, -0.0153,  0.1687, -0.1843, -0.0103,  1.8486, -0.8672, -1.6907,  0.5330, -0.2316,  1.0860,  3.3427,  1.8338, -0.4012, -0.4893,  0.4482, -1.6318,  0.7493,  0.5131, -1.1009,  1.2824, -3.2195,  0.6660, -0.3238, -0.4962,  0.3410, -1.0572, -1.1014,  0.1388, -1.7925,  0.8096, -2.0355, -0.9068,  1.1941, -1.8014,  0.0378, -0.2286,  1.4185,  0.5991,  1.5236,  0.1015,  1.5935, -1.3028,  1.0833,  0.0207, -2.5202,  0.4889,  1.9203,  0.3599,  1.5069, -0.5983, -0.9472, -1.4128,  0.0251, -0.8160, -1.4836,  0.9483,  0.5418,  0.0704,  2.0288,  0.7253,  0.7689,  0.0401,  0.4672, -0.9288, -0.4404,  0.5059,  1.1886,  1.2352, -0.6807, -0.8505, -0.8434, -0.8269,  0.4738, -0.1371, -0.7369, -1.1949,  1.9052, -0.0479,  0.1652,  1.2224,  0.0298,  1.5454,  0.6461,  1.4474, -0.2515,  0.0815,  1.0245,  0.3735, -0.4966,  0.7358,  1.1659, -0.0261, -0.8297, -0.9907,  0.1873, -0.1336,  2.1544, -1.1358, -0.3534, -0.7904,  0.3181,  2.8254,  0.1058,  0.4013, -0.1092, -1.3576, -1.6598, -1.1263,  1.2364,  1.3081,  0.7460,  1.7961, -0.8862, -1.6055, -2.5590, -1.2520,  ])

precision(fin3_whole[0][0].detach(), fhe_vector)

tensor(-0.2931, dtype=torch.float64)

### 7) Layer 2 -- Intermediate

In [20]:
fin_4 = torch.matmul(fin3_whole, model.bert.encoder.layer[1].intermediate.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].intermediate.dense.bias   
fin_5 = torch.nn.functional.gelu(fin_4)    

In [21]:
fhe_vector = np.array([ -0.0683, -0.0562,  1.3694, -0.1403, -0.0703, -0.1537, -0.1658,  0.4683, -0.0210, -0.0005, -0.1563, -0.1662, -0.0726, -0.0936, -0.1059, -0.1640, -0.0935, -0.1621, -0.1689, -0.0099, -0.1663, -0.0288, -0.0431, -0.1237, -0.1699,  1.3531, -0.0271, -0.1230, -0.0315, -0.1606, -0.1465, -0.1607, -0.1111, -0.1253, -0.1450,  0.0236,  1.3971, -0.1058, -0.1201,  0.6417, -0.1595, -0.1340,  0.3707, -0.0008, -0.1128,  0.5451, -0.0293, -0.0071,  0.0274,  0.0609, -0.1688,  0.0656,  2.1240, -0.1667, -0.0440,  0.1143,  1.0968, -0.0513,  0.0032, -0.1634, -0.1636, -0.0114,  0.3805, -0.0070, -0.1653, -0.0769, -0.1158,  0.8741,  0.2711,  0.5258,  0.7994, -0.1006,  0.0409,  0.1818,  0.7667, -0.0002, -0.0010, -0.0818,  1.7126, -0.0568, -0.1693,  0.0353, -0.1208,  0.4403, -0.0188, -0.1565, -0.1689, -0.0890, -0.1334,  0.1987, -0.0233,  1.3172, -0.1609,  0.0282, -0.1696, -0.0133, -0.0055, -0.1478, -0.1054, -0.1695,  0.0069, -0.1269,  3.3577, -0.1573, -0.0397,  0.1693,  0.9749,  0.9535,  0.9454,  0.0970, -0.1689,  2.1287,  0.2266,  1.7717, -0.1216, -0.1288, -0.0847, -0.1600, -0.0314,  0.6490, -0.0465, -0.1664, -0.0266, -0.1589, -0.0447, -0.1676, -0.0243, -0.1359 ])
precision(fin_5[0][0][:128].detach(), fhe_vector)

tensor(-0.4821, dtype=torch.float64)

### 8) Layer 2 -- Output

In [22]:
fin_6 = torch.matmul(fin_5, model.bert.encoder.layer[1].output.dense.weight.transpose(0, 1).double()) + model.bert.encoder.layer[1].output.dense.bias
fin_6 = fin_6 + fin3_whole

fin7_whole = []

mean = np.array([0.06643368, 0.05726708, 0.05311476, 0.05229822, 0.05352628, 0.05238868, 0.0536801 , 0.05327334, 0.05206954, 0.05110339, 0.051747  , 0.05016997, 0.04943122, 0.04937956, 0.04952862, 0.04973959, 0.04852742, 0.04696055, 0.04846476, 0.04925392,0.0509005 , 0.05373027, 0.05371865, 0.05446217, 0.05222489,0.05142676, 0.05080909, 0.05179351, 0.05049174, 0.04965748,0.05138143, 0.0499965 , 0.05194982, 0.05178364, 0.0521023 ,0.05059624, 0.05445499, 0.05507825, 0.05241935, 0.05073552,0.05200171, 0.04858642, 0.04419684, 0.04642237, 0.05115073,0.05028116, 0.05021724, 0.05312114, 0.0524375 , 0.04643478,0.05026358, 0.04248708, 0.04675281, 0.03895142, 0.04558007])
var = np.array([0.81992316, 0.78486345, 0.79259   , 0.79754392, 0.79350872, 0.79652433, 0.79935746, 0.79867687, 0.80257863, 0.80235328,0.80521209, 0.80621272, 0.80330435, 0.80469855, 0.81171202,0.81136354, 0.80977166, 0.8089956 , 0.8106946 , 0.80862825,0.81450049, 0.81722176, 0.82121488, 0.82012788, 0.8254015 ,0.82097106, 0.81742119, 0.82090554, 0.82116105, 0.82017896,0.82234659, 0.82832269, 0.82888739, 0.81852014, 0.82054523,0.8224114 , 0.82913892, 0.8289046 , 0.81985612, 0.83341215,0.82896934, 0.82315006, 0.82802216, 0.81886278, 0.8274004 ,0.83436616, 0.82014282, 0.82628005, 0.83230868, 0.84511334,0.85141143, 0.84934269, 0.83041272, 0.826798  , 0.83660989])

for i in range(len(input_tensor.squeeze())):
    fin_7 = fin_6.squeeze()[i]

    fin7_corr = (fin_7.squeeze().detach() - mean[i]) * var[i]

    w_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.weight.clone().detach().double().unsqueeze(0)
    b_output_layernorm = model.bert.encoder.layer[1].output.LayerNorm.bias.clone().detach().double()

    fin7_corr = fin7_corr * w_output_layernorm + b_output_layernorm

    fin7_whole.append(fin7_corr.detach())

fin7_whole = torch.cat(tuple(fin7_whole), 0).unsqueeze(0)

In [23]:
fhe_vector = np.array([  1.7862, -2.0674, -0.2598, -0.2733, -0.6657,  0.9152,  0.6312,  1.0929,  0.3524,  0.7416,  1.2811,  0.2206, -0.8468, -0.5167, -0.1659,  0.6020,  1.3285, -0.9925, -1.9419,  0.2647,  0.2765,  2.3105,  2.8253,  0.5482, -1.4334, -0.3802, -0.3619, -1.4140,  0.5294,  0.6072, -1.8004,  0.5581, -2.8421,  0.2503,  0.4287, -0.3454,  0.2077, -1.0119, -1.5609, -0.6429, -1.4182,  0.1802, -1.7165, -0.2336,  0.9720, -1.6756, -0.1018,  0.1046,  1.0860,  0.9126,  0.9143,  1.2146,  0.9086, -0.7014,  0.1545,  0.2080, -3.3099, -0.0128,  2.6352,  1.0011,  0.7249, -1.0812, -1.1600, -1.2047, -0.4703, -0.2173, -1.3199,  2.1158,  0.7578, -0.3307,  0.9533,  0.5950, -0.1573,  0.8234,  0.7160,  0.6736,  0.2819,  0.8938, -0.2418,  1.3240, -2.0877, -1.4034, -1.4676,  0.0390,  1.2058, -0.2615, -1.5343, -2.3307,  1.4072,  0.2342, -0.0834,  1.4614, -0.2938,  0.3641,  1.0156, -0.0286, -0.7436, -1.0538,  1.2407, -0.2570, -0.8378, -0.0219,  1.4528, -0.0432, -0.9515, -1.8858, -0.4678, -1.1132,  1.9448, -0.1209,  0.0245, -1.3956, -0.0740,  2.0146,  0.8191,  1.0593,  0.4562, -0.5237, -2.1609, -0.7794,  1.5314,  0.5960,  1.6027,  0.2354, -0.7131, -0.2370, -2.6296, -2.5624,  ])

precision(fin7_whole[0][0].detach(), fhe_vector)

tensor(-0.5569, dtype=torch.float64)

### 9) Pooler

In [24]:
pooler_output = torch.tanh(torch.matmul(fin7_whole.double(), model.bert.pooler.dense.weight.transpose(0, 1).double()) + model.bert.pooler.dense.bias)

fhe_vector = np.array([ -0.8440,  0.1941, -0.9109,  0.1285, -0.8042, -0.5718, -0.7415, -0.4866, -0.1436, -0.0316, -0.6325,  0.0129, -0.2240,  0.3304, -0.0773, -0.8037,  0.4338,  0.0093, -0.1401,  0.0085,  0.6732,  0.1331,  0.4453, -0.0120, -0.4030, -0.0919, -0.9976,  0.6423,  0.8940,  0.0913,  0.0908, -0.0297,  0.7762, -0.4384,  0.3557,  0.9624, -0.1017, -0.0680,  0.4804, -0.6385,  0.9297,  0.8713, -0.3583, -0.5340,  0.4769, -0.2850, -0.6156, -0.1654, -0.3428, -0.6050,  0.5404, -0.8458,  0.2055,  0.0719,  0.6019, -0.3941,  0.4473, -0.4733,  0.6917, -0.8246, -0.1190, -0.8463, -0.6532, -0.2030,  0.0843, -0.5770,  0.9581, -0.0900, -0.0180,  0.8932,  0.9717, -0.0080, -0.9736,  0.0909, -0.4380, -0.9297, -0.3076, -0.0041, -0.4441,  0.2525,  0.5803, -0.0977,  0.9803, -0.9771,  0.9810, -0.8303,  0.8472, -0.5449,  0.5220,  0.5931,  0.6716,  0.9346, -0.4750,  0.9427, -0.4397, -0.1431, -0.2779, -0.8458, -0.5880, -0.8033, -0.0258,  0.8072, -0.5636,  0.2409,  0.4521, -0.8313, -0.9927, -0.0616,  0.3463,  0.9466,  0.7628,  0.4724, -0.7301,  0.6899,  0.2496,  0.1475,  0.6461, -0.0828,  0.0777, -0.0982,  0.1507, -0.8850, -0.5445,  0.3068, -0.5583,  0.7481,  0.7680,  0.8013 ])
precision(pooler_output[0][0].detach(), fhe_vector)

tensor(0.8556, dtype=torch.float64)

### 10) Classifier

In [25]:
classification = torch.matmul(pooler_output, model.classifier.weight.transpose(0, 1).double()) + model.classifier.bias.double()

fhe_vector = np.array([ -0.0197,  0.1699 ])

precision(classification[0][0].detach(), fhe_vector)

tensor(0.5020, dtype=torch.float64)

In [26]:
print("Plain circuit output: {}\nFHE circuit output: {}".format(classification[0][0].detach().numpy(), fhe_vector))

Plain circuit output: [0.02078085 0.13353153]
FHE circuit output: [-0.0197  0.1699]


In [27]:
model(tokens_tensor, torch.tensor([[1] * len(tokenized_text)]))

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.0951, -0.0199]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)