# Introduction to CKKS

## Introduction to Homomorphic Encryption

### Context

Homomorphic encryption (HE) is an encryption scheme, which allows data owners to encrypt their data, and let a third party perform computations on it, without knowing what is the underlying data. The result of the computations on encrypted data can then be sent back to the data owner, which will be the only one able to decrypt the encrypted result.

HE enables computation to be performed on encrypted data, while preserving the decryption output, as it leverages homomorphic properties of the encryption and the decryption. This means that addition and multiplication can be done on encrypted data, and the decryption output will be as if those operations were done directly on the unencrypted data.

More formally, a ring homomorphism $h$ between two rings $R$ and $R'$, follows those two properties : 
$$h(x + y) = h(x) + h(y)$$
$$h(x * y) = h(x) * h(y)$$

This means that if we have an encryption homomorphism $e$, a decryption homomorphism $d$, such that $d(e(x)) = x$, and a function $f$, which is a composition of additions and multiplications, then we can have the following scenario :

- The user encrypts their data $x$ using $e$, and sends $e(x)$ to an untrusted third party.
- The third party performs computations $f$ on the encrypted $e(x)$. Because $e$ is an homomorphism, we have that $f(ex(x)) = e(fx)$. Then the third party sends the data back to the user.
- Finally the user decrypts the output, obtaining then $d(e(f(x))) = f(x)$, without exposing her data directly to the untrusted third party.

### The CKKS scheme

![Overview CKKS](images/overview_ckks.PNG)
<center>High level view of CKKS.</center>

We will focus now on the CKKS scheme, which allows homomorphic operations on complex, thus real, values. The figure above provides a high level view of how CKKS works :

- The user generates a secret and public key beforehand.
- The user then first encodes a complex vector, on which computation will be done, into a plaintext polynomial.
- This plaintext polynomial is encrypted into a ciphertext made of two polynomials using the public key.
- Computations are then performed on the ciphertext by a third party.
- The results are decrypted, using the private key, into a plaintext polynomial.
- Finally the user can decode and read the values.


## Deep dive in the code

### Setup

In [None]:
import tenseal.sealapi as seal

First we will set the parameters needed by SEAL for the context. Because we use 200 bits in total for our multiplications, we must choose $N = 8192$.

In [3]:
poly_modulus_degree = 8192
moduli = [60, 40, 40, 60]
PRECISION_BITS = 40
scale = pow(2.0, PRECISION_BITS)

parms = seal.EncryptionParameters(seal.SCHEME_TYPE.CKKS)
parms.set_poly_modulus_degree(poly_modulus_degree)
parms.set_coeff_modulus(seal.CoeffModulus.Create(
    poly_modulus_degree, moduli))

context = seal.SEALContext.Create(parms, True, seal.SEC_LEVEL_TYPE.TC128)

We can now generate our keys, the evaluator, encryptor, decryptor and encoder using our context and our keys.

In [4]:
keygen = seal.KeyGenerator(context)

public_key = keygen.public_key()
secret_key = keygen.secret_key()
relin_keys = keygen.relin_keys_local()
galois_keys = keygen.galois_keys_local()

encryptor = seal.Encryptor(context, public_key)
evaluator = seal.Evaluator(context)
decryptor = seal.Decryptor(context, secret_key)
encoder = seal.CKKSEncoder(context)

Just for convenience, we define bellow functions to print directly the content of a ciphertext or plaintext, without having to de decode and decrypt first all the time. This is mainl for educational and debugging purpose.

In [7]:
# export
def print_vector(vec, print_size=4, prec=3):
    """Prints a vector with a given level of precision and print size"""
    slot_count = len(vec)
    print()
    if slot_count <= 2*print_size:
        print("    [", end="")
        for i in range(slot_count):
            print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="")
    else:
        print("    [", end="")
        for i in range(print_size):
            print(" " + (f"%.{prec}f" % vec[i]) + ",", end="")
        if len(vec) > 2*print_size:
            print(" ...,", end="")
        for i in range(slot_count - print_size, slot_count):
            print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="")
    print()
    
def print_ptx(ptx: seal.Plaintext):
    result = encoder.decode_double(ptx)
    print_vector(result, 3, 7)
    
def print_ctx(ctx: seal.Ciphertext):
    ptx = seal.Plaintext()
    decryptor.decrypt(ctx, ptx)
    print_ptx(ptx)

