In [None]:
# hide
%load_ext autoreload
%autoreload 2
%load_ext pycodestyle_magic

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The pycodestyle_magic extension is already loaded. To reload it, use:
  %reload_ext pycodestyle_magic


# Matrix multiplication in CKKS

This notebook implements the paper https://eprint.iacr.org/2018/1041.pdf which allows to perform matrix multiplication with $\mathcal{O}(d)$ operations on matrices of size $d \times d$.

In [None]:
# export 
import numpy as np
import tenseal.sealapi as seal
from typing import List

To perform matrix multiplication we must implement 4 matrix operators : 

- $\sigma(A)_{i,j} = A_{i,i+j}$
- $\tau(A)_{i,j} = A_{i+j,j}$
- $\phi(A)_{i,j} = A_{i,j+1}$
- $\psi(A)_{i,j} = A_{i+1,j}$

## Base class

In [None]:
def sigma_diagonal_vector(d: int, k:int) -> np.array:
    """Creates the k-th diagonal for the sigma operator
    for matrices of dimension dxd."""
    
    u = np.arange(d**2)
    if k >= 0:
        index = (u - d*k >= 0) & (u < d*k + d - k)
    else:
        index = (u - d*(d+k) >= -k ) & (u - d*(d+k)< d)
    u[index] = 1
    u[~index] = 0
    return u

In [None]:
def tau_diagonal_vector(d: int, k:int) -> np.array:
    """Creates the k-th diagonal for the tau operator
    for matrices of dimension dxd."""
    
    u = np.zeros(d**2)
    for i in range(d):
        l = (k + d * i)
        u[l] = 1
    return u

In [None]:
def row_diagonal_vector(d,k):
    v_k = np.arange(d**2)
    index = (v_k % d) < (d - k)
    v_k[index] = 1
    v_k[~index] = 0
    
    v_k_d = np.arange(d**2)
    index = ((v_k_d % d) >= (d -k)) & ((v_k_d % d) < d)
    v_k_d[index] = 1
    v_k_d[~index] = 0
    return v_k, v_k_d

In [None]:
def column_diagonal_vector(d,k):
    v_k = np.ones(d**2)
    return v_k

In [None]:
class MatrixMultiplicator:
    """Base class to create a matrix multiplicator operator."""
    def __init__(self, d, create_zero, sigma_diagonal_vector, tau_diagonal_vector,
                 row_diagonal_vector, column_diagonal_vector,
                 rotate=None, add=None, pmult=None, cmult=None):
        
        self.d = d
        self.create_zero = create_zero
        self.sigma_diagonal_vector = sigma_diagonal_vector
        self.tau_diagonal_vector = tau_diagonal_vector
        self.row_diagonal_vector = row_diagonal_vector
        self.column_diagonal_vector = column_diagonal_vector
        
        if not rotate:
            rotate = lambda x,k: np.roll(x, -k)
        if not add:
            add = lambda x,y: x+y
        if not pmult:
            pmult = lambda x,y: x*y
        if not cmult:
            cmult = lambda x,y: x*y
            
        self.rotate, self.add, self.pmult, self.cmult = rotate, add, pmult, cmult
    
    def sigma_lin_transform(self, input):
        
        sigma = []
        d = self.d
    
        for k in range(-d+1,d):
            sigma.append(self.sigma_diagonal_vector(d,k))
        
        output = self.create_zero()
        
        for sigma_vector,k in zip(sigma,range(-d+1,d)):
            output = self.add(output, self.pmult(self.rotate(input,k), sigma_vector))
        return output
    
    def tau_lin_transform(self, input):

        tau = []
        d = self.d

        for k in range(d):
            tau.append(self.tau_diagonal_vector(d,k))
            
        output = self.create_zero()
        
        for tau_vector,k in zip(tau,range(d)):
            output = self.add(output, self.pmult(self.rotate(input,k * d), tau_vector))
        return output
    
    def row_lin_transform(self, input, k):
        
        d = self.d
        v_k, v_k_d = self.row_diagonal_vector(d, k)
        
        output = self.create_zero()
        
        output = self.add(output, self.pmult(self.rotate(input, k), v_k))
        output = self.add(output, self.pmult(self.rotate(input, k-d), v_k_d))

        return output
    
    def column_lin_transform(self, input, k):
        
        d = self.d
        v_k = self.column_diagonal_vector(d, k)
        
        output = self.create_zero()
        
        output = self.add(output, self.pmult(self.rotate(input, d*k),v_k))

        return output
    
    def matmul(self, A, B):
        
        d = self.d

        sigma_A = self.create_zero()
        sigma_A = self.sigma_lin_transform(A)

        tau_B = self.create_zero()
        tau_B = self.tau_lin_transform(B)

        output = self.cmult(sigma_A, tau_B)

        for k in range(1,d):
            shift_A = self.row_lin_transform(sigma_A, k)
            shift_B = self.column_lin_transform(tau_B, k)

            output = self.add(output, self.cmult(shift_A, shift_B))
        
        return output
        

