In [1]:
import numpy as np

### input ciphertext diagonal vector

In [2]:
# n, m, c 크기 정의 
n = 64 # 단순화를 위해 n은 m의 배수라고 생각. 만약 n이 m의 배수가 아니라면 m의 배수에서 n보다 큰 수 중에서 가장 작은 수를 n으로 한다.
        # 이때 rotation이 커진 n의 수만큼 증가하지만 감안,, 그리고 n의 크기가 2의 거듭제곱이라고 생각한다. (slot에 딱 맞게 배치 가능)
m = 16
c = 4 # 단순화를 위해 c는 m을 나눈다고 생각

In [3]:
# A, B matrix
A = np.arange(m*n)%5 + 1
A = A.reshape(m, n)
B = np.arange(n*n)%5 + 1
B = B.reshape(n,n )
print(f"A matrix-{m}x{n}")
print(A)
print()
print(f"B_matrix-{n}x{n}")
print(B)

A matrix-16x64
[[1 2 3 ... 2 3 4]
 [5 1 2 ... 1 2 3]
 [4 5 1 ... 5 1 2]
 ...
 [3 4 5 ... 4 5 1]
 [2 3 4 ... 3 4 5]
 [1 2 3 ... 2 3 4]]

B_matrix-64x64
[[1 2 3 ... 2 3 4]
 [5 1 2 ... 1 2 3]
 [4 5 1 ... 5 1 2]
 ...
 [5 1 2 ... 1 2 3]
 [4 5 1 ... 5 1 2]
 [3 4 5 ... 4 5 1]]


In [4]:
# A, B diagonal vector list 생성
from func import *

A_list = []
B_list = []

for i in range(m):
    A_list.append(lower_diagonal_vector(A, i))
for i in range(n):
    B_list.append(lower_diagonal_vector(B, i))

# print(A_list)
# print(B_list)


In [5]:
from liberate import fhe
from liberate.fhe import presets

# engine 생성
grade = "bronze" # log N = 14, log S = 13, num_level = 
params = presets.params[grade]
engine = fhe.ckks_engine(**params)

# key 생성, rotation key는 한 번에 모두 생성하면 memory error 발생하여 따로 생성
sk = engine.create_secret_key()
pk = engine.create_public_key(sk=sk)
evk = engine.create_evk(sk=sk)


In [6]:
# # 출력 제한 제거 for test
# np.set_printoptions(threshold=np.inf, linewidth=np.inf)

# diagonal vector를 encrypt. 순차적으로 gap을 두면서!
def pack_lowers_with_gap(diag_list, c, gap, engine, pk):
    """
    diag_list: diagonal_vector를 모아둔 list
    c: slot에 들어가는 diagonal_vector의 수
    gap: slot에서 element 사이의 간격
    engine, pk: liberate.FHE 엔진과 공개키

    return: list of ciphertext
    """
    slot_count = engine.num_slots
    num_diag = len(diag_list) # num of diagonal list
    n_per_diag = len(diag_list[0]) # length of diagonal list

    ct_list = []
    for start_idx in range(0, num_diag, c):
        chunk = diag_list[start_idx: start_idx + c] #TODO:chunk의 길이가 c가 아닌 경우 예외 처리
        chunk_len = len(chunk)

        bigVec = np.zeros(slot_count, dtype = np.float64)

        for i_diag in range(chunk_len):
            diag_vec = chunk[i_diag]
            for t in range(n_per_diag):
                i = i_diag * n_per_diag + t
                slotIndex = i * gap
                if slotIndex < slot_count:
                    bigVec[slotIndex] = diag_vec[t]
        
        # print(f"-------------{start_idx}---------------")
        # print(bigVec[::gap])
        plain = engine.encode(bigVec.tolist())
        ct = engine.encrypt(plain, pk)
        ct_list.append(ct)

    return ct_list

gap = engine.num_slots// (c*len(A_list[0]))
ct_A= pack_lowers_with_gap(A_list, c, gap, engine, pk)


