In [21]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp

from ckks.encoder import Encoder
from ckks.rlwe_jit import Encryptor

from ckks.utils import get_modulo

In [26]:
scale = 2**12
M = 32

n = M//2
q = 3
P = 1023
seed = 5
std = 1
p = 3
L = 17
h = 2

def mul(x, y):
    return jnp.polydiv(
        jnp.polymul(x, y),
        get_modulo(n)
    )[1][-n:]
    
def shift_mod(x, modulo):
    modulo_half = modulo // 2
    return jnp.mod(x + modulo_half, modulo) - modulo_half

encoder = Encoder(M, scale)
encryptor = Encryptor(n, q, p, L, P, std, h, seed)

sk, pk = encryptor.generate_keys()

m1 = jnp.array([111, 222, 333, 444, 555, 666, 777, 888])
m2 = jnp.flip(m1)

In [34]:
mul(m1, m2)

DeviceArray([  98568.,  283383.,  542124.,  862470., 1232100., 1638693.,
             2069928., 2513484., 2069928., 1638693., 1232100.,  862470.,
              542124.,  283383.,   98568.], dtype=float64)

In [27]:
e1 = encoder.encode(m1)
e2 = encoder.encode(m2)
e1

DeviceArray([2045952, -732260,       0,  -76547,       0,  -22835,
                   0,   -5763,       0,    5763,       0,   22836,
                   0,   76548,       0,  732260], dtype=int64)

In [41]:
encoder.decode()

DeviceArray([1.02195302e+09-5.28954473e+07j,
             1.02195302e+09-1.48300987e+08j,
             1.02195302e+09-2.18011752e+08j,
             1.02195302e+09-2.54547315e+08j,
             1.02195302e+09-2.54547315e+08j,
             1.02195302e+09-2.18011752e+08j,
             1.02195302e+09-1.48300987e+08j,
             1.02195302e+09-5.28954473e+07j], dtype=complex128)

In [38]:
mul(e1, e2)

DeviceArray([ 0.00000000e+00, -3.10068179e+12,  0.00000000e+00,
              6.52036043e+11, -2.04595200e+06,  1.46167476e+11,
             -2.04595200e+06,  4.63388484e+10,  0.00000000e+00,
              0.00000000e+00,  0.00000000e+00, -4.63388484e+10,
             -2.04595200e+06, -1.46167476e+11, -2.04595200e+06,
             -6.52036043e+11], dtype=float64)

In [37]:
encoder.decode(mul(e1, e2))

DeviceArray([ -96674.05969718 -19230.01100028j,
             -143423.40765994 -95832.80209768j,
             -123213.41971075-184401.56905511j,
              -48074.22224456-241685.0911216j ,
               48074.09024873-241685.1173772j ,
              123213.10104464-184401.781981j  ,
              143423.72632605 -95832.32518013j,
               96674.19169301 -19229.34741245j], dtype=complex128)

In [29]:
c1 = encryptor.encrypt(e1, pk)
c2 = encryptor.encrypt(e2, pk)

In [48]:
(mul(e1, e2),)

DeviceArray([ 0.00000000e+00, -3.10068179e+12,  0.00000000e+00,
              6.52036043e+11, -2.04595200e+06,  1.46167476e+11,
             -2.04595200e+06,  4.63388484e+10,  0.00000000e+00,
              0.00000000e+00,  0.00000000e+00, -4.63388484e+10,
             -2.04595200e+06, -1.46167476e+11, -2.04595200e+06,
             -6.52036043e+11], dtype=float64)

In [46]:
encryptor.decrypt(encryptor.mul(c1, c2, L), L, sk)

DeviceArray([ 2.45949510e+07, -1.67854716e+08, -4.21551000e+05,
             -1.32345860e+07, -1.43014760e+07,  1.35045607e+08,
              5.88604800e+06, -1.26976295e+08, -1.30887490e+07,
             -3.89947600e+06, -1.59377230e+07,  1.56273274e+08,
             -1.39299670e+07, -9.83145260e+07, -1.87181270e+07,
              1.65292520e+07], dtype=float64)

In [30]:
rec_e1 = encryptor.decrypt(c1, L, sk)
rec_e2 = encryptor.decrypt(c2, L, sk)
rec_e1

DeviceArray([ 2.045955e+06, -7.322600e+05,  5.000000e+00, -7.654400e+04,
             -6.000000e+00, -2.283700e+04, -6.000000e+00, -5.760000e+03,
              1.000000e+00,  5.767000e+03, -1.000000e+00,  2.283900e+04,
             -3.000000e+00,  7.655200e+04, -6.000000e+00,  7.322660e+05],            dtype=float64)

In [31]:
encoder.decode(rec_e1)

DeviceArray([110.99997082+4.71127635e-04j, 222.0036717 +3.43116368e-05j,
             332.99810763+4.29585616e-03j, 443.9972331 -6.12162142e-04j,
             554.99916798+1.13364777e-04j, 666.00008202-1.56714894e-04j,
             777.00313994+2.77310877e-03j, 888.00448618+6.43489775e-03j],            dtype=complex128)

In [32]:
encoder.decode(rec_e2)

DeviceArray([887.99975516+4.71127635e-04j, 777.00346502+3.43116359e-05j,
             665.99805099+4.29585616e-03j, 554.99731665-6.12162142e-04j,
             443.99908444+1.13364777e-04j, 333.00013866-1.56714893e-04j,
             222.00334662+2.77310877e-03j, 111.00470184+6.43489775e-03j],            dtype=complex128)