In [None]:
def encode_matrices_to_vector(matrix):
    shape = matrix.shape
    assert len(shape) == 3, "Non tridimensional tensor"
    assert shape[1] == shape[2], "Non square matrices"
    
    g = shape[0]
    d = shape[1]
    n = g * (d ** 2)
    
    output = np.zeros(n)
    for l in range(n):
        k = l % g
        i = (l // g) // d
        j = (l // g) % d
        output[l] = matrix[k,i,j]
        
    return output

def decode_vector_to_matrices(vector, d):
    n = len(vector)
    g = n // (d ** 2)
    
    output = np.zeros((g, d, d))
    
    for k in range(g):
        for i in range(d):
            for j in range(d):
                output[k,i,j] = vector[g * (d*i + j) +k]
    return output

In [None]:
def encode_matrix_to_vector(matrix: np.array) -> np.array:
    """Encodes a d*d matrix to a vector of size d*d"""
    shape = matrix.shape
    assert len(shape) == 2 and shape[0] == shape[1], "Non square matrix"
    d = shape[0]
    output = np.zeros(d**2)
    for l in range(d**2):
        i = l // d
        j = l % d
        output[l] = matrix[i,j]
    return output

def decode_vector_to_matrix(vector):
    n = len(vector)
    d = np.sqrt(n)
    assert len(vector.shape) == 1 and d.is_integer(), "Non square matrix"
    d = int(d)
    
    output = np.zeros((d,d))
    
    for i in range(d):
        for j in range(d):
            output[i,j] = vector[d*i + j]
    return output

In [None]:
def weave(vector, g):
    output = np.zeros(len(vector) * g)
    for i in range(len(vector)):
        output[i*g:(i+1)*g] = vector[i]
    return output

In [None]:
d = 3

A = np.random.randn(d**2)
B = np.random.randn(d**2)

create_zero = lambda : np.zeros(d**2)

mm = MatrixMultiplicator(d, create_zero, sigma_diagonal_vector, tau_diagonal_vector, 
                         row_diagonal_vector, column_diagonal_vector)

In [None]:
l2_error = lambda x,y : ((x - y) ** 2).mean()

In [None]:
expected = encode_matrix_to_vector(np.matmul(decode_vector_to_matrix(A),decode_vector_to_matrix(B)))
predicted = mm.matmul(A,B)

l2_error(expected,predicted)

2.8760553836182724e-32

In [None]:
d = 3
g = 2

A = np.random.randn(d**2)
B = np.random.randn(d**2)

C = np.concatenate([decode_vector_to_matrix(A).reshape(1,d,d)]*g)
C = encode_matrices_to_vector(C)

D = np.concatenate([decode_vector_to_matrix(B).reshape(1,d,d)]*g)
D = encode_matrices_to_vector(D)

parallel_sigma_diagonal_vector = lambda d,k: weave(sigma_diagonal_vector(d,k),g)
parallel_tau_diagonal_vector = lambda d,k: weave(tau_diagonal_vector(d,k),g)
parallel_row_diagonal_vector = lambda d,k: [weave(vector,g) for vector in row_diagonal_vector(d,k)]
parallel_column_diagonal_vector = lambda d,k: weave(column_diagonal_vector(d,k),g)

parallel_create_zero = lambda: np.zeros(g * (d** 2))
parallel_rotate = lambda x,k: np.roll(x, -(k * g))

pmm = MatrixMultiplicator(d,parallel_create_zero , parallel_sigma_diagonal_vector, parallel_tau_diagonal_vector,
                         parallel_row_diagonal_vector, parallel_column_diagonal_vector, parallel_rotate)

In [None]:
expected = encode_matrices_to_vector(np.matmul(decode_vector_to_matrices(C, d), decode_vector_to_matrices(C, d)))
predicted = pmm.matmul(C,D)

l2_error(expected, predicted)

3.4530904186503086

In [None]:
decode_vector_to_matrix(mm.matmul(A,B))

array([[ 2.22480278,  0.84396379, -0.8360246 ],
       [ 0.11749061,  0.0406489 ,  0.11238539],
       [ 1.19914222, -0.98280776, -0.16491496]])

In [None]:
decode_vector_to_matrices(pmm.matmul(C,D),d)

array([[[ 2.22480278,  0.84396379, -0.8360246 ],
        [ 0.11749061,  0.0406489 ,  0.11238539],
        [ 1.19914222, -0.98280776, -0.16491496]],

       [[ 2.22480278,  0.84396379, -0.8360246 ],
        [ 0.11749061,  0.0406489 ,  0.11238539],
        [ 1.19914222, -0.98280776, -0.16491496]]])

In [None]:
import builtins
from cryptotree.seal_helper import print_ctx, print_ptx, create_seal_globals, append_globals_to_builtins

poly_modulus_degree = 8192
moduli = [40,30,30,30,40]
PRECISION_BITS = 30

create_seal_globals(globals(), poly_modulus_degree, moduli, PRECISION_BITS)
append_globals_to_builtins(globals(), builtins)

In [None]:
d = 28

A = np.random.randn(d ** 2)
B = np.random.randn(d ** 2)

In [None]:
ptx = seal.Plaintext()

encoder.encode(A, scale, ptx)
ctA = seal.Ciphertext()
encryptor.encrypt(ptx, ctA)

encoder.encode(B, scale, ptx)
ctB = seal.Ciphertext()
encryptor.encrypt(ptx, ctB)

In [None]:
def get_vector(ctx):
    ptx = seal.Plaintext()
    decryptor.decrypt(ctx, ptx)
    return np.array(encoder.decode_double(ptx))

def encode(vector):
    ptx = seal.Plaintext()
    encoder.encode(vector, scale, ptx)
    return ptx

def encrypt(vector):
    ptx = encode(vector)
    ctx = seal.Ciphertext()
    encryptor.encrypt(ptx, ctx)
    return ctx

In [None]:
def ckks_create_zero():
    zero = np.zeros(encoder.slot_count())
    ptx = seal.Plaintext()
    encoder.encode(zero, scale, ptx)
    ctx = seal.Ciphertext()
    encryptor.encrypt(ptx, ctx)
    return ctx

def ckks_rotate(ctx, k):
    output = seal.Ciphertext()
    evaluator.rotate_vector(ctx, k, galois_keys, output)
    return output

def ckks_add(ctx1, ctx2):
    output = seal.Ciphertext()
    if not ctx1.parms_id() == ctx2.parms_id():
        evaluator.mod_switch_to_inplace(ctx1, ctx2.parms_id())
    evaluator.add(ctx1, ctx2, output)
    return output

def ckks_pmult(ctx, ptx):
    output = seal.Ciphertext()
    if not ptx.parms_id() == ctx.parms_id():
        evaluator.mod_switch_to_inplace(ptx, ctx.parms_id())
    evaluator.multiply_plain(ctx, ptx, output)
    evaluator.rescale_to_next_inplace(output)
    output.scale = scale
    return output

def ckks_cmult(ctx1, ctx2):
    output = seal.Ciphertext()
    if not ctx2.parms_id() == ctx1.parms_id():
        evaluator.mod_switch_to_inplace(ctx2, ctx1.parms_id())
    evaluator.multiply(ctx1, ctx2, output)
    evaluator.rescale_to_next_inplace(output)
    output.scale = scale
    return output

ckks_sigma_diagonal_vector = lambda d,k: encode(sigma_diagonal_vector(d,k))
ckks_tau_diagonal_vector = lambda d,k: encode(tau_diagonal_vector(d,k))
ckks_row_diagonal_vector = lambda d,k: [encode(vector) for vector in row_diagonal_vector(d,k)]
ckks_column_diagonal_vector = lambda d,k: encode(column_diagonal_vector(d,k))

cmm = MatrixMultiplicator(d, ckks_create_zero, ckks_sigma_diagonal_vector, ckks_tau_diagonal_vector,
                         ckks_row_diagonal_vector, ckks_column_diagonal_vector, ckks_rotate, ckks_add, 
                          ckks_pmult, ckks_cmult)

In [None]:
mm = MatrixMultiplicator(d, create_zero, sigma_diagonal_vector, tau_diagonal_vector, 
                         row_diagonal_vector, column_diagonal_vector)

In [None]:
predicted = cmm.matmul(ctA, ctB)

In [None]:
expected = mm.matmul(A,B)

l2_error(get_vector(predicted)[:784], expected)

33.51652090661767

In [None]:
d = 16
g = 16

A = np.random.randn(d ** 2)
B = np.random.randn(d ** 2)

C = np.concatenate([decode_vector_to_matrix(A).reshape(1,d,d)]*g)
C = encode_matrices_to_vector(C)
ctC = encrypt(C)

D = np.concatenate([decode_vector_to_matrix(B).reshape(1,d,d)]*g)
D = encode_matrices_to_vector(D)
ctD = encrypt(D)

parallel_ckks_sigma_diagonal_vector = lambda d,k: encode(weave(sigma_diagonal_vector(d,k),g))
parallel_ckks_tau_diagonal_vector = lambda d,k: encode(weave(tau_diagonal_vector(d,k),g))
parallel_ckks_row_diagonal_vector = lambda d,k: [encode(weave(vector,g)) for vector in row_diagonal_vector(d,k)]
parallel_ckks_column_diagonal_vector = lambda d,k: encode(weave(column_diagonal_vector(d,k),g))

parallel_ckks_rotate = lambda ctx,k: ckks_rotate(ctx, k*g)

pcmm = MatrixMultiplicator(d,ckks_create_zero , parallel_ckks_sigma_diagonal_vector, parallel_ckks_tau_diagonal_vector,
                         parallel_ckks_row_diagonal_vector, parallel_ckks_column_diagonal_vector, parallel_ckks_rotate,
                          ckks_add, ckks_pmult, ckks_cmult)

In [None]:
2**13

8192

In [None]:
parallel_sigma_diagonal_vector = lambda d,k: weave(sigma_diagonal_vector(d,k),g)
parallel_tau_diagonal_vector = lambda d,k: weave(tau_diagonal_vector(d,k),g)
parallel_row_diagonal_vector = lambda d,k: [weave(vector,g) for vector in row_diagonal_vector(d,k)]
parallel_column_diagonal_vector = lambda d,k: weave(column_diagonal_vector(d,k),g)

parallel_create_zero = lambda: np.zeros(g * (d** 2))
parallel_rotate = lambda x,k: np.roll(x, -(k * g))

pmm = MatrixMultiplicator(d,parallel_create_zero , parallel_sigma_diagonal_vector, parallel_tau_diagonal_vector,
                         parallel_row_diagonal_vector, parallel_column_diagonal_vector, parallel_rotate)

pmm = MatrixMultiplicator(d,parallel_create_zero , parallel_sigma_diagonal_vector, parallel_tau_diagonal_vector,
                         parallel_row_diagonal_vector, parallel_column_diagonal_vector, parallel_rotate)

In [None]:
predicted = pcmm.matmul(ctC,ctD)

In [None]:
expected = pmm.matmul(C,D)

l2_error(get_vector(predicted), expected)

3.647689479760129e-06

In [None]:
((v - expected) ** 2).mean()

5.134337397070778e-06

In [None]:
evaluator.rotate_vector_inplace(ctx, 1, galois_keys)
print_ctx(ctx)


    [ 1.9994646, 3.0002853, 4.0000778, ..., -0.0000002, 0.0000010, 0.9999999 ]



In [None]:
# export
from fastcore.test import test_close

def test_sum(x: List[float], evaluator, encoder, encryptor, decryptor, scale, eps=1e-2):
    """Tests if the output of the polynomial, defined by the coeffs, is the same
    between the homomorphic evaluation and the regular one"""
    n_slot = len(x)
    
    ptx = seal.Plaintext()
    encoder.encode(x, scale, ptx)
    
    ctx = seal.Ciphertext()
    encryptor.encrypt(ptx, ctx)
    
    output = sum_reduce(ctx, evaluator, galois_keys, n_slot)
    decryptor.decrypt(output, ptx)
    
    values = encoder.decode_double(ptx)
    
    homomorphic_output = values[0]
    expected_output = np.sum(x)
    
    test_close(homomorphic_output, expected_output, eps)
    
def test_dot_product_plain(x: List[float], y: List[float], 
                           evaluator, encoder, encryptor, decryptor, 
                           galois_keys,
                           scale, eps=1e-2):
    """Tests if the output of the polynomial, defined by the coeffs, is the same
    between the homomorphic evaluation and the regular one"""
    assert len(x) == len(y), f"x and y must have same length {len(x)} != {len(y)}"
    n_slot = len(x)
    
    ptx = seal.Plaintext()
    encoder.encode(x, scale, ptx)
    
    ctx = seal.Ciphertext()
    encryptor.encrypt(ptx, ctx)
    
    pty = seal.Plaintext()
    encoder.encode(y, scale, pty)
    
    output = dot_product_plain(ctx, pty, evaluator, galois_keys, n_slot)
    decryptor.decrypt(output, ptx)
    
    values = encoder.decode_double(ptx)
    
    homomorphic_output = values[0]
    expected_output = np.dot(x, y)
    
    test_close(homomorphic_output, expected_output, eps)

In [None]:
print_ctx(sum_reduce(ctx, evaluator, galois_keys, n_slot))
print_ctx(dot_product_plain(ctx, ptx, evaluator, galois_keys, n_slot))


    [ 9.9999583, 9.0000185, 6.9999091, ..., 1.0000008, 3.0003153, 6.0000488 ]


    [ 30.0000054, 29.0000064, 25.0000052, ..., 0.9999990, 5.0000002, 13.9999994 ]



In [None]:
test_sum([1,2], evaluator, encoder, encryptor, decryptor, scale)
test_sum([1,2,3], evaluator, encoder, encryptor, decryptor, scale)

In [None]:
test_dot_product_plain([1,2,3], [1,1,1], evaluator, encoder, encryptor, decryptor, galois_keys, scale)
test_dot_product_plain([1,2,3,5], [1,1,1,-6], evaluator, encoder, encryptor, decryptor, galois_keys, scale)