Test comparison of two scalars using SEAL API from tenseal module. Parameters and setup is explained in 5_ckks_basics on the SEAL github page. 

In [108]:
import tenseal.sealapi as seal
import util
from math import log2,ceil

In [109]:
parms = seal.EncryptionParameters(seal.SCHEME_TYPE.CKKS)
poly_modulus_degree = 8192
parms.set_poly_modulus_degree(poly_modulus_degree)
parms.set_coeff_modulus(seal.CoeffModulus.Create(poly_modulus_degree, [49, 30,30,30,30, 49]))
scale = pow(2.0, 30)

context = seal.SEALContext(parms,True,seal.SEC_LEVEL_TYPE.TC128)

util.print_parameters(context)

print("|\t Max Bit Count: " + str(seal.CoeffModulus.MaxBitCount(poly_modulus_degree, seal.SEC_LEVEL_TYPE.TC128)))

/
|Encryption parameters: 
|	scheme: CKKS
|	poly_modulus_degree: 8192
|	coeff_modulus_size: 218 (49 30 30 30 30 49) bits
|	 Max Bit Count: 218


Generate keys using SEAL keygenerator. Generate encoder, evaluator, encryptor and decryptor. 

In [110]:

keygen = seal.KeyGenerator(context)
secret_key = keygen.secret_key()

public_key = seal.PublicKey()
keygen.create_public_key(public_key)

relin_keys = seal.RelinKeys()
keygen.create_relin_keys(relin_keys)

galois_keys = seal.GaloisKeys()
keygen.create_galois_keys()

encryptor = seal.Encryptor(context, public_key)
evaluator = seal.Evaluator(context)
decryptor = seal.Decryptor(context, secret_key)

encoder = seal.CKKSEncoder(context)
slot_count = seal.CKKSEncoder.slot_count(encoder)

print('Number of slots: ' + str(slot_count))

Number of slots: 4096


Let $x$ be the encrypted difference between two number in $a,b\in[0,1]$, such that $x\in[-1,1]$.

In [134]:
a=0.2;b=0.8
x = a-b
x_plain = seal.Plaintext()
encoder.encode(x,scale, x_plain)

x_enc = seal.Ciphertext()
encryptor.encrypt(x_plain, x_enc)

In [112]:
# Worst way of copying a ciphertext - multiply by plaintext 1. Consumes one ciphertext level. 
def bad_copy(cipher : seal.Ciphertext,scale : int, evaluator  : seal.Evaluator, encoder : seal.CKKSEncoder):
    dummy_plain = seal.Plaintext()
    encoder.encode(1, scale, dummy_plain)

    copy = seal.Ciphertext()
    evaluator.multiply_plain(cipher, dummy_plain, copy)

    evaluator.rescale_to_next_inplace(copy)
    
    return copy 


Polynomial approximations of the comparison function using polynomials f_3(x), g_3(x) from "Efficient Sorting of Homomorphic Encrypted Data with k-way Sorting Network". Note that the argument to the function in this case will be the encrypted `difference`between $a$ and $b$, i.e., $a-b$. 

* $g_3(x) = (35x-35x^3+21x^5-5x^7)/2^4$

* $f_3(x) = (4589x-16577x^3+25614x^5-12860x^7)/2^{10}$

Comparison function is implemented in the paper as

* $(x>y) := (f_3^{(d_f)}\circ g_3^{(d_g)}(x-y) + 1)/2$

$d_g$ and $d_f$ are not specified in the paper. Will begin testing with $d_f=d_g=1$. 

Note that the highest degree term is $x^7 = (x^2)^2(x^2*)x$, requiring four multiplications. We will begin by calculating the different temrs, 
- $x^7$
- $x^5$
- $x^3$

