In [1]:
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
from jax.lax import scan
from typing import List

from ckks.encoder import Encoder
from ckks.encryptor import Encryptor
from ckks.context import Context
from ckks.cipher import Cipher
from ckks.utils import *

## Encryptor

In [2]:
M = 4*2
SCALE = 1
P = 1024
q = 11
LAMBDA = 10
p = 2 
l = 3
L = 15
ql = (p**l) * q

In [3]:
ctx = Context(M, SCALE, q, p, l, L, P)
encoder = Encoder(M, SCALE)
encryptor = Encryptor(ctx)
print(encryptor.Q)
print(encryptor.pub_key)

327680
[DeviceArray([-50944.,  93723.,  94856.,  98777.], dtype=float32), DeviceArray([ 134256,   86777,   81941, -144762], dtype=int32)]


In [5]:
fake = jnp.array([111, -222, 333, 444])
encryptor.encrypt(fake),\
encryptor.decrypt(encryptor.encrypt(fake), L)

([DeviceArray([113007., -70339., -68651., -64619.], dtype=float32),
  DeviceArray([ 134256,   86777,   81941, -144762], dtype=int32)],
 DeviceArray([ 111., -222.,  333.,  444.], dtype=float32))

In [6]:
m1 = jnp.array([111,222])
m2 = jnp.array([999,20])

encrypt = lambda x: encryptor.encrypt(encoder.encode(x))
decrypt = lambda x: encoder.decode(encryptor.decrypt(x, L)) # L ~ first time encryption

In [7]:
e1 = encoder.encode(m1)
e2 = encoder.encode(m2)

c1 = encryptor.encrypt(e1)
c2 = encryptor.encrypt(e2)

In [8]:
decrypt(c1), decrypt(c2)

(DeviceArray([111.845665-9.5367432e-06j, 222.15433 +2.4795532e-05j], dtype=complex64),
 DeviceArray([999.31793+9.1552734e-05j,  20.68219-1.8310547e-04j], dtype=complex64))

## Arithmetic

### ADD

In [9]:
def cipheradd(cipher1, cipher2):
    return [
        shift_mod(jnp.polyadd(cipher1[0],cipher2[0]), encryptor.Q) , 
        shift_mod(jnp.polyadd(cipher1[1],cipher2[1]), encryptor.Q),
    ]

In [10]:
mulc = cipheradd(c1, c2)

In [11]:
decrypt(mulc)

DeviceArray([1111.1636 +7.6293945e-05j,  242.83652-1.8310547e-04j], dtype=complex64)

### Constant Multiplication

In [19]:
f1 = jnp.array([11,22,33,44])
f2 = jnp.array([0,0,11,11])

ef1 = encryptor.encrypt(f1)
ef1

[DeviceArray([112907., -70095., -68951., -65019.], dtype=float32),
 DeviceArray([ 134256,   86777,   81941, -144762], dtype=int32)]

In [20]:
def ciphermul_constant(cipher, constant):
    return [
        shift_mod(ring_polymul(cipher[0], constant, encryptor.modulo)[-4:], encryptor.Q),
        shift_mod(ring_polymul(cipher[1], constant, encryptor.modulo)[-4:], encryptor.Q),
    ]

In [21]:
ring_polymul(ef1[0], f2, get_modulo(4))[-4:]

DeviceArray([  470932., -1529506., -1473670., -1957186.], dtype=float32)

In [22]:
mulc = ciphermul_constant(ef1, f2)
mulc

[DeviceArray([ -20588.,  -54946.,     890., -154946.], dtype=float32),
 DeviceArray([-26237.,  53658., 128169.,  43762.], dtype=float32)]

In [23]:
encryptor.decrypt(mulc, L) #+ ctx.Q//2

DeviceArray([363., 605., 847., 363.], dtype=float32)

In [24]:
jnp.polydiv(jnp.polymul(f1, f2), get_modulo(4))[1]

DeviceArray([  0.,   0.,   0., 363., 605., 847., 363.], dtype=float32)

### Cipher Multiplication

In [58]:
def ciphermul(c1, c2):
    return [
        shift_mod(ring_polymul(c1[0], c2[0], encryptor.modulo)[-4:], encryptor.Q),
        shift_mod(
            jnp.polyadd(
                ring_polymul(c1[0], c2[1], encryptor.modulo)[-4:],
                ring_polymul(c1[1], c2[0], encryptor.modulo)[-4:],
            ),
            encryptor.Q
        ),
        shift_mod(ring_polymul(c1[1], c2[1], encryptor.modulo)[-4:], encryptor.Q),
    ]

def deciphermul(ct, l):
    return shift_mod(
        ring_polyadd(
            ring_polyadd(
                ct[0],
                ring_polymul(ct[1], ctx.sk, encryptor.modulo)[-4:],
                encryptor.modulo
            ),
            ring_polymul(ct[2], ctx.sk_square, encryptor.modulo)[-4:],
            encryptor.modulo
        ),
        (ctx.p**l) * ctx.q
    )

In [59]:
f1 = jnp.array([1,2,3,4])
f2 = jnp.array([0,0,10,10])

cf1 = encryptor.encrypt(f1)
cf2 = encryptor.encrypt(f2)

In [60]:
ct = ciphermul(cf1, cf2)

In [61]:
ct

[DeviceArray([-148992.,  149504., -159744.,  -61440.], dtype=float32),
 DeviceArray([  57344.,   57344., -134656., -116736.], dtype=float32),
 DeviceArray([  90112., -122880.,  -16384., -147456.], dtype=float32)]

In [72]:
deciphermul(ct, 0)

DeviceArray([-3., -5.,  1., -3.], dtype=float32)