Test comparison of two scalars using SEAL API from tenseal module. Parameters and setup is explained in 5_ckks_basics on the SEAL github page. 

In [18]:
import tenseal.sealapi as seal
import util
from math import log2,ceil

Note that using larger bit sizes for the primes in the coefficient modulus chain is advantageous since it allows for scale stabilization. With a larger prime and scale, we improve precision. 

In [51]:
parms = seal.EncryptionParameters(seal.SCHEME_TYPE.CKKS)
poly_modulus_degree = 2**14
mod_pow = 35
parms.set_poly_modulus_degree(poly_modulus_degree)
parms.set_coeff_modulus(seal.CoeffModulus.Create(poly_modulus_degree, [60, mod_pow,mod_pow,mod_pow,mod_pow,mod_pow,mod_pow,60]))
scale = pow(2.0, mod_pow)

# What are the specific primes in the modulus chain? These help calculate exact scales 
# of rescaled ciphertexts later on in this notebook. 
modulus_chain = [modulus.value() for modulus in parms.coeff_modulus()]

primes = [coeff.value() for coeff in parms.coeff_modulus()]

context = seal.SEALContext(parms,True,seal.SEC_LEVEL_TYPE.TC128)

util.print_parameters(context)

print("|\t Max Bit Count: " + str(seal.CoeffModulus.MaxBitCount(poly_modulus_degree, seal.SEC_LEVEL_TYPE.TC128)))

/
|Encryption parameters: 
|	scheme: CKKS
|	poly_modulus_degree: 16384
|	coeff_modulus_size: 330 (60 35 35 35 35 35 35 60) bits
|	 Max Bit Count: 438


Generate keys using SEAL keygenerator. Generate encoder, evaluator, encryptor and decryptor. 

In [20]:

keygen = seal.KeyGenerator(context)
secret_key = keygen.secret_key()

public_key = seal.PublicKey()
keygen.create_public_key(public_key)

relin_keys = seal.RelinKeys()
keygen.create_relin_keys(relin_keys)

galois_keys = seal.GaloisKeys()
keygen.create_galois_keys()

encryptor = seal.Encryptor(context, public_key)
evaluator = seal.Evaluator(context)
decryptor = seal.Decryptor(context, secret_key)

encoder = seal.CKKSEncoder(context)
slot_count = seal.CKKSEncoder.slot_count(encoder)

print('Number of slots: ' + str(slot_count))

Number of slots: 8192


Let $x$ be the encrypted difference between two number in $a,b\in[0,1]$, such that $x\in[-1,1]$.

In [21]:
a=0.2;b=0.8
x = a-b
x_plain = seal.Plaintext()
encoder.encode(x,scale, x_plain)

x_enc = seal.Ciphertext()
encryptor.encrypt(x_plain, x_enc)

In [22]:
# Worst way of copying a ciphertext - multiply by plaintext 1. Consumes one ciphertext level. 
def bad_copy(cipher : seal.Ciphertext,scale : int, evaluator  : seal.Evaluator, encoder : seal.CKKSEncoder):
    dummy_plain = seal.Plaintext()
    encoder.encode(1, scale, dummy_plain)

    copy = seal.Ciphertext()
    evaluator.multiply_plain(cipher, dummy_plain, copy)

    evaluator.rescale_to_next_inplace(copy)
    
    return copy 


The cell below simply tests the copying method above and prints some information about the copied ciphertext. 

In [23]:
# Test bad_copy
x_copy = bad_copy(x_enc, scale, evaluator, encoder)

print("First data level parms_id():")
print(context.first_parms_id())
print("Original parms_id():")
print(x_enc.parms_id())
# Note that the parms_id() of the copy shows that this ciphertext
# is a level lower than the original. 
print("Copy parms_id():")
print(x_copy.parms_id())

# Print scale of original for later comparison
print("\nOriginal scale")
print(x_enc.scale)

# Copy is at different level from original, this
# can be changed with modulus switching, effectively
# wasting one ciphertext level.  
print("Modulus switched original parms_id():")
evaluator.mod_switch_to_next_inplace(x_enc)
print(x_enc.parms_id())

# Modulus switching does not change the scale
print("Modulus switched scale: ")
print(x_enc.scale)

# What are the specific primes in the modulus chain? 
print("\nPrimes in the coefficient modulus chain:")
print(modulus_chain)

# What is the exact scale of the copied ciphertext?
print("\nCopy scale")
print(x_copy.scale)

# What happens when we operate on original and copy, 
# for illustration purposes only.
test = seal.Ciphertext()
evaluator.multiply(x_copy, x_enc, test)

# Scale has now grown 
print("Product scale:")
print(test.scale)

# and we want to rescale
evaluator.rescale_to_next_inplace(test)
print("Product scale after rescale:")
print(test.scale)

