In [24]:
import ckks
import numpy as np

In [25]:
M = 16 # Well use the roots of the M-th cyclotomic polynomial for encoding and decoding
N = M//2 # We are guaranteed the M-th cyclotomic is of degree M/2 if M is a power of 2

# XI is our complex root of unity, and want to work with the M-th
XI = np.exp(2 * np.pi * 1j / M)


In [26]:
# after changing the encoder out input vectors should be of size M/2 so if we increase the size of M the tests should still work
encoder = ckks.CKKSEncoder(M)

In [27]:
v = np.array([1,2,3,4])

In [28]:
encoder.pi_inverse(v)

array([1, 2, 3, 4, 4, 3, 2, 1])

In [29]:
p = encoder.encode(v)

In [30]:
print(p)

160.0 - 50.0·x + 0.0·x² - 3.0·x³ + 0.0·x⁴ + 4.0·x⁵ + 0.0·x⁶ + 50.0·x⁷


In [31]:
v_decoded = encoder.decode(p)

In [32]:
print(v_decoded) # we can the result of decoding is very close to original but with some error

[1.01458223+0.01443562j 2.00310646-0.00597943j 2.99689354-0.00597943j
 3.98541777+0.01443562j]


In [33]:
# distance from orig
np.linalg.norm(v_decoded - v)

0.030542827521940662

# Homomorphic operations
## Addition

In [34]:
v1 = np.array([1,2,3,4])
v2 = np.array([1,-2,3,-4])

In [35]:
p1 = encoder.encode(v1)
p2 = encoder.encode(v2)

In [36]:
v1_p_v2 = p1 + p2 # close to [0,0,0,0]
v1_m_v2 = p1 - p2 # close to 2*v1


In [37]:
print(np.round(encoder.decode(v1_p_v2)))
print(np.round(encoder.decode(v1_m_v2)))

[ 2.-0.j  0.+0.j  6.-0.j -0.-0.j]
[ 0.-0.j  4.+0.j -0.+0.j  8.+0.j]


## Multiplication
Need a polynomial modulus. I believe in CKKKS there is a far more complex renormalization operations that
ensures polynomial products are computed using the modulus accurately but for now well just perform mod
in a straight forward manner

In [38]:
from numpy.polynomial import Polynomial
poly_mod = Polynomial([1,0,0,0,0,0,0,0,1]) # X^8 + 1, so polynomial products will never exceed degree 7

In [39]:
v1_prod_v2 = p1 * p2 % poly_mod

In [40]:
print(np.round(encoder.decode(v1_prod_v2)))

[   65.-2.j  -255.-0.j   576.-3.j -1023.-4.j]


# Part 2 stuff

In [41]:
z = np.array([0,1])

# should double to [0,1,1,0]
z_pi = encoder.pi_inverse(z)
print(z_pi)

# should half to [0,1]
print(encoder.pi(z_pi))

[0 1 1 0]
[0 1 1 0]


In [42]:
print(encoder.sigma_R_basis)

[[ 1.00000000e+00+0.j          1.00000000e+00+0.j
   1.00000000e+00+0.j          1.00000000e+00+0.j
   1.00000000e+00+0.j          1.00000000e+00+0.j
   1.00000000e+00+0.j          1.00000000e+00+0.j        ]
 [ 9.23879533e-01+0.38268343j  3.82683432e-01+0.92387953j
  -3.82683432e-01+0.92387953j -9.23879533e-01+0.38268343j
  -9.23879533e-01-0.38268343j -3.82683432e-01-0.92387953j
   3.82683432e-01-0.92387953j  9.23879533e-01-0.38268343j]
 [ 7.07106781e-01+0.70710678j -7.07106781e-01+0.70710678j
  -7.07106781e-01-0.70710678j  7.07106781e-01-0.70710678j
   7.07106781e-01+0.70710678j -7.07106781e-01+0.70710678j
  -7.07106781e-01-0.70710678j  7.07106781e-01-0.70710678j]
 [ 3.82683432e-01+0.92387953j -9.23879533e-01-0.38268343j
   9.23879533e-01-0.38268343j -3.82683432e-01+0.92387953j
  -3.82683432e-01-0.92387953j  9.23879533e-01+0.38268343j
  -9.23879533e-01+0.38268343j  3.82683432e-01-0.92387953j]
 [-2.22044605e-16+1.j          3.33066907e-16-1.j
  -1.11022302e-15+1.j          1.27675648e

In [43]:
# check that linear combination of sigma basis is encoded as an integer polynomial
coords1 = [1,0,0,0]
coords2 = [1,1,1,1]
coords3 = [2,2,2,0]

In [44]:
# b1 = np.matmul(encoder.sigma_R_basis.T,coords1)
# b2 = np.matmul(encoder.sigma_R_basis.T,coords2)
# b3 = np.matmul(encoder.sigma_R_basis.T,coords3)

In [45]:
# print(b1)
# print(b2)
# print(b3)

In [46]:
# print(encoder.encode(b1),"\n\n")
# print(encoder.encode(b2),"\n\n")
# print(encoder.encode(b3),"\n\n")

In [47]:
# A non linear combination should not encode to a integer polynomial
# coords4 = [1.5,1.5,1.5,1.5]
# b4 = np.matmul(encoder.sigma_R_basis.T,coords4)
# print(encoder.encode(b4))

# Full encoder tests

In [48]:
encoder = ckks.CKKSEncoder(8)

In [49]:
z = np.array([3 +4j, 2 - 1j])
z

array([3.+4.j, 2.-1.j])

In [50]:
p = encoder.encode(z)
p

Polynomial([160.,  91., 160.,  45.], domain=[-1,  1], window=[-1,  1], symbol='x')

In [51]:
encoder.decode(p)

array([3.008233+4.00260191j, 1.991767-0.99739809j])