In [7]:
# 원래 diagonal vector가 들어오면 이것을 stacked로 변환해야 하는데,
# 테스트용으로 바로 stacked 형태로 encrypt 한다.
def pack_lowers_with_gap_stacked(diag_list, c, gap, engine, pk):
    """
    길이 n인 i번째 대각선 벡터를 c번 반복해서 n*c 길이의 벡터를 만든다.
    그리고 gap 만큼 공간을 두고, slot에 encrypt 하는 함수
    """
    slot_count = engine.num_slots

    B_cipher = []
    for diag_vec in diag_list:
        # print(diag_list)
        replicate = np.concatenate([diag_vec]*c)
        bigVec = np.zeros(slot_count, dtype=np.float64)
        
        for i, val in enumerate(replicate):
            idx = i * gap
            if idx < slot_count:
                bigVec[idx] = val
            else:
                break
        
        # print("test:", bigVec[::gap])
        plain = engine.encode(bigVec.tolist())
        ct = engine.encrypt(plain, pk)
        B_cipher.append(ct)

    return B_cipher

ct_B = pack_lowers_with_gap_stacked(B_list, c, engine.num_slots//(n*c), engine, pk)

In [8]:
# rotation key 생성
galois_cache = {} # 생성한 회전 키를 저장해둔다.

def rotate(ct, delta, engine, sk, galois_cache): #TODO: 필요한 key를 이해하고 초기에 만들어두는 방식으로 변경
    """
    Liberate.FHE에서 delta 만큼의 rotation일 필요할 때,
    galois_cache에 있으면 그대로 사용
    없으면 회전 키 생성하고 저장한다.
    """

    if delta not in galois_cache:
        gk = engine.create_rotation_key(sk=sk, delta=-1 *delta)
        galois_cache[delta] = gk

    return engine.rotate_single(ct, galois_cache[delta] )

# rot_result = rotate(ct_A[1], -16, engine, sk, galois_cache) # openFHE와 달리 rotate를 하면 왼쪽으로 이동한다.
# rot_result2 = rotate(ct_A[0], -16, engine, sk, galois_cache)
# print(engine.decrode(rot_result, sk))
# print(engine.decrode(rot_result2, sk))
# print(galois_cache)

In [9]:
#  mask encode (TODO: plaintext와 곱할 때, encrypt하지 않고 곱할 수 있는지 확인)
mask_cache = {}

def get_mask(n, c, ell, gap, engine, pk):
    """
    n, c, ell에 대해 0/1로 구성된 mu_l0, mu_l1, mu_l2, mu_l3 벡터를 gap 간격으로 slot에 배치한 plaintext를 만들고 return한다.
    """

    if (n, c, ell, gap) in mask_cache:
        return mask_cache[(n, c, ell, gap)]
    
    slot_count = engine.num_slots
    s = n*c

    mu_l0 = np.zeros(s, dtype=np.float64)
    mu_l1 = np.zeros(s, dtype=np.float64)
    mu_l2 = np.zeros(s, dtype=np.float64)
    mu_l3 = np.zeros(s, dtype=np.float64)

    ell_mod_c = ell % c
    for i in range(s):
        r = i//n
        local_idx = i % n
        is_front = (local_idx < (n-ell))
        if (ell_mod_c <= r < c):
            if is_front:
                mu_l1[i] = 1.0
            else:
                mu_l3[i] = 1.0
        else:
            if is_front:
                mu_l0[i] = 1.0
            else:
                mu_l2[i] = 1.0
    
    bigVec0 = np.zeros(slot_count, dtype=np.float64)
    bigVec1 = np.zeros(slot_count, dtype=np.float64)
    bigVec2 = np.zeros(slot_count, dtype=np.float64)
    bigVec3 = np.zeros(slot_count, dtype=np.float64)

    for i in range(s):
        idx = i * gap
        if idx < slot_count:
            if mu_l0[i] != 0:
                bigVec0[idx] = mu_l0[i]
            if mu_l1[i] != 0:
                bigVec1[idx] = mu_l1[i]
            if mu_l2[i] != 0:
                bigVec2[idx] = mu_l2[i]
            if mu_l3[i] != 0:
                bigVec3[idx] = mu_l3[i]

    # mu_l0_real, mu_l1_real, mu_l2_real, mu_l3_real = mask(n, c, ell)
    # if np.array_equal(mu_l0, mu_l0_real):
    #     print("PASS mu_l0")
    # if np.array_equal(mu_l1, mu_l1_real):
    #     print("PASS mu_l1")
    # if np.array_equal(mu_l2, mu_l2_real):
    #     print("PASS mu_l2")
    # if np.array_equal(mu_l3, mu_l3_real):
    #     print("PASS mu_l3")

    # print("mu_l0:", mu_l0)
    # print("bigVec0:", bigVec0[::gap])
    # print("bigVec1:", bigVec1)
    # print("bigVec2:", bigVec2)
    # print("bigVec3:", bigVec3)
    # if (np.array_equal(mu_l0, bigVec0[::gap])):
    #     print("pass")
    # else: 
    #     print("no pass")
        
    bigVec2 = np.roll(bigVec2, gap * n * (-1))
    mu_l0_plain = engine.encorypt(bigVec0.tolist(), pk)
    mu_l1_plain = engine.encorypt(bigVec1.tolist(), pk)
    mu_l2_plain = engine.encorypt(bigVec2.tolist(), pk)
    mu_l3_plain = engine.encorypt(bigVec3.tolist(), pk)

    mask_cache[(n, c, ell, gap)] = (mu_l0_plain, mu_l1_plain, mu_l2_plain, mu_l3_plain)
    return mask_cache[(n, c, ell, gap)]

# get_mask(n, c, 1, engine.num_slots//(n*c), engine, pk)


In [10]:
# 행렬곱 함수
def cipher_ciphert_mult_liberate(A_cipher, B_cipher, n, m, c, engine, sk, evk, galois_cache):
    """
    THOR의 algorithm2

    m x n / n x n 행렬곱 함수
    """
    gap = engine.num_slots // (n*c)
    m_tilde = m // c
    C_cipher = [None] * m_tilde
    ctC_jl = {}

    for j in range(m_tilde):
        c_init = engine.mult(A_cipher[j], B_cipher[0], evk = evk)
        C_cipher[j] = c_init

        for ell in range(1, n):
            idxA = (j + -1*(ell//c))%m_tilde
            delta = -n * (ell % c) + ell
            delta *= gap

            ctj_ell = rotate(A_cipher[idxA], delta, engine, sk, galois_cache)
            product = engine.mult(ctj_ell, B_cipher[ell], evk)
            ctC_jl[(j, ell)] = product

    # cipher = engine.decrode(C_cipher[1], sk)
    # cipher = engine.decrode(ctC_jl[(1,1)], sk)
    # cipher = np.array(cipher)
    # print(cipher[::gap]) # 대충 동일함


    ct_part = {}
    for ell in range(1, n):
        mu_l0, mu_l1, mu_l2, mu_l3 = get_mask(n, c, ell, gap, engine, pk)

        for j in range(m_tilde):
            # if (ell == 2 and j == 0):
            #     cipher0 = engine.decrode(mu_l0, sk)
            #     cipher1 = engine.decrode(mu_l1, sk)
            #     cipher2 = engine.decrode(mu_l2, sk)
            #     cipher3 = engine.decrode(mu_l3, sk)

            #     print(np.array(cipher0)[::gap])
            #     print(np.array(cipher1)[::gap])
            #     print(np.array(cipher2)[::gap])
            #     print(np.array(cipher3)[::gap])

            # tmp = ctC_jl[(j, ell)]

            ctj_ell_0 = engine.mult(ctC_jl[(j, ell)], mu_l0, evk)
            ctj_ell_1 = engine.mult(ctC_jl[(j, ell)], mu_l1, evk)
            ctj_ell_2 = engine.mult(ctC_jl[(j, ell)], mu_l2, evk)

            ctj_ell_3 = engine.sub(ctC_jl[(j, ell)], ctj_ell_0)
            ctj_ell_3 = engine.sub(ctj_ell_3, ctj_ell_1)
            ctj_ell_3 = engine.sub(ctj_ell_3, ctj_ell_2)

            # if (ell == 2 and j == 0):
            #     cipher0 = engine.decrode(ctj_ell_0, sk)
            #     cipher1 = engine.decrode(ctj_ell_1, sk)
            #     cipher2 = engine.decrode(ctj_ell_2, sk)
            #     cipher3 = engine.decrode(ctj_ell_3, sk)

            #     print("0:", np.array(cipher0)[::gap])
            #     print("1:", np.array(cipher1)[::gap])
            #     print("2:", np.array(cipher2)[::gap])
            #     print("3:", np.array(cipher3)[::gap])

            ct_part[(j, ell, 0)] = ctj_ell_0
            ct_part[(j, ell, 1)] = ctj_ell_1
            ct_part[(j, ell, 2)] = ctj_ell_2
            ct_part[(j, ell, 3)] = ctj_ell_3



    for j in range(m_tilde):
        ctCPrime_j = None
        ctCDPrime_j = None

        for ell in range(1, n):
            j_minus = (j-1)%m_tilde

            sum_01 = engine.add(ct_part[(j_minus, ell, 0)], ct_part[(j, ell, 1)])
            sum_23 = engine.add(ct_part[(j_minus, ell, 2)], ct_part[(j, ell, 3)])

            if (ell == 17 and j == 0):
                cipher0 = engine.decrode(sum_01, sk)
                cipher1 = engine.decrode(sum_23, sk)
                cipher2 = engine.decrode(ctCPrime_j, sk)
                cipher3 = engine.decrode(ctCDPrime_j, sk)

                # print("sum_01:", sum_01.level)
                # print("sum_02:", sum_23.level)
                # print("sum_01:", np.array(cipher0)[::gap])
                # print("sum_23:", np.array(cipher1)[::gap])
                # print("ctCPirme_j:", np.array(cipher2)[::gap])
                # print("ctCDPrime_j:", np.array(cipher3)[::gap])
            
            if ctCPrime_j is None:
                ctCPrime_j = sum_01
            else:
                ctCPrime_j = engine.add(ctCPrime_j, sum_01)
                
            if ctCDPrime_j is None:
                ctCDPrime_j = sum_23
            else:
                ctCDPrime_j = engine.add(ctCDPrime_j, sum_23)

            # if (ell == 17 and j == 0):
            #     cipher0 = engine.decrode(ctCPrime_j, sk)
            #     cipher1 = engine.decrode(ctCDPrime_j, sk)
            
                # print("0:", np.array(cipher0)[::gap])
                # print("1:", np.array(cipher1)[::gap])


        rotated_dprime = rotate(ctCDPrime_j, -n * gap, engine, sk, galois_cache)
        
        # if (j == 0):
        #     cipher = engine.decrode(rotated_dprime, sk)
        #     print(np.array(cipher)[::gap])

        tmp_add = engine.add(C_cipher[j], ctCPrime_j)
        tmp_add = engine.add(tmp_add, rotated_dprime)
        C_cipher[j] = tmp_add

    return C_cipher

In [11]:
# 행렬곱 시행
C_cipher = cipher_ciphert_mult_liberate(ct_A, ct_B, n, m, c, engine, sk, evk, galois_cache)


In [None]:
# 결과 분석
C_true = A @ B
m_tilde = m // c

C_diag = []
for i in range(m_tilde):
    diag = np.zeros(n*c)
    for j in range(c):
        diag[j*n: (j+1)*n] = lower_diagonal_vector(C_true, c * i + j)
    C_diag.append(diag)

c_diag = []
for j in range(m_tilde):
    vals = engine.decrypt(C_cipher[j], sk)
    vals = engine.decode(vals)
    c_diag.append(np.array(vals)[::gap])

# 실제와 결과가 같은지 확인, 오차 0.5 이내
for j in range(m_tilde):
    if np.all(np.abs(C_diag[i] - c_diag[i]) <= 0.5):
        print("good")
    else:
        print("sad,,")
    

good
good
good
good