# Note that the exact scale of the new rescaled product
# which happens to be the square of the original cipher is
# exactly the scale of the original (scale) times the scale
# copy (scale**2)/modulus_chain[-2] (the first prime), which 
# is then rescaled (divided by the next prime in the chain), 
# modulus_chain[-3]
print("scale^3/P_1/P_2")
print(scale**3/modulus_chain[-2]/modulus_chain[-3])

tmp_result = seal.Plaintext()
decryptor.decrypt(x_copy, tmp_result)
print("Decrypted copy:")
print(encoder.decode_double(tmp_result)[0])
decryptor.decrypt(x_enc, tmp_result)
print("Decrypted original:")
print(encoder.decode_double(tmp_result)[0])


First data level parms_id():
[8581331200408610277, 13302195022564177709, 1237030387993002687, 5771053831278510699]
Original parms_id():
[8581331200408610277, 13302195022564177709, 1237030387993002687, 5771053831278510699]
Copy parms_id():
[7388653974824800976, 12986223737403185069, 8737996544829265756, 18402466612834462340]

Original scale
34359738368.0
Modulus switched original parms_id():
[7388653974824800976, 12986223737403185069, 8737996544829265756, 18402466612834462340]
Modulus switched scale: 
34359738368.0

Primes in the coefficient modulus chain:
[1152921504606683137, 34357411841, 34357444609, 34357805057, 34358788097, 34359214081, 34359410689, 1152921504606748673]

Copy scale
34360066050.12501
Product scale:
1.1806028797894944e+21
Product scale after rescale:
34360590350.1252
scale^3/P_1/P_2
34360590350.1252
Decrypted copy:
-0.5999999507772056
Decrypted original:
-0.5999999513003557


Polynomial approximations of the comparison function using polynomials f_3(x), g_3(x) from "Efficient Sorting of Homomorphic Encrypted Data with k-way Sorting Network". Note that the argument to the function in this case will be the encrypted `difference`between $a$ and $b$, i.e., $a-b$. 

* $g_3(x) = (35x-35x^3+21x^5-5x^7)/2^4$

* $f_3(x) = (4589x-16577x^3+25614x^5-12860x^7)/2^{10}$

Comparison function is implemented in the paper as

* $(x>y) := (f_3^{(d_f)}\circ g_3^{(d_g)}(x-y) + 1)/2$

$d_g$ and $d_f$ are not specified in the paper. Will begin testing with $d_f=d_g=1$. 

Note that the highest degree term is $x^7 = (x^2)^2(x^2*)x$, requiring four multiplications. We will begin by calculating the different temrs, 
- $x^7$
- $x^5$
- $x^3$

In [24]:
# A function for controlling size and scale of ciphertext after multiplication
def relinearize_and_rescale_inplace(cipher : seal.Ciphertext, evaluator : seal.Evaluator, relin_keys : seal.RelinKeys):
    # Control size by relinearization
    evaluator.relinearize_inplace(cipher, relin_keys)
    # Control scale by rescaling
    evaluator.rescale_to_next_inplace(cipher)

# evaluator::exponentiate not supported for CKKS, write separate function for exponentiation, note that "res" 
# should be an encryption of 1 for this to work. 
def square_and_multiply(cipher : seal.Ciphertext, exp : int,  evaluator : seal.Evaluator, relin_keys : seal.RelinKeys,scale : int, encoder: seal.CKKSEncoder, res : seal.Ciphertext):  
    if exp==0:
        raise ValueError("Exponent cannot be zero --> transparent ciphertext.")
    binary_exp = bin(exp)[2:]
    for char in binary_exp[::-1]:
        if char=='1':
            if res.data() == None:
                raise ValueError("Ciphertext res must be initialized to 1.")
            else:
                #print("Current scale:\t" + str(res.scale))
                evaluator.multiply_inplace(res, cipher)
                relinearize_and_rescale_inplace(res, evaluator, relin_keys)
                evaluator.square_inplace(cipher)
                relinearize_and_rescale_inplace(cipher, evaluator, relin_keys)
                #print("Current scale:\t" + str(res.scale))
        else:
            evaluator.square_inplace(cipher)
            relinearize_and_rescale_inplace(cipher,evaluator, relin_keys)
            evaluator.mod_switch_to_next_inplace(res)

# Divide by power of 2, divides cipher by 2**power
def divByPo2(cipher: seal.Ciphertext, power : int):
    cipher.scale = cipher.scale*(2**power)



The cell below simply tests the function `square_and_multiply` above. Use this to test with different exponents to exemplify different levels, etc.

In [25]:
# Need fresh ciphertext
encryptor.encrypt(x_plain, x_enc)

dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
# Note that the ciphertext x7 is initialized as an encryption of 1,
# which is required by squared_and_multply 
x7_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x7_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x7_enc)
exponent = 4
square_and_multiply(x_copy, exponent, evaluator, relin_keys,scale, encoder, x7_enc)

plain_result = seal.Plaintext()
decryptor.decrypt(x7_enc, plain_result)
result = encoder.decode_double(plain_result)[0]
print(f"Exponent is {exponent} leading to circuit with depth {ceil(log2(exponent))}")
print(f"Encrypted result is: {result:.6E}")
print(f"True result is: {x**exponent:.6E}")
print(f"Error is: {(result-x**exponent):.4E}")

Exponent is 4 leading to circuit with depth 2
Encrypted result is: 1.296002E-01
True result is: 1.296000E-01
Error is: 1.9327E-07


The cell below tests the function `divByPo2`, resulting ciphertexts are not used in subsequent calculations. 

In [26]:
# Test divByPo2
# Fresh ciphertext 
encryptor.encrypt(x_plain, x_enc)

power = 10
print("Ciphertext level before divByPo2: " + str(context.get_context_data(x_enc.parms_id()).chain_index()))
divByPo2(x_enc, power)
test_plain = seal.Plaintext()
decryptor.decrypt(x_enc, test_plain)
# The divByPo2 does not change the level of the ciphertext.
print("Ciphertext level after divByPo2: " + str(context.get_context_data(x_enc.parms_id()).chain_index()))

print("\ndivByPo2 result: ")
print(encoder.decode_double(test_plain)[0])

print("True value: ")
print(x/2**power)


Ciphertext level before divByPo2: 6
Ciphertext level after divByPo2: 6

divByPo2 result: 
-0.000585937472349454
True value: 
-0.0005859375000000001


Now we are ready to calculate all odd powers up to 7 using the `square_and_multiply` algorithm on "copied" ciphertexts, thereby avoiding changing the original. After calculating the powers, check the `parms_id()`. 

In [52]:
# Fresh ciphertext:
encryptor.encrypt(x_plain, x_enc)

# First x^7
dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
x7_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x7_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x7_enc)
print("Original scale:\t" + str(scale))
print("x^7 scale bef.:\t" + str(x7_enc.scale))
square_and_multiply(x_copy, 7, evaluator, relin_keys,scale, encoder, x7_enc)

print("x^7 chain_index():\t" + str(context.get_context_data(x7_enc.parms_id()).chain_index()))
print("x^7 scale:\t" + str(x7_enc.scale))
# Then x^5 
x5_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x5_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x5_enc)
square_and_multiply(x_copy, 5, evaluator, relin_keys,scale, encoder, x5_enc)

print("x^5 chain_index():\t" + str(context.get_context_data(x5_enc.parms_id()).chain_index()))
print("x^5 scale:\t" + str(x5_enc.scale))

# Now x^3
x3_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x3_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x3_enc)
square_and_multiply(x_copy, 3, evaluator, relin_keys,scale, encoder, x3_enc)

#Obviously, we expect that the level of x^7 and x^5 is the same, since both require
# ceil(log2(7))=ceil(log2(5))=3 multiplications. However, x^3 should be at a higher, 
# since we only need 2 multiplications. 

print("x^3 chain_index():\t" + str(context.get_context_data(x3_enc.parms_id()).chain_index()))
print("x^3 scale:\t" + str(x3_enc.scale))

# Finally, use the copy of x for calculations. 
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
print("x_copy chain_index():\t" + str(context.get_context_data(x_copy.parms_id()).chain_index()))



Original scale:	34359738368.0
x^7 scale bef.:	34359738368.0
x^7 chain_index():	2
x^7 scale:	34367964214.10345
x^5 chain_index():	2
x^5 scale:	34365833833.954185
x^3 chain_index():	3
x^3 scale:	34362720405.22447
x_copy chain_index():	5


We are now ready to multiply by the plain coefficients of the polynomial. Note however, that we might as well use addition and `divByPo2` to avoid further multiplications.

In [36]:
def add_many_coeff(coeff: int, cipher: seal.Ciphertext,evaluator: seal.Evaluator, result: seal.Ciphertext):
    tmp = []
    for i in range(coeff):
        tmp.append(cipher)
    evaluator.add_many(tmp, result)

In [61]:
# Want to calculate g, i.e., we need 35x, 35x^3, 21x^5, 5x^7, all divided by 2^4


# PROBLEM - HAVE NOT MANAGED TO GET SCALE STABILIZATION
x_coeff = seal.Ciphertext()
add_many_coeff(35, x_copy, evaluator, x_coeff)
divByPo2(x_coeff, 4)

tmp_result = seal.Plaintext()
decryptor.decrypt(x_coeff, tmp_result)
print("35x/2^4")
util.print_info(x_coeff, decryptor, context, encoder, 35*x/2**4)

