In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit

from functools import partial
from typing import List

jax_key = jax.random.PRNGKey(0)


In [2]:
q = 32
m = 16
n = m//2

In [3]:
a = jax.random.randint(jax_key, (m,), 0, q).astype(int) # sample
s = jax.random.randint(jax_key, (m,), 0, q).astype(int) # secret key
e = jnp.array([0,0,0,0]).astype(int)

# modulo [1, 0, 0, ..., 0, 1] : (m+1, )
modulo = jnp.zeros((m+1,)).astype(int)
modulo = modulo.at[0].set(1)
modulo = modulo.at[-1].set(1)

a_s = jnp.polydiv(jnp.polymul(a, s), modulo)[1]

message = jnp.array([7,6,5,4,3,2,1,0]).astype(int)

In [4]:
class CKKS:
    def __init__(self, q: int, m: int):
        self.a = jax.random.randint(jax_key, (m,), 0, q).astype(int) # sample
        self.s = jax.random.randint(jax_key, (m,), 0, q).astype(int) # secret key
        self.e = jnp.array([5]).astype(int)

        # modulo [1, 0, 0, ..., 0, 1] : (m+1, )
        modulo = jnp.zeros((m+1,)).astype(int)
        modulo = modulo.at[0].set(1)
        self.modulo = modulo.at[-1].set(1)

        self.a_s = jnp.polydiv(jnp.polymul(self.a, self.s), self.modulo)[1]

        self.pub_key = [
            jnp.polyadd(-self.a_s, self.e),
            self.a
        ]

    def encrypt(self, message: jnp.array) -> jnp.array:
        return [
            jnp.polyadd(self.pub_key[0], message),
            self.a
        ]

    def decrypt(self, ciphertext: List[jnp.array]) -> jnp.array:
        return jnp.polyadd(
            ciphertext[0],
            # jnp.polydiv(jnp.polymul(ciphertext[1], self.s), self.modulo)[1]
            jnp.polydiv(jnp.polymul(ciphertext[1], self.s), self.modulo)[1]
        )

ckks = CKKS(13, 16)
message = jnp.array([777,666,555,444,333,222,111,0]).astype(int)

In [5]:
cipher = ckks.encrypt(message)
message_decrypted = ckks.decrypt(cipher)
message_decrypted

DeviceArray([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
               0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
               0.,   0.,   0., 777., 666., 555., 444., 333., 222., 111.,
               5.], dtype=float32)

## Ciphertext-Ciphertext Multiplication