Code for Testing 12 Attention Layers

## 1. Experiment Setup

**Import Modules**

In [None]:
import sys
import os 

project_root = os.path.abspath(os.path.join(os.getcwd(), './src'))
if project_root not in sys.path:
    sys.path.append(project_root)
    
project_root = os.path.abspath(os.path.join(os.getcwd(), '../src'))
if project_root not in sys.path:
    sys.path.append(project_root)

import pickle    
import numpy as np
import math
import torch
from transformers import BertForNextSentencePrediction
import matplotlib.pyplot as plt

from liberate.fhe.bootstrapping import ckks_bootstrapping as bs

import thor
from thor import CkksEngine, ThorDataEncryptor, ThorLinearEvaluator
from thor.bert import ThorBert, ThorBertFF, ThorBertPooler, ThorBertClassifier

### 1-1. Initiate CKKS Engine

**Choose GPU**

In [None]:
devices = [0]

with torch.cuda.device(devices[0]):
    torch.cuda.empty_cache()
    print(torch.cuda.memory_allocated(devices[0]) /1024**3)

In [None]:
params = {"logN":16, "scale_bits": 41, "num_special_primes": 4, "devices": devices, "quantum":"pre_quantum"}
engine = CkksEngine(params)
print("Memory allocated: ", torch.cuda.memory_allocated(devices[0]) /1024**3)

**Load Pre-Generated Keys**

In [None]:
rotk_dict_keys = [
    -32768, -16384, -1024, -512, -32, -16,
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
    32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384,
    416, 448, 480, 512, 1024, 2048, 3072, 4096, 5120, 6144,
    7168, 8192, 9216, 10240, 11264, 12288, 13312, 14336,
    15360, 16384
]                    

deltas = [    
    1, 2, 3, 4, 5, 6, 7, 8, 9,
    10, 11, 12, 16, 2048, 4096, 6144, 8192, 10240,
    12288, 14336, 18432, 20480, 22528, 24576, 26624, 28672, 30720,
    256, 2288, 4320, 6352, 8384, 10416, 12448, 14480, 16512,
    18544, 20576, 22608, 24640, 26672, 28704, 30736,
    512, 768, 1280, 2544, 2800, 3312, 4576, 4832, 5344,
    6608, 6864, 8640, 8896, 10672,
]

In [None]:
sk = engine.load("./keys/keys0/sk")
pk = engine.load(f"./keys/keys0/pk")
engine.add_pk(pk)
evk = engine.load(f"./keys/keys0/evk")
engine.add_evk(evk)
gk = engine.load(f"./keys/keys0/gk")
engine.add_gk(gk)
conjk = engine.load(f"./keys/keys0/conjk")
engine.add_conj_key(conjk)
rotk_dict = {}
for key in rotk_dict_keys:
    rotk_dict[key] = engine.load(f"./keys/keys0/rotk_dict/{key}")
bs.create_cts_stc_const(engine)
engine.add_bs_key(rotk_dict)
engine.add_rot_keys_from_sk(deltas, sk)
print("Memory allocated: ", torch.cuda.memory_allocated(devices[0]) /1024**3)

### 1-2. Load and Encrypt Data

**Set Datatset Type and Target Data Index**

In [None]:
dataset_type = 'mrpc'
target_idx = 0

**Initiate DataEncryptor and DataLoader**

In [None]:
dataset = f'./datasets/{dataset_type}'

data_encryptor = ThorDataEncryptor(dataset_type, dataset,
                                   embedding_model=BertForNextSentencePrediction.from_pretrained('bert-base-uncased').bert.embeddings, 
                                   ckks_engine=engine, test=False)
data_loader = data_encryptor.eval_dataloader

**Encrypt Data as "x"**

