In [38]:
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial, reduce
from jax.lax import scan
from typing import List

from ckks.encoder import Encoder
from ckks.encryptor import Encryptor
from ckks.context import Context
from ckks.cipher import Cipher
from ckks.utils import *

## Encryptor

In [240]:
M = 4*2
SCALE = 64
P = 1024
q = 11
LAMBDA = 10
p = 2 
l = 3
L = 17
ql = (p**l) * q

In [241]:
ctx = Context(M, SCALE, q, p, l, L, P)
encoder = Encoder(M, SCALE)
encryptor = Encryptor(ctx)
print(encryptor.Q)
print(encryptor.pub_key)

1441792
[DeviceArray([ 440102., -526757.,  415383., -679460.], dtype=float32), DeviceArray([ 494704,  316153, -147435, -439674], dtype=int32)]


In [242]:
fake = jnp.array([111, -222, 333, 444])
encryptor.encrypt(fake),\
encryptor.decrypt(encryptor.encrypt(fake), L)

([DeviceArray([ 440213., -526979.,  415716., -679016.], dtype=float32),
  DeviceArray([ 494704,  316153, -147435, -439674], dtype=int32)],
 DeviceArray([ 111., -222.,  333.,  444.], dtype=float32))

In [243]:
m1 = jnp.array([111,222])
m2 = jnp.array([999,20])

encrypt = lambda x: encryptor.encrypt(encoder.encode(x))
decrypt = lambda x: encoder.decode(encryptor.decrypt(x, L)) # L ~ first time encryption

In [244]:
e1 = encoder.encode(m1)
e2 = encoder.encode(m2)

c1 = encryptor.encrypt(e1)
c2 = encryptor.encrypt(e2)

In [245]:
e2

DeviceArray([ 32608,  22152,      0, -22152], dtype=int32)

In [246]:
decrypt(c1), decrypt(c2)

(DeviceArray([110.99211-9.5367432e-06j, 222.00787+2.4795532e-05j], dtype=complex64),
 DeviceArray([998.9947  +7.6293945e-05j,  20.005432-1.8310547e-04j], dtype=complex64))

## Arithmetic

### ADD

In [247]:
def cipheradd(cipher1, cipher2):
    return [
        shift_mod(jnp.polyadd(cipher1[0],cipher2[0]), encryptor.Q) , 
        shift_mod(jnp.polyadd(cipher1[1],cipher2[1]), encryptor.Q),
    ]

In [248]:
mulc = cipheradd(c1, c2)

In [249]:
decrypt(mulc)

DeviceArray([1109.9868+7.6293945e-05j,  242.0133-1.6784668e-04j], dtype=complex64)

### Constant Multiplication

In [250]:
f1 = jnp.array([11,22,33,44])
f2 = jnp.array([0,0,0,1])

ef1 = encryptor.encrypt(f1)
ef2 = encryptor.encrypt(f2)
ef1

[DeviceArray([ 440113., -526735.,  415416., -679416.], dtype=float32),
 DeviceArray([ 494704,  316153, -147435, -439674], dtype=int32)]

In [251]:
def ciphermul_constant(cipher, constant):
    return [
        shift_mod(ring_polymul(cipher[0], constant, encryptor.modulo)[-4:], encryptor.Q),
        shift_mod(ring_polymul(cipher[1], constant, encryptor.modulo)[-4:], encryptor.Q),
    ]

In [252]:
ring_polymul(ef1[0], f2, get_modulo(4))[-4:]

DeviceArray([ 440113., -526735.,  415416., -679416.], dtype=float32)

In [253]:
mulc = ciphermul_constant(ef1, f2)
mulc

[DeviceArray([ 440113., -526735.,  415416., -679416.], dtype=float32),
 DeviceArray([ 494704.,  316153., -147435., -439674.], dtype=float32)]

In [254]:
encryptor.decrypt(mulc, L) #+ ctx.Q//2

DeviceArray([11., 22., 33., 44.], dtype=float32)

In [255]:
jnp.polydiv(jnp.polymul(f1, f2), get_modulo(4))[1]