We can now start playing with SEAL. Let's see how to encode and encrypt a vector using SEAL.

In [8]:
x = [1,2,3,4]
ptx = seal.Plaintext()
ctx = seal.Ciphertext()

encoder.encode(x, scale, ptx)
encryptor.encrypt(ptx, ctx)

We can print the plaintext and the ciphertext and make sure they output the same thing.

In [9]:
print_ptx(ptx)
print_ctx(ctx)


    [ 1.0000000, 2.0000000, 3.0000000, ..., 0.0000000, -0.0000000, 0.0000000 ]


    [ 1.0000000, 2.0000000, 3.0000000, ..., -0.0000000, 0.0000000, -0.0000000 ]



### Homomorphic operations

#### Addition

We will encode another vector so that we will be able to perform addition and multiplication between them.

In [10]:
y = [-1, -2, -3, -4]
pty = seal.Plaintext()
cty = seal.Ciphertext()

In [11]:
encoder.encode(y, scale, pty)
encryptor.encrypt(pty, cty)

Now let's see how we can add them together : 

In [13]:
ct_plus = seal.Ciphertext()
evaluator.add(ctx, cty, ct_plus)

print_ctx(ct_plus)


    [ -0.0000000, -0.0000000, 0.0000000, ..., -0.0000000, -0.0000000, -0.0000000 ]



Perfect ! It works ! That was pretty easy. 

#### Multiplication

Let's see now how we can multiply them together :

In [14]:
ct_mul = seal.Ciphertext()
evaluator.multiply(ctx, cty, ct_mul)

print_ctx(ct_mul)


    [ -1.0000000, -4.0000000, -9.0000000, ..., -0.0000000, -0.0000000, -0.0000000 ]



This works as well ! 

Nonetheless things have become more complicated with multiplication, because underneath, the number of polynomials needed for decryption, stored in the ciphertext, have increased : 

In [17]:
print(f"Size of the ciphertext after multiplication {ct_mul.size()}")
print(f"Size of the ciphertext after addition {ct_plus.size()}")

Size of the ciphertext after multiplication 3
Size of the ciphertext after addition 2


If we let things go as they are, the size of our ciphertext will keep increasing after each multiplication which will make everything more complicated and error prone.

That's why we will use the relinearization keys defined earlier, in order to reduce the size of the ciphertext : 

In [18]:
print(f"Size of the ciphertext before relinearization {ct_mul.size()}")
evaluator.relinearize_inplace(ct_mul, relin_keys)
print(f"Size of the ciphertext after relinearization {ct_mul.size()}")

Size of the ciphertext before relinearization 3
Size of the ciphertext after relinearization 2


So that's all good ! 

But another problem emerges in CKKS after multiplication : the scale has changed. Because we of the multiplication, we have : $z = \Delta x * \Delta y = \Delta^2 x y$. Therefore, if we add something of a different scale, let's say $\Delta x$, SEAL will throw an error because we try to add things on different scales :

In [19]:
try:
    evaluator.add_inplace(ct_mul, ctx)
except ValueError as e:
    print(e)

scale mismatch


We can check this manually :

In [20]:
print(ct_mul.scale)
print(ctx.scale)

1.2089258196146292e+24
1099511627776.0


So now we need to rescale $z$ by using one of our prime contained in the moduli, which will decrease the level by one : 

In [21]:
evaluator.rescale_to_next_inplace(ct_mul)
print(ct_mul.scale)

1099511775231.0198


We see that our scale has indeed decreased, nonetheless because we divided by a prime close to $\Delta$ and not exactly $\Delta$, we still do not have things of the exact same scale : 

In [23]:
print(ct_mul.scale == scale)
print(ct_mul.scale)
print(scale)

False
1099511775231.0198
1099511627776.0


We can see that the scales are pretty close, but still not the same. That's why we need to manually force the scale of $z$ to be equal to $\Delta$.

In [14]:
ct_mul.scale = scale
print(ct_mul.scale == ctx.scale)

True


Things are not over yet, because we used one prime and threw it away in $z$, we reduced the level of $z$, while $x$ is on another level, because its moduli contains one more prime.

In [24]:
try:
    evaluator.add_inplace(ct_mul, ctx)
except ValueError as e:
    print(e)