In [None]:
def encode_attention_mask(engine, attention_mask:np.ndarray, level:int=15) -> np.ndarray:
    """
    Return an array of size (8,) which contains 8 plaintexts. 
    """
    if attention_mask.shape != (128,):
        raise ValueError("Shape of attention mask should be (128,)")
    n_tokens = np.count_nonzero(attention_mask)
    attention_mask = np.full((8,), None, dtype=object)
    for i in range(8):
        msg = np.zeros((2**15,), dtype=float)
        for j in range(16):
            temp = j *(2**11)
            diag_index = i * 16 + j
            for t in range(128):
                col_index = (diag_index + t) % 128
                is_token = 1 if col_index < n_tokens else 0
                for head in range(12):
                    msg[temp + t*16 + head] = is_token
        attention_mask[i] = engine.encode(msg, level)
    return attention_mask


In [None]:
idx = 0
for batch in data_loader:
    if idx < target_idx:
        idx += 1
        continue
    if idx == target_idx:
        data= {k: v for k, v in batch.items() if k in ['input_ids', 'token_type_ids']}
        embedding = data_encryptor.embed_data(data)
        x = data_encryptor.encrypt_embedding(embedding, pk, level = 20)
        attention_mask = batch['attention_mask']
        thor_attention_mask = data_encryptor.encode_attention_mask(attention_mask.cpu().numpy().squeeze().T, level=15)
        break

### 1-3. Load and Run plain(non HE) Model for Comparison

**Load and Run Plain Model**

In [None]:
model_plain  = thor.utils.load_model(dataset_type, f'./finetuned_models/{dataset_type}/model.safetensors')
model_plain.eval()
device = torch.device("cpu")
model_plain.to(device)
idx = 0
for batch in data_loader:
    print(idx, target_idx)
    if idx < target_idx:
        idx += 1
        continue
    elif idx == target_idx:
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
        with torch.no_grad():
            outputs = model_plain(**batch)
        break

