In [13]:
import jax.numpy as jnp
from ckks.encryptor import Encryptor
from ckks.encoder import Encoder
from ckks.utils import get_modulo

In [62]:
M = 32
scale = 2**10

In [63]:
message = jnp.array([777,666,555,444,333,222,111,0])
encoder = Encoder(M, scale, 0)
encryptor = Encryptor(13, M, 256, seed=0)

In [64]:
code = encoder.encode(message)

In [65]:
jnp.around(encoder.decode(code)).astype(int)

DeviceArray([777, 666, 555, 444, 333, 222, 111,   0], dtype=int32)

In [66]:
cipher = encryptor.encrypt(code)
# cipher

In [67]:
unencrypted = encryptor.decrypt(cipher)
# encoder.decode(unencrypted)

In [68]:
jnp.around(encoder.decode(unencrypted.astype(int))).astype(int)

DeviceArray([777, 666, 555, 444, 333, 222, 111,   0], dtype=int32)

In [69]:
m1 = jnp.arange(111,999,111)
m2 = jnp.arange(999,111,-111)
m1, m2

(DeviceArray([111, 222, 333, 444, 555, 666, 777, 888], dtype=int32),
 DeviceArray([999, 888, 777, 666, 555, 444, 333, 222], dtype=int32))

In [70]:
m1 = jnp.arange(1, 9, 1)
m2 = jnp.arange(9, 1,-1)
m1, m2

(DeviceArray([1, 2, 3, 4, 5, 6, 7, 8], dtype=int32),
 DeviceArray([9, 8, 7, 6, 5, 4, 3, 2], dtype=int32))

In [71]:
e1, e2 = encoder.encode(m1), encoder.encode(m2)

In [72]:
c1, c2 = encryptor.encrypt(e1), encryptor.encrypt(e2)
# c1, c2 = Cipher(c1, modulo=encryptor.modulo), Cipher(c2, modulo=encryptor.modulo)

In [74]:
encoder.decode(encryptor.decrypt(c1).astype(int)[-16:]), \
encoder.decode(encryptor.decrypt(c2).astype(int)[-16:])

(DeviceArray([1.000067 +0.00100189j, 1.9995699+0.00072992j,
              2.9988773-0.00014907j, 3.999386 +0.00149509j,
              5.000607 +0.00149519j, 6.0011153-0.00014973j,
              7.0004206+0.00073349j, 7.9999185+0.00100619j],            dtype=complex64),
 DeviceArray([8.998018 +0.0010035j , 7.998806 +0.00073636j,
              7.0000377-0.00014234j, 6.0002327+0.00150545j,
              4.999774 +0.00150531j, 3.9999697-0.00014192j,
              3.0012035+0.00073266j, 2.0019972+0.00099885j],            dtype=complex64))

In [32]:
def cipheradd(cipher1, cipher2, modulo):
    return [
        jnp.polydiv(jnp.polyadd(cipher1[0],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polyadd(cipher1[1],cipher2[1]), modulo)[1],
    ]

In [33]:
raw_cipher = cipheradd(c1, c2, encryptor.modulo)

In [34]:
m1 + m2

DeviceArray([10, 10, 10, 10, 10, 10, 10, 10], dtype=int32)

In [35]:
encoder.decode(encryptor.decrypt(raw_cipher)).astype(int)

DeviceArray([ 9,  9,  9,  9, 10, 10, 10, 10], dtype=int32)

In [36]:
def ciphermul(cipher1, cipher2, modulo):
    return [
        jnp.polydiv(jnp.polymul(cipher1[0],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polymul(cipher1[0],cipher2[1]), modulo)[1] + \
        jnp.polydiv(jnp.polymul(cipher1[1],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polymul(cipher1[1],cipher2[1]), modulo)[1],
    ]

In [37]:
raw_cipher = ciphermul(c1, c2, encryptor.modulo)

In [45]:
s = encryptor.s
ss = jnp.polydiv(jnp.polymul(s,s), encryptor.modulo)[1]

In [57]:
x = raw_cipher[0] + jnp.polydiv(jnp.polymul(raw_cipher[1], encryptor.s), encryptor.modulo)[1] + \
    jnp.polydiv(jnp.polymul(raw_cipher[2], ss), encryptor.modulo)[1]

In [59]:
encoder.decode(x).astype(int)

  return _convert_element_type(operand, new_dtype, weak_type=False)


DeviceArray([  8314,   6129,  -8019, -22166, -23103,  -9195,   8029,
              14787], dtype=int32)