x3_coeff = seal.Ciphertext()
add_many_coeff(35, x3_enc, evaluator, x3_coeff)
divByPo2(x3_coeff, 4)
decryptor.decrypt(x3_coeff, tmp_result)
print("35x^3/2^4")
util.print_info(x3_coeff, decryptor, context, encoder, 35*x**3/2**4)

x5_coeff = seal.Ciphertext()
add_many_coeff(21, x5_enc, evaluator, x5_coeff)
divByPo2(x5_coeff, 4)
decryptor.decrypt(x5_coeff, tmp_result)
print("21x^5/2^4")
util.print_info(x5_coeff, decryptor, context, encoder, 21*x**5/2**4)

x7_coeff = seal.Ciphertext()
add_many_coeff(5, x7_enc, evaluator, x7_coeff)
divByPo2(x7_coeff, 4)
decryptor.decrypt(x7_coeff, tmp_result)
print("5x^7/2^4")
util.print_info(x7_coeff, decryptor, context, encoder, 5*x**7/2**4)

35x/2^4
	Enc. result:	-1.3125001755103025
	Plain result:	-1.3125000000000002
	Scale:	549761056802.0001
	Ch. ind:	5
35x^3/2^4
	Enc. result:	-0.47250007385581344
	Plain result:	-0.47250000000000025
	Scale:	549803526483.5916
	Ch. ind:	3
21x^5/2^4
	Enc. result:	-0.10206002736599178
	Plain result:	-0.10206000000000007
	Scale:	549853341343.267
	Ch. ind:	2
5x^7/2^4
	Enc. result:	-0.008748005707963065
	Plain result:	-0.00874800000000001
	Scale:	549887427425.6552
	Ch. ind:	2


With all powers we need, we are now ready to calculate the value of the polynomial. Note however that all ciphertexts have different scale and that we need to change the level of 35x and 35x^3 to be able to compute sums of the powers. This can be done in two ways, following `5_ckks_basics.ipynb`. I will first try the easiest method of simply setting the scale manually, since in all cases, the scale is alsmost equal to the original scale. 

In [62]:
# Note that 35x and 35x^3 are still at  different levels, which must be handled. 
evaluator.mod_switch_to_inplace(x_coeff, x7_coeff.parms_id())
evaluator.mod_switch_to_inplace(x3_coeff, x7_coeff.parms_id())
# Now manually set scale to the original (which is not true, and leads to loss in precision). 
# It is however, the simplest solution. 

print(f"Scale of 5x^7/2^4:\t{x7_coeff.scale}")

x_coeff.scale = x7_coeff.scale; x3_coeff.scale=x7_coeff.scale; x5_coeff.scale=x7_coeff.scale

# Can we still decrypt to the correct values after rescaling?
print("35x")
util.print_info(x_coeff, decryptor, context, encoder, 35*x/2**4)

print("35x^3")
util.print_info(x3_coeff, decryptor, context, encoder, 35*x**3/2**4)

Scale of 5x^7/2^4:	549887427425.6552
35x
	Enc. result:	-1.3121985474725357
	Plain result:	-1.3125000000000002
	Scale:	549887427425.6552
	Ch. ind:	2
35x^3
	Enc. result:	-0.47242798055208546
	Plain result:	-0.47250000000000025
	Scale:	549887427425.6552
	Ch. ind:	2


It is now possible to sum all terms together! The cell below calculates the result and prints some intermediate results. 

In [65]:
res = seal.Ciphertext()
# First (35x-35x^3)/2**4
evaluator.sub(x_coeff, x3_coeff, res)
print("35x-35x^3")
util.print_info(res, decryptor, context,encoder, (35*x-35*x**3)/2**4)
# Now add 21x^5/2**4
evaluator.add_inplace(res, x5_coeff)
print("35x-35x^3+21x^5")
util.print_info(res, decryptor, context, encoder, (35*x-35*x**3+21*x**5)/2**4)
# Finally add -5x^7/2**4

evaluator.sub_inplace(res, x7_coeff)

print("g(x) = 35x-35x^3+21x^5-5x^7")
util.print_info(res, decryptor, context, encoder, (35*x-35*x**3+21*x**5-5*x**7)/2**4)



35x-35x^3
	Enc. result:	-0.8397705669204507
	Plain result:	-0.84
	Scale:	549887427425.6552
	Ch. ind:	2
35x-35x^3+21x^5
	Enc. result:	-0.9418242678524709
	Plain result:	-0.94206
	Scale:	549887427425.6552
	Ch. ind:	2
g(x) = 35x-35x^3+21x^5-5x^7
	Enc. result:	-0.933076262144508
	Plain result:	-0.933312
	Scale:	549887427425.6552
	Ch. ind:	2