def get_nonlinear_in_out(hidden_states, layer_idx):
    with torch.no_grad():
        bert_layer_m = model_plain.bert.encoder.layer[layer_idx] 
        attention_m = bert_layer_m.attention.self
        bert_output_m = model_plain.bert.encoder.layer[layer_idx].attention.output

        q = attention_m.transpose_for_scores(attention_m.query(hidden_states))
        k = attention_m.transpose_for_scores(attention_m.key(hidden_states))
        v = attention_m.transpose_for_scores(attention_m.value(hidden_states))
        attention_scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(attention_m.attention_head_size)
        extended_att_mask = model_plain.get_extended_attention_mask(
                        attention_mask, 768
                    ).to(device)
        sfmtx_in = attention_scores+extended_att_mask
        att_probs_m = torch.nn.functional.softmax(sfmtx_in, dim=-1)
        sfmtx_out = att_probs_m
        att_context_m = torch.matmul(att_probs_m, v)
        context_layer = att_context_m.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (attention_m.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        dense_output_m = bert_output_m.dense(context_layer)
        ln1_in = dense_output_m + hidden_states
        ln1_out = bert_output_m.LayerNorm(ln1_in)
        gelu_in = bert_layer_m.intermediate.dense(ln1_out)
        gelu_out = bert_layer_m.intermediate.intermediate_act_fn(gelu_in)
        dense2_out = bert_layer_m.output.dense(gelu_out)
        ln2_in = dense2_out + ln1_out
        ln2_out = bert_layer_m.output.LayerNorm(ln2_in)
        pooler_m = model_plain.bert.pooler
        pooler_dense_output = pooler_m.dense(ln2_out[:, 0])
        print(ln2_out[:, 0].shape)
        pooler_output = pooler_m.activation(pooler_dense_output)

    return (
        hidden_states.cpu().numpy().squeeze(),
        q.cpu().numpy().squeeze(),
        sfmtx_in.cpu().numpy().squeeze(),
        sfmtx_out.cpu().numpy().squeeze(),
        att_context_m.cpu().numpy().squeeze(),
        ln1_in.cpu().numpy().squeeze(),
        ln1_out.cpu().numpy().squeeze(),
        gelu_in.cpu().numpy().squeeze(),
        gelu_out.cpu().numpy().squeeze(),
        dense2_out.cpu().numpy().squeeze(),
        ln2_in.cpu().numpy().squeeze(),
        ln2_out.cpu().numpy().squeeze(),
        pooler_dense_output.cpu().numpy().squeeze(),
        pooler_output.cpu().numpy().squeeze()
        )
    
hidden_states = []
qs= []
ks = []
sftmx_ins = []
sftmx_outs = []
att_contexts = []
ln1_ins = []
ln1_outs = []
gelu_ins = []
gelu_outs = []
dense2_outs = []
ln2_ins = []
ln2_outs = []
for layer in range(12):
    hidden_state, q, sftmx_in, sftmx_out, att_context, ln1_in, ln1_out, gelu_in, gelu_out, dense2_out, ln2_in, ln2_out, pooler_dense_out, pooler_out = get_nonlinear_in_out(outputs.hidden_states[layer], layer)
    hidden_states.append(hidden_state)
    qs.append(q)
    sftmx_ins.append(sftmx_in)
    sftmx_outs.append(sftmx_out)
    att_contexts.append(att_context)
    ln1_ins.append(ln1_in)
    ln1_outs.append(ln1_out)
    gelu_ins.append(gelu_in)
    gelu_outs.append(gelu_out)
    dense2_outs.append(dense2_out)
    ln2_ins.append(ln2_in)
    ln2_outs.append(ln2_out)

### 1-4. Load HE Model

**Load Model Weights**

In [None]:

with open(f"./encoded_models_new/{dataset_type}/att.pkl", 'rb') as f:
    weights_pt = pickle.load(f)

with open(f"./encoded_models_new/{dataset_type}/ff.pkl", 'rb') as f:
    ff_weights = pickle.load(f)
    
with open(f"./encoded_models_new/{dataset_type}/pooler.pkl", 'rb') as f:
    pooler_weights = pickle.load(f)

with open(f"./encoded_models_new/{dataset_type}/cls.pkl", 'rb') as f:
    classifier_weights = pickle.load(f)

**Initiate HE Model**

In [None]:

evaluator = ThorLinearEvaluator(engine) #LinearEvaluator does operations such as HE-matmul.

thor_bert = ThorBert(evaluator, weights_pt)
thor_ffs = []

for i in range(12):
    thor_ffs.append(ThorBertFF(evaluator, ff_weights, i))
thor_bert.ffs = thor_ffs
thor_bert.pooler = ThorBertPooler(evaluator, pooler_weights)
thor_bert.classifier = ThorBertClassifier(evaluator, classifier_weights)

**Define Forward Layer Function**

In [None]:
import time

def forward_layer(x):
    global engine, evaluator, thor_attention,thor_ff, layer_idx, thor_attention_mask, time1, time2, time3, time4, time5, time6, time7, time8, time9, time10, time11, time12, time13, time14
    
    thor_attention.to(devices)
    thor_ff.to(devices)
    print("layer_idx:", layer_idx)
    
    if x.shape == (8,):
        x_cplx = np.full((4,), None, dtype=object)
        for i in range(4):
            x_cplx[i] = engine.cc_add(x[i], engine.imult(x[i+4]))
        if layer_idx != 0:
            for i in range(4):
                x_cplx[i] = engine.cc_add(x_cplx[i], engine.rotate_left(x_cplx[i], -6))
    elif x.shape == (4,):
        x_cplx = x
        x = np.full((8,), None, dtype=object)
        for i in range(4):
            conj = engine.conjugate(x_cplx[i])
            x[i] =  engine.mult_scalar(engine.cc_add(x_cplx[i], conj), 1/2)
            x[i+4] =  engine.mult_scalar(engine.imult(engine.cc_sub(conj, x_cplx[i])), 1/2)
            x_cplx[i] = engine.level_up(x_cplx[i], 21)
            
    x_cplx_rots = evaluator.make_rotated_copies(x_cplx)
    q_wo_rescale = thor_attention.query(x_cplx_rots)
    k = thor_attention.key(x_cplx_rots)
    v = thor_attention.value(x_cplx_rots)

    l_k = evaluator.transpose_upper_to_lower(k)
    l_k_cplx = np.full((4,), None, dtype=object)
    for i in range(4):
        l_k_cplx[i] = engine.cc_add(engine.level_up(l_k[i], l_k[i].level_calc+1), engine.imult(evaluator.rotate_internal(l_k[i], 64, mode='att')))
        l_k_cplx[i] = engine.rescale(l_k_cplx[i])
    
    q = np.full_like(q_wo_rescale, None, dtype=object)
    for i in range(4):
        q[i] = engine.rescale(q_wo_rescale[i])
    q_copies = evaluator.make_copies(q)
    sftmx_scale = 1
    sftmx_in = thor_attention.calculate_attention_score(l_k_cplx, q_copies, bootstrap=False, scale=sftmx_scale, rescale=False)

    for i in range(4):
        temp = engine.cc_add(sftmx_in[i], engine.imult(sftmx_in[i+4]))
        temp = engine.bootstrap(temp)
        conj = engine.conjugate(temp)
        sftmx_in[i] = engine.cc_add(temp, conj)
        sftmx_in[i+4] = engine.imult(engine.cc_sub(conj, temp))

    sftmx_out = thor_attention.softmax(x=sftmx_in, attention_mask=thor_attention_mask, rescale=False, debug=False, sk=None)

    v_cplx = np.full((2,), None, dtype=object)
    for i in range(2):
        v_cplx[i] = engine.cc_add(v[i], engine.imult(v[i+2]))
    if sftmx_out[0].level_calc < v_cplx[0].level_calc:
        for j in range(128):
            sftmx_out[j] = engine.level_up(sftmx_out[j], v[0].level_calc)
    elif sftmx_out[0].level_calc > v_cplx[0].level_calc:
        for j in range(2):
            v_cplx[j] = engine.level_up(v_cplx[j], sftmx_out[0].level_calc)
    for i in range(2):
        v_cplx[i] = engine.rescale(v_cplx[i])
    sftmx_out_rescale = np.full((128,), None, dtype=object)
    for j in range(128):
        sftmx_out_rescale[j] = engine.rescale(sftmx_out[j])
    att_context = thor_attention.calculate_attention_context(v_cplx, sftmx_out_rescale, rescale=False)

    for i in range(2):
        att_context[i] = engine.bootstrap(att_context[i])

    att_context_rots = thor_attention.evaluator.make_rotated_copies(att_context)
    dense_output = thor_attention.dense(att_context_rots)
    x_out_sum = np.full((8,), None, dtype=object)
    mask = np.array(([1]*6+[0]*10)*2**11)
    for i in range(4):
        x_out_sum[i] = engine.add(x[i], dense_output[i])
        x_out_sum[i+4] = engine.add(x[i+4], dense_output[i+4])
    ln1_in = x_out_sum

    ln1_out = thor_attention.layernorm(x=ln1_in, sk=None)
    l = np.full((64,), None,dtype=object)
    mask = np.full((engine.num_slots,), 1, dtype=int)
    mask[np.arange(engine.num_slots) % (16) >= 6] = 0
    for i in range(4):
        temp = engine.cc_add(ln1_out[i], engine.imult(ln1_out[i+4]))
        temp = engine.mc_mult(mask, temp)
        l[16*i] = engine.cc_add(temp, engine.rotate_left(temp, -8))
        for j in range(1, 16):
            index = 16*i+j
            l[index] = engine.rotate_left(l[index-1], 2**11)

    gelu_in_wo_bs = thor_ff.dense1(l)

    for i in range(8):
        temp = engine.cc_add(gelu_in_wo_bs[0,i], engine.imult(gelu_in_wo_bs[1,i]))
        temp = engine.mult_scalar(temp, 1/2)
        temp = engine.bootstrap(temp)
        conj = engine.conjugate(temp)
        gelu_in_wo_bs[0,i] = engine.cc_add(temp, conj)
        gelu_in_wo_bs[1,i] = engine.imult(engine.cc_sub(conj, temp))

    gelu_out = thor_ff.gelu(x=gelu_in_wo_bs)
    dense2_out = thor_ff.dense2(gelu_out)
    ln2_in = np.full((8,), None, dtype=object)
    for i in range(8):
        ln2_in[i] = engine.add(ln1_out[i], dense2_out[i])
    for i in range(4):
        temp = engine.cc_add(ln2_in[i], engine.imult(ln2_in[i+4]))
        temp = engine.bootstrap(temp)
        conj = engine.conjugate(temp)
        ln2_in[i] = engine.cc_add(temp, conj)
        ln2_in[i+4] = engine.imult(engine.cc_sub(conj, temp)) 

    if layer_idx == 9 or layer_idx == 10:
        ln2_out = thor_ff.layernorm(x=ln2_in, sk=None)
    else:
        ln2_out = thor_ff.layernorm(x=ln2_in, sk=None)

    if ln2_out[0].level >8:
        for i in range(8):
            ln2_out[i] = engine.level_up(ln2_out[i], 21)
        
    thor_attention.cpu()
    thor_ff.cpu()
    return ln2_out, (x, q_wo_rescale, sftmx_in, sftmx_out, att_context, ln1_in, ln1_out, gelu_in_wo_bs, gelu_out, dense2_out, ln2_in, ln2_out)

## 2. Forward Attention Layers

### 2-1. Run and Plot Layer 0

**Code for Plotting and Comparison with Plain Model**

In [None]:
variables_list = []
h_indices = [np.where(np.arange(0, 2**11) % 16 == i) for i in range(12)]

def plot_variables(variables, i=0, j=0, h=0):
    global layer_idx, sk, h_indices, engine, dd
    variable_names = ['x', 'q', 'sftmx_in', 'sftmx_out', 'att_context', 'ln1_in', 'ln1_out', 'gelu_in_wo_bs', 'gelu_out', 'dense2_out', 'ln2_in', 'ln2_out']
    global_vars = [hidden_states, qs, sftmx_ins,  sftmx_outs, att_contexts, ln1_ins, ln1_outs, gelu_ins, gelu_outs, dense2_outs, ln2_ins, ln2_outs]

    fig, axs = plt.subplots(4, 3, figsize=(15, 15))
    fig.suptitle(f'Variables Plot (Layer {layer_idx+1})', fontsize=16)

    for index, (var, name, global_var) in enumerate(zip(variables, variable_names, global_vars)):
        row = index // 3
        col = index % 3
        
        if isinstance(var, np.ndarray) and var.ndim > 1:
            var = var[0]
        
        if len(var) <= i:
            print(f'{name} is not available: shape is {len(var)}')
            continue
        
        current_var = engine.decrode(var[i], sk, is_real=True)[2**11*j:2**11*(j+1)][h_indices[h]]
        global_var = global_var[layer_idx]

        if global_var.ndim == 3:
            global_var = global_var[h].T
        elif name in ['gelu_in', 'gelu_out', 'gelu_in_wo_bs']:
            global_var = np.vsplit(global_var.T, 24)[0]
        else:
            global_var = np.vsplit(global_var.T, 6)[h]
        
        global_var_layer = thor.utils.matrix.ld(global_var, i*16+j)

        if name == 'sftmx_in':
            global_var_layer = global_var_layer[:40]
            current_var = current_var[:40]

            if layer_idx != 2:
                current_var = current_var * 32
            else:
                current_var = current_var * 64
        elif name == 'gelu_in_wo_bs':
            current_var = current_var * 64
        elif name == "ln2_in" :
            current_var = current_var/2
            
        axs[row, col].plot(current_var, label=f'HE {name}')
        axs[row, col].plot(global_var_layer, label=f'Plain {name}', linestyle='--')
        axs[row, col].set_title(name)
        axs[row, col].grid(True)
        axs[row, col].legend()

    for ax in axs.flat:
        ax.set(xlabel='Index', ylabel='Decoded Value')

    axs[-1, -1].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
layer_idx = 0
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x1, variables =  forward_layer(x)
variables_list.append(variables)
plot_variables(variables,0)

### 2-2. Run and Plot Layer 1

In [None]:
layer_idx = 1
thor_attention = thor_bert.attentions[layer_idx]    
thor_ff = thor_bert.ffs[layer_idx]
x2 , variables2 = forward_layer(x1)
variables_list.append(variables2)
plot_variables(variables2, 0)

### 2-3. Run and Plot Layer 2

In [None]:
layer_idx = 2
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x3, variables3 = forward_layer(x2)
variables_list.append(variables3)
plot_variables(variables3, 0)

### 2-4. Run and Plot Layer 3

In [None]:
layer_idx = 3
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x4, variables = forward_layer(x3)
variables_list.append(variables)
plot_variables(variables, 0)

### 2-5. Run and Plot Layer 4

In [None]:
layer_idx = 4
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x5 , variables = forward_layer(x4)
variables_list.append(variables)
plot_variables(variables)

### 2-6. Run and Plot Layer 5

In [None]:
layer_idx = 5
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x6 , variables = forward_layer(x5)
variables_list.append(variables)
plot_variables(variables)

### 2-7. Run and Plot Layer 6

In [None]:
layer_idx = 6
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x7, variables = forward_layer(x6)
variables_list.append(variables)
plot_variables(variables)

### 2-8. Run and Plot Layer 7

In [None]:
layer_idx = 7
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x8 , variables = forward_layer(x7)
variables_list.append(variables)
plot_variables(variables)

### 2-9. Run and Plot Layer 8

In [None]:
layer_idx = 8
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x9, variables = forward_layer(x8)
variables_list.append(variables)
plot_variables(variables)

### 2-10. Run and Plot Layer 9

In [None]:
layer_idx = 9
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x10, variables = forward_layer(x9)
variables_list.append(variables)
plot_variables(variables)

### 2-11. Run and Plot Layer 10

In [None]:
layer_idx = 10
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x11, variables = forward_layer(x10)
variables_list.append(variables)
plot_variables(variables)

### 2-12. Run and Plot Layer 11

In [None]:
layer_idx = 11
thor_attention = thor_bert.attentions[layer_idx]
thor_ff = thor_bert.ffs[layer_idx]
x12, variables = forward_layer(x11)
variables_list.append(variables)
plot_variables(variables)

## 3. Run Pooler and Classification

### 3-1. Run Pooler

In [None]:
thor_bert.pooler.to(devices)
x = thor_bert.pooler.forward(x12)

### 3-2. Run Classification

In [None]:
thor_bert.classifier.to(devices)
x = thor_bert.classifier.forward(x)

### 4. Comparison between the prediction and the actual label

In [None]:
import pandas as pd
from datasets import load_dataset

# Load the dataset from the GLUE benchmark
dataset = load_dataset("glue", dataset_type)

# Extract the validation split from the dataset
val_set = dataset["validation"]

# Decrypt the encrypted predictions using the secret key
a = engine.decrode(x[0], sk)[0]
b = engine.decrode(x[1], sk)[0]

# Predict 0 if a > b, otherwise predict 1
pred = 0 if a > b else 1

# Retrieve the ground-truth label from the validation set
label = val_set["label"][target_idx]

# Display the prediction and the actual label
print(f"Predicted by HE: {pred}, Ground Truth: {label}")