In [3]:
import jax
import jax.numpy as jnp

from ckks.encoder import Encoder
from ckks.encryptor import Encryptor
from ckks.context import Context
from ckks.cipher import Cipher
from ckks.utils import *

In [4]:
M = 4*2
SCALE = 64
P = 1024
q = 10
LAMBDA = 10
p = 2 
l = 3
L = 10
ql = (p**l) * q

In [6]:
ctx = Context(M, SCALE, q, p, l, L, P)
encoder = Encoder(M, SCALE)
encryptor = Encryptor(ctx)

In [4]:
m1 = jnp.array([111,222])
m2 = jnp.array([999,20])

encrypt = lambda x: encryptor.encrypt(encoder.encode(x))
decrypt = lambda x: encoder.decode(encryptor.decrypt(x))

In [5]:
e1 = encoder.encode(m1)
e2 = encoder.encode(m2)

c1 = encryptor.encrypt(e1)
c2 = encryptor.encrypt(e2)

In [6]:
decrypt(c2)

DeviceArray([1000.401   -1.470932j ,   19.224121-1.1587372j], dtype=complex64)

In [7]:
c1

[DeviceArray([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0662e+04,
              -2.5070e+03,  6.0000e+00,  2.5160e+03], dtype=float32),
 DeviceArray([8, 1, 7, 0], dtype=int32)]

In [8]:
def cipheradd(cipher1, cipher2, modulo):
    return [
        jnp.polyadd(cipher1[0],cipher2[0]) , 
        jnp.polyadd(cipher1[1],cipher2[1]) ,
    ]

In [9]:
mulc = cipheradd(c1, c2, get_modulo(4))

In [10]:
mulc

[DeviceArray([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3276e+04,
               1.9650e+04,  1.2000e+01, -1.9632e+04], dtype=float32),
 DeviceArray([16,  2, 14,  0], dtype=int32)]

In [16]:
decrypt(mulc)

DeviceArray([1112.7994 -2.9419556j,  240.45068-2.3172607j], dtype=complex64)

In [185]:
import jax
import jax.numpy as jnp
from typing import List

from ckks.utils import get_modulo


def cipheradd(cipher1, cipher2, modulo):
    return [
        jnp.polydiv(jnp.polyadd(cipher1[0],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polyadd(cipher1[0],cipher2[0]), modulo)[1],
    ]

def ciphersub(cipher1, cipher2, modulo):
    return [
        jnp.polydiv(jnp.polysub(cipher1[0],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polysub(cipher1[0],cipher2[0]), modulo)[1],
    ]

def ciphermul(cipher1, cipher2, modulo):
    return [
        jnp.polydiv(jnp.polymul(cipher1[0],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polymul(cipher1[0],cipher2[1]), modulo)[1] + \
        jnp.polydiv(jnp.polymul(cipher1[1],cipher2[0]), modulo)[1],
        jnp.polydiv(jnp.polymul(cipher1[1],cipher2[1]), modulo)[1],
    ]

def ciphermul_constant(cipher, constant, modulo):
    return [
        jnp.polydiv(jnp.polymul(cipher[0], constant), modulo)[1],
        jnp.polydiv(jnp.polymul(cipher[1], constant), modulo)[1],
    ]

def relinearize(c_mult, evk, p, modulo):
    P = [
        1/p * jnp.polydiv(jnp.polymul(c_mult[2], evk[0]), modulo)[1],
        1/p * jnp.polydiv(jnp.polymul(c_mult[2], evk[1]), modulo)[1]
    ]
    return [
        jnp.polyadd(jnp.polydiv(jnp.polymul(c_mult[0],P), modulo)[1], P[0]),
        jnp.polyadd(jnp.polydiv(jnp.polymul(c_mult[1],P), modulo)[1], P[1]),
    ]

class Cipher:
    _add = staticmethod(cipheradd)
    _sub = staticmethod(ciphersub)
    _mul = staticmethod(ciphermul)
    _mul_constant = staticmethod(ciphermul_constant)
    _relin = staticmethod(relinearize)

    def __init__(self, content: List[jnp.array], modulo) -> None:
        self.content = content
        if isinstance(modulo, jnp.ndarray):
            self.modulo = modulo
        elif isinstance(modulo, int):
            self.modulo = get_modulo(modulo)
        self.depth = 1


    def __mul__(self, other):
        if isinstance(other, Cipher):
            c_mult = self._mul(self, self.content, other.content, self.modulo)
            content = self._relin(c_mult, self.evk)
        elif isinstance(other, jnp.array):
            content = self._mul_constant(self.content, other, self.modulo)
        return self.__class__(content, self.modulo)
    
    def __add__(self, other):
        if isinstance(other, Cipher):
            content = self._add(self.content, other.content, self.modulo)
        else:
            raise NotImplementedError
        return self.__class__(content, self.modulo)
    
    def __sub__(self, other):
        if isinstance(other, Cipher):
            content = self._sub(self.content, other.content, self.modulo)
        else:
            raise NotImplementedError
        return self.__class__(content, self.modulo)