In [113]:
# A function for controlling size and scale of ciphertext after multiplication
def relinearize_and_rescale_inplace(cipher : seal.Ciphertext, eval : seal.Evaluator, relin_keys : seal.RelinKeys):
    # Control size by relinearization
    evaluator.relinearize_inplace(cipher, relin_keys)
    # Control scale by rescaling
    evaluator.rescale_to_next_inplace(cipher)

# evaluator::exponentiate not supported for CKKS, write separate function for exponentiation, note that "res" 
# should be an encryption of 1 for this to work. 
def square_and_multiply(cipher : seal.Ciphertext, exp : int,  evaluator : seal.Evaluator, relin_keys : seal.RelinKeys,scale : int, encoder: seal.CKKSEncoder, res : seal.Ciphertext):  
    if exp==0:
        raise ValueError("Exponent cannot be zero --> transparent ciphertext.")
    binary_exp = bin(exp)[2:]
    for char in binary_exp[::-1]:
        if char=='1':
            if res.data() == None:
                raise ValueError("Ciphertext res must be initialized to 1.")
            else:
                evaluator.multiply_inplace(res, cipher)
                relinearize_and_rescale_inplace(res, evaluator, relin_keys)
                evaluator.square_inplace(cipher)
                relinearize_and_rescale_inplace(cipher, evaluator, relin_keys)
        else:
            evaluator.square_inplace(cipher)
            relinearize_and_rescale_inplace(cipher,evaluator, relin_keys)
            evaluator.mod_switch_to_next_inplace(res)

# Divide by power of 2, divides cipher by 2**power
def divByPo2(cipher: seal.Ciphertext, power : int):
    cipher.scale = cipher.scale*(2**power)



In [130]:
dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
x7_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x7_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x7_enc)
exponent = 4
square_and_multiply(x_copy, exponent, evaluator, relin_keys,scale, encoder, x7_enc)

plain_result = seal.Plaintext()
decryptor.decrypt(x7_enc, plain_result)
result = encoder.decode_double(plain_result)[0]
print(f"Exponent is {exponent} leading to circuit with depth {ceil(log2(exponent))}")
print(f"Encrypted result is: {result:.6E}")
print(f"True result is: {x**exponent:.6E}")
print(f"Error is: {(result-x**exponent):.4E}")

ValueError: encrypted_ntt and plain_ntt parameter mismatch

Now we are redo to calculate all odd powers up to 7 using the `square_and_multiply` algorithm on "copied" ciphertexts, thereby avoiding changing the original. After calculating the powers, check the `parms_id()`. 

In [137]:
# First x^7
dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
x7_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x7_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x7_enc)
square_and_multiply(x_copy, 7, evaluator, relin_keys,scale, encoder, x7_enc)

print("x^7 chain_index():\t" + str(context.get_context_data(x7_enc.parms_id()).chain_index()))

# Then x^5 
dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
x5_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x5_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x5_enc)
square_and_multiply(x_copy, 5, evaluator, relin_keys,scale, encoder, x5_enc)

print("x^5 chain_index():\t" + str(context.get_context_data(x7_enc.parms_id()).chain_index()))

# Now x^3
dummy_plain = seal.Plaintext()
encoder.encode(1, scale, dummy_plain)
x3_enc = seal.Ciphertext()
encryptor.encrypt(dummy_plain, x3_enc)
# Have to waste one level to make square_and_multiply compatible 
# copied ciphertext
x_copy = bad_copy(x_enc, scale, evaluator, encoder)
evaluator.mod_switch_to_next_inplace(x3_enc)
square_and_multiply(x_copy, 3, evaluator, relin_keys,scale, encoder, x3_enc)

print("x^3 chain_index():\t" + str(context.get_context_data(x7_enc.parms_id()).chain_index()))

# Finally, use the copy of x for calculations. 

print("x_copy chain_index():\t" + str(context.get_context_data(x7_enc.parms_id()).chain_index()))

x^7 chain_index():	0
x^5 chain_index():	0
x^3 chain_index():	0
x_copy chain_index():	0


