In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
import jax
from jax import jit
import jax.numpy as jnp

from polynomials.PolynomialRing import make_polynomial_ring
from ckks.rlwe_ckks import RLWECKKS
from ckks.utils import *

In [61]:
n = 4
q = 671082899
P = 1023
seed = 3
std = 1

In [62]:
rlwe = RLWECKKS(n, q, P, std, seed=seed)

sk, pk = rlwe.generate_keys()

m1 = jnp.array([1231,2314,3145,4156])  # plaintext
m2 = jnp.array([3021,1432,5143,5654])

## Add & Mul

In [63]:
c1 = rlwe.encrypt(m1, pk)
c2 = rlwe.encrypt(m2, pk)

In [64]:
rlwe.decrypt(c1, sk)

Polynomial ring: [1233. 2317. 3147. 4155.] (mod 671082899), reminder range: (-335541450, 335541449]

In [65]:
rlwe.decrypt(c2, sk)

Polynomial ring: [3023. 1435. 5145. 5653.] (mod 671082899), reminder range: (-335541450, 335541449]

In [66]:
c_add = rlwe.add(c1, c2)

In [67]:
rlwe.decrypt(c_add, sk)

Polynomial ring: [4256. 3752. 8292. 9808.] (mod 671082899), reminder range: (-335541450, 335541449]

In [68]:
c_mul = rlwe.mul(c1, c2)

In [69]:
rlwe.decrypt(c_mul, sk)

Polynomial ring: [35967653. 31524400. 30393807.  4306121.] (mod 671082899), reminder range: (-335541450, 335541449]

In [70]:
ring_polymul(m1, m2, get_modulo(4))[-4:]

DeviceArray([35919892., 31490632., 30402752.,  4352298.], dtype=float64)

## Relinear

In [71]:
sk

Polynomial ring: [0. 1. 1. 0.] (mod 671082899), reminder range: (-335541450, 335541449]

In [72]:
rlwe.evk

(Polynomial ring: [-2.61346704e+11  2.17189085e+11  2.04108754e+11  2.27014422e+11] (mod 686517805677), reminder range: (-343258902839, 343258902838],
 Polynomial ring: [ 319552596274  -92538173187 -332632925991  115443842514] (mod 686517805677), reminder range: (-343258902839, 343258902838])

In [73]:
cc = rlwe.relinear(c_mul)

In [74]:
rlwe.decrypt(cc, sk)

Polynomial ring: [36018641. 31519389. 30796087.  4089999.] (mod 671082899), reminder range: (-335541450, 335541449]