encrypted1 and encrypted2 parameter mismatch


So the final thing we need to do is to bring $x$ to the same level as $z$, by removing one of its prime modulus, and add them together : 

In [25]:
ctx_leveled = seal.Ciphertext()
evaluator.mod_switch_to(ctx, ct_mul.parms_id(), ctx_leveled)

try:
    evaluator.add_inplace(ct_mul, ctx_leveled)
except ValueError as e:
    print(e)

scale mismatch


In [26]:
print_ctx(ct_mul)


    [ -1.0000000, -4.0000000, -9.0000000, ..., -0.0000000, 0.0000000, 0.0000000 ]



#### Rotation

Finally, we have one more operation in CKKS which is rotation, which allows us to shift the slots of our vectors homomorphically.

In [29]:
print("Initial vector : ")
print_ctx(ctx)

ctx_shifted = seal.Ciphertext()

print("Galois rotation of 1 (shift on the left) :")
evaluator.rotate_vector(ctx, 1, galois_keys, ctx_shifted)
print_ctx(ctx_shifted)

print("Galois rotation of -1 (shift on the right) :")
evaluator.rotate_vector(ctx, -1, galois_keys, ctx_shifted)
print_ctx(ctx_shifted)

Initial vector : 

    [ 1.0000000, 2.0000000, 3.0000000, ..., -0.0000000, 0.0000000, -0.0000000 ]

Galois rotation of 1 (shift on the left) :

    [ 2.0000003, 3.0000000, 4.0000000, ..., 0.0000000, 0.0000000, 1.0000000 ]

Galois rotation of -1 (shift on the right) :

    [ 0.0000003, 0.9999998, 1.9999999, ..., -0.0000000, -0.0000000, 0.0000000 ]



In [19]:
import numpy as np

def sum_reduce(ctx: seal.Ciphertext, evaluator: seal.Evaluator, 
               galois_keys: seal.GaloisKeys, n_slot: int):
    """Sums all the coefficients of the ciphertext, supposing that coefficients up to n_slot 
    are non zero. The first coefficient of the output will then be the sum of the coefficients."""
    n = int(np.ceil(np.log2(n_slot)))
    
    temp = seal.Ciphertext()
    output = seal.Ciphertext()
    
    for i in range(n):
        if i == 0:
            evaluator.rotate_vector(ctx, 2**i, galois_keys, temp)
            evaluator.add(ctx, temp, output)
        else:
            evaluator.rotate_vector(output, 2**i, galois_keys, temp)
            evaluator.add_inplace(output, temp)
    return output

def dot_product_plain(ctx: seal.Ciphertext, ptx: seal.Plaintext,
                      evaluator: seal.Evaluator, galois_keys: seal.GaloisKeys, n_slot: int):
    """Computes the dot product between a ciphertext and a plaintext"""
    output = seal.Ciphertext()
    
    evaluator.multiply_plain(ctx, ptx, output)
    output = sum_reduce(output, evaluator, galois_keys, n_slot)
    
    return output

def dot_product(ctx: seal.Ciphertext, cty: seal.Ciphertext, 
                evaluator: seal.Evaluator, galois_keys: seal.GaloisKeys, 
                relin_keys: seal.GaloisKeys, n_slot: int):
    """Computes the dot product between a ciphertext and a plaintext"""
    output = seal.Ciphertext()
    
    evaluator.multiply(ctx, cty, output)
    evaluator.relinearize_inplace(output, relin_keys)
    evaluator.rescale_to_next_inplace(output)
    output = sum_reduce(output, evaluator, galois_keys, n_slot)
    
    return output

In [20]:
ctx_sum_reduce = sum_reduce(ctx, evaluator, galois_keys, len(x))
pt = seal.Plaintext()
decryptor.decrypt(ctx_sum_reduce, pt)
values = encoder.decode_double(pt)

he_sum = values[0]
regular_sum = sum(x)

np.abs(he_sum - regular_sum)

2.744962301903797e-07

In [21]:
ctx_dot_product = dot_product(ctx, cty, evaluator, galois_keys, relin_keys, len(x))

decryptor.decrypt(ctx_dot_product, pt)
values = encoder.decode_double(pt)

he_dot_product = values[0]
regular_dot_product = np.dot(x,y)

np.abs(he_dot_product - regular_dot_product)

2.647775545483455e-07