In [128]:
# Test bad_copy
x_copy = bad_copy(x_enc, scale, evaluator, encoder)

print("First data level parms_id():")
print(context.first_parms_id())
print("Original parms_id():")
print(x_enc.parms_id())
# Note that the parms_id() of the copy shows that this ciphertext
# is a level lower than the original. 
print("Copy parms_id():")
print(x_copy.parms_id())

# Print scale of original for later comparison
print("\nOriginal scale")
print(x_enc.scale)

# Copy is at different level from original, this
# can be changed with modulus switching, effectively
# wasting one ciphertext level.  
print("Modulus switched original parms_id():")
evaluator.mod_switch_to_next_inplace(x_enc)
print(x_enc.parms_id())

# Modulus switching does not change the scale
print("Modulus switched scale: ")
print(x_enc.scale)

# What are the specific primes in the modulus chain? 
print("\nPrimes in the coefficient modulus chain:")
modulus_chain = [modulus.value() for modulus in parms.coeff_modulus()]
print(modulus_chain)

# What is the exact scale of the copied ciphertext?
print("\nCopy scale")
print(x_copy.scale)

# What happens when we operate on original and copy, 
# for illustration purposes only.
test = seal.Ciphertext()
evaluator.multiply(x_copy, x_enc, test)

# Scale has now grown 
print("Product scale:")
print(test.scale)

# and we want to rescale
evaluator.rescale_to_next_inplace(test)
print("Product scale after rescale:")
print(test.scale)

# Note that the exact scale of the new rescaled product
# which happens to be the square of the original cipher is
# exactly the scale of the original (scale) times the scale
# copy (scale**2)/modulus_chain[-2], which is then rescaled
# (divided by the next prime in the chain), modulus_chain[-3]
print(scale**3/modulus_chain[-2]/modulus_chain[-3])

tmp_result = seal.Plaintext()
decryptor.decrypt(x_copy, tmp_result)
print(encoder.decode_double(tmp_result)[0])
decryptor.decrypt(x_enc, tmp_result)
print(encoder.decode_double(tmp_result)[0])

First data level parms_id():
[18132249767353198549, 718253398100440172, 8509878810699523384, 11073798577702196591]
Original parms_id():
[18132249767353198549, 718253398100440172, 8509878810699523384, 11073798577702196591]
Copy parms_id():
[16690438325808200324, 16265141331075099064, 4540980657725171339, 14048225904169892080]

Original scale
1099511627776.0
Modulus switched original parms_id():
[16690438325808200324, 16265141331075099064, 4540980657725171339, 14048225904169892080]
Modulus switched scale: 
1099511627776.0

Primes in the coefficient modulus chain:
[562949952798721, 1073430529, 1073479681, 1073643521, 1073692673, 562949952847873]

Copy scale
1099561960704.0117
Product scale:
1.208981161254238e+24
Product scale after rescale:
1126054540084388.0
1073889293.751133
-0.0005859369418990447
-0.0005859367152752603


In [127]:
# Test divByPo2
# Fresh ciphertext 
x_enc = seal.Ciphertext()
encryptor.encrypt(x_plain, x_enc)

power = 10
print("Ciphertext level before divByPo2: " + str(context.get_context_data(x_enc.parms_id()).chain_index()))
divByPo2(x_enc, power)
test_plain = seal.Plaintext()
decryptor.decrypt(x_enc, test_plain)
# The divByPo2 does not change the level of the ciphertext.
print("Ciphertext level after divByPo2: " + str(context.get_context_data(x_enc.parms_id()).chain_index()))

print("\ndivByPo2 result: ")
print(encoder.decode_double(test_plain)[0])

print("True value: ")
print(x/2**power)


Ciphertext level before divByPo2: 4
Ciphertext level after divByPo2: 4

divByPo2 result: 
-0.0005859367152752603
True value: 
-0.0005859375000000001
