In [1]:
from jax.config import config
config.update("jax_enable_x64", True)
import jax
from jax import jit
import jax.numpy as jnp

from polynomials.PolynomialRing import make_polynomial_ring
from ckks.rlwe_ckks import RLWECKKS
from ckks.utils import *

In [2]:
n = 4
q = 671082899
P = 1023
seed = 3
std = 1

In [12]:
rlwe = RLWECKKS(n, q, P, std, seed=seed)

sk, pk = rlwe.generate_keys()

m1 = jnp.array([-1231,2314,3145,-4156])  # plaintext
m2 = jnp.array([3021,-1432,5143,5654])

## Add & Mul

In [13]:
c1 = rlwe.encrypt(m1, pk)
c2 = rlwe.encrypt(m2, pk)

In [14]:
rlwe.decrypt(c1, sk)

Polynomial ring: [-1229.  2317.  3147. -4157.] (mod 671082899), reminder range: (-335541450, 335541449]

In [15]:
rlwe.decrypt(c2, sk)

Polynomial ring: [ 3023. -1429.  5145.  5653.] (mod 671082899), reminder range: (-335541450, 335541449]

In [16]:
c_add = rlwe.add(c1, c2)

In [17]:
rlwe.decrypt(c_add, sk)

Polynomial ring: [1794.  888. 8292. 1496.] (mod 671082899), reminder range: (-335541450, 335541449]

In [18]:
c_mul = rlwe.mul(c1, c2)

In [19]:
rlwe.decrypt(c_mul, sk)

Polynomial ring: [-12090211.  38944960. -12358323. -23378727.] (mod 671082899), reminder range: (-335541450, 335541449]

In [20]:
ring_polymul(m1, m2, get_modulo(4))[-4:]

DeviceArray([-12118088.,  38928334., -12345864., -23354388.], dtype=float64)

## Relinear

In [21]:
sk

Polynomial ring: [0. 1. 1. 0.] (mod 671082899), reminder range: (-335541450, 335541449]

In [22]:
rlwe.evk

(Polynomial ring: [-2.61346704e+11  2.17189085e+11  2.04108754e+11  2.27014422e+11] (mod 686517805677), reminder range: (-343258902839, 343258902838],
 Polynomial ring: [ 319552596274  -92538173187 -332632925991  115443842514] (mod 686517805677), reminder range: (-343258902839, 343258902838])

In [23]:
cc = rlwe.relinear(c_mul)

In [24]:
rlwe.decrypt(cc, sk)

Polynomial ring: [-12039223.  38939949. -11956043. -23594849.] (mod 671082899), reminder range: (-335541450, 335541449]