DeviceArray([ 0.,  0.,  0., 11., 22., 33., 44.], dtype=float32)

### Cipher Multiplication

In [256]:
def mod(x): return shift_mod(x, ctx.Q)
_mul = lambda x, y: ring_polymul(x, y, ctx.modulo)[-4:]
_modmul = lambda x, y: mod(ring_polymul(x, y, ctx.modulo)[-4:])
def mul(*args):
    return reduce(_mul, args)

def modmul(*args):
    return reduce(_modmul, args)

In [289]:

def ciphermul(ct1, ct2):
    return [
        modmul(ct1[0], ct2[0]),
        mod(mul(ct1[0], ct2[1]) + mul(ct1[1], ct2[0])),
        modmul(ct1[1], ct2[1]),
    ]

def relinear(ct):
    return [
        mod(ct[0] + 1/ctx.P * mod(jnp.around(mul(ct[2], ctx.evk[0])))),
        mod(ct[1] + 1/ctx.P * mod(jnp.around(mul(ct[2], ctx.evk[1])))),
    ]

def cmul(ct1, ct2):
    return relinear(ciphermul(ct1, ct2))

In [290]:
ring_polymul(ctx.sk, ctx.sk, ctx.modulo)[-4:]

DeviceArray([-36.,  56., -52.,  17.], dtype=float32)

In [291]:
f1 = jnp.array([1,2,3,4])
f2 = jnp.array([0,0,0,2])

cf1 = encryptor.encrypt(f1)
cf2 = encryptor.encrypt(f2)

In [292]:
ct = cmul(cf1, cf2)

In [293]:
ct

[DeviceArray([-196800.,  393024.,  229312.,   97600.], dtype=float32),
 DeviceArray([ 1.31264e+05, -2.62592e+05,  3.35808e+05, -6.40000e+01], dtype=float32)]

In [294]:
encryptor.decrypt(ct, L)

DeviceArray([ 232064.,  211200., -204544., -178048.], dtype=float32)

In [295]:
shift_mod(cf1[0] + ring_polymul(cf1[1], ctx.sk, ctx.modulo)[-4:], ctx.Q) ,\
shift_mod(cf2[0] + ring_polymul(cf2[1], ctx.sk, ctx.modulo)[-4:], ctx.Q)

(DeviceArray([1., 2., 3., 4.], dtype=float32),
 DeviceArray([0., 0., 0., 2.], dtype=float32))

In [264]:
shift_mod(
    mul(cf1[1], cf2[1], ctx.sk, ctx.sk) \
        + mul(cf1[0], cf2[1], ctx.sk) \
        + mul(cf1[1], cf2[0], ctx.sk) \
        + mul(cf1[0], cf2[0]),
    ctx.Q)

DeviceArray([  65536., -720896.,  -65536., -589824.], dtype=float32)

In [265]:
mod(
    mul(
        mod(cf1[0] + mul(cf1[1], ctx.sk)),
        mod(cf2[0] + mul(cf2[1], ctx.sk)),
    )
)

DeviceArray([2., 4., 6., 8.], dtype=float32)

In [266]:
a, b = cf1[0], modmul(cf1[1], ctx.sk)
c, d = cf2[0], modmul(cf2[1], ctx.sk)

mod(
    modmul(a,c) + modmul(b,d) + modmul(a,d) + modmul(b,c)
)

DeviceArray([65536.,     0., 65536., 32768.], dtype=float32)

In [267]:
mul(
    mod(a+b),
    mod(c+d)
)

DeviceArray([2., 4., 6., 8.], dtype=float32)

In [268]:
mul(
    cf1[0] + mul(cf1[1], ctx.sk),
    cf2[0] + mul(cf2[1], ctx.sk),
)

DeviceArray([-1.2976126e+07, -2.0787714e+12, -4.1575241e+12,
             -2.0787570e+12], dtype=float32)

In [269]:
cf1[1], ctx.a

(DeviceArray([ 494704,  316153, -147435, -439674], dtype=int32),
 DeviceArray([ 494704,  316153, -147435, -439674], dtype=int32))

In [223]:
(3 * 4) % 5

2