In [34]:
import tenseal.sealapi as seal
import util
import numpy as np
import math

In this example we demonstrate evaluating a polynomial function
PI*x^3 + 0.4*x + 1
on encrypted floating-point input data x for a set of 4096 equidistant points
in the interval [0, 1]. This example demonstrates many of the main features
of the CKKS scheme, but also the challenges in using it.
We start by setting up the CKKS scheme.

In [2]:
parms = seal.EncryptionParameters(seal.SCHEME_TYPE.CKKS)

We saw in `2_encoders.cpp' that multiplication in CKKS causes scales
in ciphertexts to grow. The scale of any ciphertext must not get too close
to the total size of coeff_modulus, or else the ciphertext simply runs out of
room to store the scaled-up plaintext. The CKKS scheme provides a `rescale'
functionality that can reduce the scale, and stabilize the scale expansion.
Rescaling is a kind of modulus switch operation (recall `3_levels.cpp').
As modulus switching, it removes the last of the primes from coeff_modulus,
but as a side-effect it scales down the ciphertext by the removed prime.
Usually we want to have perfect control over how the scales are changed,
which is why for the CKKS scheme it is more common to use carefully selected
primes for the coeff_modulus.
More precisely, suppose that the scale in a CKKS ciphertext is S, and the
last prime in the current coeff_modulus (for the ciphertext) is P. Rescaling
to the next level changes the scale to S/P, and removes the prime P from the
coeff_modulus, as usual in modulus switching. The number of primes limits
how many rescalings can be done, and thus limits the multiplicative depth of
the computation.
It is possible to choose the initial scale freely. One good strategy can be
to is to set the initial scale S and primes P_i in the coeff_modulus to be
very close to each other. If ciphertexts have scale S before multiplication,
they have scale S^2 after multiplication, and S^2/P_i after rescaling. If all
P_i are close to S, then S^2/P_i is close to S again. This way we stabilize the
scales to be close to S throughout the computation. Generally, for a circuit
of depth D, we need to rescale D times, i.e., we need to be able to remove D
primes from the coefficient modulus. Once we have only one prime left in the
coeff_modulus, the remaining prime must be larger than S by a few bits to
preserve the pre-decimal-point value of the plaintext.
Therefore, a generally good strategy is to choose parameters for the CKKS
scheme as follows:
    (1) Choose a 60-bit prime as the first prime in coeff_modulus. This will
        give the highest precision when decrypting;
    (2) Choose another 60-bit prime as the last element of coeff_modulus, as
        this will be used as the special prime and should be as large as the
        largest of the other primes;
    (3) Choose the intermediate primes to be close to each other.
We use CoeffModulus::Create to generate primes of the appropriate size. Note
that our coeff_modulus is 200 bits total, which is below the bound for our
poly_modulus_degree: CoeffModulus::MaxBitCount(8192) returns 218.

In [16]:
poly_modulus_degree = 8192
parms.set_poly_modulus_degree(poly_modulus_degree)
parms.set_coeff_modulus(seal.CoeffModulus.Create(poly_modulus_degree, [60, 40, 40, 60]))
print('Coefficient modulus count: ' + str(seal.CoeffModulus.MaxBitCount(8192, seal.SEC_LEVEL_TYPE.TC128)))

Coefficient modulus count: 218


We choose the initial scale to be 2^40. At the last level, this leaves us
60-40=20 bits of precision before the decimal point, and enough (roughly
10-20 bits) of precision after the decimal point. Since our intermediate
primes are 40 bits (in fact, they are very close to 2^40), we can achieve
scale stabilization as described above.

In [19]:
scale = pow(2.0, 40)

context = seal.SEALContext(parms,True,seal.SEC_LEVEL_TYPE.TC128)

Generate keys using 'SEAL::KeyGenerator'

In [25]:

keygen = seal.KeyGenerator(context)
secret_key = keygen.secret_key()

public_key = seal.PublicKey()
keygen.create_public_key(public_key)

relin_keys = seal.RelinKeys()
keygen.create_relin_keys(relin_keys)

galois_keys = seal.GaloisKeys()
keygen.create_galois_keys()


<_sealapi_cpp.SerializableGaloisKeys at 0x7f80111a1cb0>

Generate encryptors

In [84]:
encryptor = seal.Encryptor(context, public_key)
evaluator = seal.Evaluator(context)
decryptor = seal.Decryptor(context, secret_key)

encoder = seal.CKKSEncoder(context)
slot_count = seal.CKKSEncoder.slot_count(encoder)

print('Number of slots: ' + str(slot_count))


Number of slots: 4096


Create a vector with 4096 equidistant points in [0,1]

In [30]:
input = np.linspace(0,1,slot_count)


In [31]:
plain_coeff3 = seal.Plaintext()
encoder.encode(3.14159265, scale, plain_coeff3)

plain_coeff1 = seal.Plaintext()
encoder.encode(0.3, scale, plain_coeff1)

plain_coeff0 = seal.Plaintext()
encoder.encode(1.0, scale, plain_coeff0)



In [56]:
x_plain = seal.Plaintext()
encoder.encode(input, scale, x_plain)
print('%-----------------------------%%-----------------------------%')
x1_encrypted = seal.Ciphertext()
encryptor.encrypt(x_plain, x1_encrypted)

x3_encrypted = seal.Ciphertext()
evaluator.square(x1_encrypted, x3_encrypted)
evaluator.relinearize_inplace(x3_encrypted, relin_keys)
print('%-----------------------------%%-----------------------------%')
print('\t+Scale of x^2 before rescale: ' + str(math.log2(x3_encrypted.scale)) + ' bits.')
print('\t+Rescale x^2.')
evaluator.rescale_to_next_inplace(x3_encrypted)
print('\t+Scale of x^2 before rescale: ' + str(math.log2(x3_encrypted.scale)) + ' bits.')

print('%-----------------------------%%-----------------------------%')
print('\t+Compute and rescale PI*x')
x1_encrypted_coeff3 = seal.Ciphertext()
evaluator.multiply_plain(x1_encrypted, plain_coeff3, x1_encrypted_coeff3)
print('\t+Scale of PI*x before rescale: ' + str(math.log2(x1_encrypted_coeff3.scale)) + ' bits.')
evaluator.rescale_to_next_inplace(x1_encrypted_coeff3)
print('\t+Scale of PI*x after rescale: ' + str(math.log2(x1_encrypted_coeff3.scale)) + ' bits.')

print('%-----------------------------%%-----------------------------%')
print('\t+Compute, relinearize and rescale (PI*x)*x^2')
evaluator.multiply_inplace(x3_encrypted, x1_encrypted_coeff3)
evaluator.relinearize_inplace(x3_encrypted, relin_keys)
print('\t+Scale of PI*x^3 before rescale: ' + str(math.log2(x3_encrypted.scale)) + ' bits.')
evaluator.rescale_to_next_inplace(x3_encrypted)
print('\t+Scale of PI*x^3 after rescale: ' + str(math.log2(x3_encrypted.scale)) + ' bits.')

%-----------------------------%%-----------------------------%
Scale of x^2 before rescale: 80.0 bits.
%-----------------------------%%-----------------------------%
Rescale x^2.
Scale of x^2 before rescale: 40.00000019347918 bits.
%-----------------------------%%-----------------------------%
Compute and rescale PI*x
Scale of PI*x before rescale: 80.0 bits.
Scale of PI*x after rescale: 40.00000019347918 bits.
%-----------------------------%%-----------------------------%
Compute, relinearize and rescale (PI*x)*x^2
Scale of PI*x^3 before rescale: 80.00000038695836 bits.
Scale of PI*x^3 after rescale: 40.00000135435979 bits.


Time to compute the degree 1 term. 

In [57]:
print('%-----------------------------%%-----------------------------%')
print('\t+Compute and rescale 0.4*x.')
evaluator.multiply_plain_inplace(x1_encrypted, plain_coeff1)
print('\t+Scale of 0.4*x before rescale: ' + str(math.log2(x1_encrypted.scale)) + ' bits.')
evaluator.rescale_to_next_inplace(x1_encrypted)
print('\t+Scale of 0.4*v after rescale: ' + str(math.log2(x1_encrypted.scale)) + ' bits.')


%-----------------------------%%-----------------------------%
	+Compute and rescale 0.4*x.
	+Scale of 0.4*x before rescale: 80.0 bits.
	+Scale of 0.4*v after rescale: 40.00000019347918 bits.


In [69]:
print('%-----------------------------%%-----------------------------%')
print("Parameters used by all three terms are different.")
print("\t+ Modulus chain index for x3_encrypted: " +  str(context.get_context_data(x3_encrypted.parms_id()).chain_index()))
print("\t+ Modulus chain index for x1_encrypted: " + str(context.get_context_data(x1_encrypted.parms_id()).chain_index()))
print("\t+ Modulus chain index for plain_coeff0: " + str(context.get_context_data(plain_coeff0.parms_id()).chain_index()))

%-----------------------------%%-----------------------------%
Parameters used by all three terms are different.
	+ Modulus chain index for x3_encrypted: 0
	+ Modulus chain index for x1_encrypted: 1
	+ Modulus chain index for plain_coeff0: 2


Note that this means that
- product x^2 has scale 2^80 and is at level 2
- product PI*x has scale 2^80 and is at level 2
- both were rescaled down to scale 2^80/P_2 and level 1
- product 0.4*x has scale 2^80
- it was rescaled down to scale 2^80/P_2 and level 1
- the constant term 1 has scale 2^40 and is at level 2

In [71]:
print('The exact scales of all three terms are different: ')
print('\t+ Exact scale in PI*x^3 : ' + str(x3_encrypted.scale))
print('\t+ Exact scale in 0.4*x \t: ' + str(x1_encrypted.scale))
print('\t+ Exact scale in 1 \t: ' + str(plain_coeff0.scale))

The exact scales of all three terms are different: 
	+ Exact scale in PI*x^3 : 1099512659965.7515
	+ Exact scale in 0.4*x 	: 1099511775231.0198
	+ Exact scale in 1 	: 1099511627776.0


In [78]:
print('%-----------------------------%%-----------------------------%')
print('Normalize scales to 2^40')
x3_encrypted.scale = pow(2.0,40)
x1_encrypted.scale = pow(2.0,40)

%-----------------------------%%-----------------------------%
Normalize scales to 2^40


In [79]:
print('%-----------------------------%%-----------------------------%')
print("Normalize encryption paramters to the lowest level.")
last_parms_id = x3_encrypted.parms_id()
evaluator.mod_switch_to_inplace(x1_encrypted, last_parms_id)
evaluator.mod_switch_to_inplace(plain_coeff0, last_parms_id)

%-----------------------------%%-----------------------------%
Normalize encryption paramters to the lowest level.


Now all three ciphertexts should be compatible to be added.

In [80]:
print('%-----------------------------%%-----------------------------%')
print('Compute PI*x^3+0.4*x+1.')
encrypted_result = seal.Ciphertext()
evaluator.add(x3_encrypted, x1_encrypted, encrypted_result)
evaluator.add_plain_inplace(encrypted_result, plain_coeff0)

%-----------------------------%%-----------------------------%
Compute PI*x^3+0.4*x+1.


Print the true result.

In [92]:
plain_result = seal.Plaintext()
print('%-----------------------------%%-----------------------------%')
print('Decrypt and decode result.')
print('\t+Expected result')
true_result = (3.14159254*input**2 + 0.4)*input+1
print(true_result[0:-1:1000])

#Decrypt, decode and print result
decryptor.decrypt(encrypted_result, plain_result)
result = encoder.decode_double(plain_result)
print(result[0:-1:1000])

%-----------------------------%%-----------------------------%
Decrypt and decode result.
	+Expected result
[1.         1.14342979 1.5613577  2.52828187 4.31870044]
[1.000000000343631, 1.1190098167856988, 1.5125180285499003, 2.4550230407647304, 4.221023223123293]
