In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

from jax.config import config
config.update("jax_enable_x64", True)
import jax
from jax import jit
import jax.numpy as jnp

from polynomials.PolynomialRing import make_polynomial_ring
from ckks.rlwe_ckks import RLWECKKS
from ckks.utils import *

In [68]:
n = 4
q = 671082899
P = 1023
seed = 3
std = 1

In [124]:
rlwe = RLWECKKS(n, q, P, std, seed=seed)

sk, pk = rlwe.generate_keys()

m1 = jnp.array([-1231,23,3145,-4156])  # plaintext
m2 = jnp.array([3021,-1432,51435,5654])

## Add & Mul

In [125]:
c1 = rlwe.encrypt(m1, pk)
c2 = rlwe.encrypt(m2, pk)

In [126]:
rlwe.decrypt(c1, sk)

Polynomial ring: [-1228.    26.  3147. -4158.] (mod 671082899), reminder range: (-335541450, 335541449]

In [127]:
rlwe.decrypt(c2, sk)

Polynomial ring: [ 3024. -1429. 51437.  5652.] (mod 671082899), reminder range: (-335541450, 335541449]

In [128]:
c_add = rlwe.add(c1, c2)

In [129]:
rlwe.decrypt(c_add, sk)

Polynomial ring: [ 1796. -1403. 54584.  1494.] (mod 671082899), reminder range: (-335541450, 335541449]

In [130]:
c1, c2

((Polynomial ring: [-1.82931753e+08 -1.35543333e+08 -1.81941724e+08  1.39344490e+08] (mod 671082899), reminder range: (-335541450, 335541449],
  Polynomial ring: [ 2.90549755e+08 -1.51201107e+08  3.34131632e+08 -1.98588273e+08] (mod 671082899), reminder range: (-335541450, 335541449]),
 (Polynomial ring: [-1.82927501e+08 -1.35544788e+08 -1.81893434e+08  1.39354300e+08] (mod 671082899), reminder range: (-335541450, 335541449],
  Polynomial ring: [ 2.90549755e+08 -1.51201107e+08  3.34131632e+08 -1.98588273e+08] (mod 671082899), reminder range: (-335541450, 335541449]))

In [131]:
c_mul = rlwe.mul(c1, c2)
c_mul

(Polynomial ring: [ 2.80076640e+08 -1.57477034e+08  6.83146500e+07  2.73488170e+08] (mod 671082899), reminder range: (-335541450, 335541449],
 Polynomial ring: [ 1.11563170e+08 -1.89286256e+08 -1.80023060e+08 -9.89432250e+07] (mod 671082899), reminder range: (-335541450, 335541449],
 Polynomial ring: [-2.67162012e+08  1.99275260e+08 -7.56586580e+07  2.73069055e+08] (mod 671082899), reminder range: (-335541450, 335541449])

In [132]:
c_mul

(Polynomial ring: [ 2.80076640e+08 -1.57477034e+08  6.83146500e+07  2.73488170e+08] (mod 671082899), reminder range: (-335541450, 335541449],
 Polynomial ring: [ 1.11563170e+08 -1.89286256e+08 -1.80023060e+08 -9.89432250e+07] (mod 671082899), reminder range: (-335541450, 335541449],
 Polynomial ring: [-2.67162012e+08  1.99275260e+08 -7.56586580e+07  2.73069055e+08] (mod 671082899), reminder range: (-335541450, 335541449])

In [133]:
rlwe.decrypt(c_mul, sk)

Polynomial ring: [-2.26741110e+07  1.71674500e+08 -1.97921595e+08  3.01842570e+07] (mod 671082899), reminder range: (-335541450, 335541449]

In [134]:
ring_polymul(m1, m2, get_modulo(4))[1][-4:]

DeviceArray([-2.28359850e+07,  1.71563360e+08, -1.97814305e+08,
              3.03503520e+07], dtype=float64)

In [145]:
def mul(x, y):
    return shift_mod(jnp.polydiv(
        jnp.polymul(x, y),
        get_modulo(4) 
    )[1][-4:], q)

def add(x, y):
    return shift_mod(jnp.polyadd(
        x, y
    ), q)

## Relinear

In [159]:
s = rlwe.PR(jnp.array([0,0,1,-1]))
s*s

Polynomial ring: [ 0  1 -2  1] (mod 671082899), reminder range: (-335541450, 335541449]

In [139]:
x = 1/rlwe.P * c_mul[2]
c_mul[0] + (x * rlwe.evk[0]).round(),

(Polynomial ring: [ 3.27298174e+08  3.33997132e+08 -9.45168020e+07  1.65997843e+08] (mod 671082899), reminder range: (-335541450, 335541449],)

In [142]:
cm0 = c_mul[0].coeffs
cm2 = c_mul[2].coeffs
evk = rlwe.evk[0].coeffs

In [151]:
x = cm2/rlwe.P

add(cm0, mul(x, evk))

DeviceArray([ 3.27295726e+08,  3.33987444e+08, -9.45275540e+07,
              1.65997707e+08], dtype=float64)

In [140]:
cc

(Polynomial ring: [ 3.27298174e+08  3.33997132e+08 -9.45168020e+07  1.65997843e+08] (mod 671082899), reminder range: (-335541450, 335541449],
 Polynomial ring: [-1.22374326e+08  2.58116669e+08  6.31049950e+07 -2.25235029e+08] (mod 671082899), reminder range: (-335541450, 335541449])

In [152]:
cc = rlwe.relinear(c_mul)
rlwe.decrypt(cc, sk)

Polynomial ring: [-2.25630610e+07  1.71867098e+08 -1.97377505e+08  3.02555000e+07] (mod 671082899), reminder range: (-335541450, 335541449]

## Rescaling

[polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polynomial_ring.<locals>.PolynomialRing,
 polynomials.PolynomialRing.make_polyn