# Encryption in CKKS

In [1]:
from numpy.polynomial import Polynomial
import numpy as np

def round_coordinates(coordinates):
    """Gives the integral rest."""
    coordinates = coordinates - np.floor(coordinates)
    return coordinates

def coordinate_wise_random_rounding(coordinates):
    """Rounds coordinates randonmly."""
    r = round_coordinates(coordinates)
    f = np.array([np.random.choice([c, c-1], 1, p=[1-c, c]) for c in r]).reshape(-1)
    
    rounded_coordinates = coordinates - f
    rounded_coordinates = [int(coeff) for coeff in rounded_coordinates]
    return rounded_coordinates

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, context, scale:float):
        """Initializes with scale."""
        
        M = context.M
        self.context = context
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.create_sigma_R_basis()
        self.scale = scale
        
        self.slot_count = M // 4
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
    

    def pi(self, z: np.array) -> np.array:
        """Projects a vector of H into C^{N/2}."""

        N = self.M // 4
        return z[:N]


    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""

        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])
    
    def create_sigma_R_basis(self):
        """Creates the basis (sigma(1), sigma(X), ..., sigma(X** N-1))."""

        self.sigma_R_basis = np.array(self.vandermonde(self.xi, self.M)).T
    

    def compute_basis_coordinates(self, z):
        """Computes the coordinates of a vector with respect to the orthogonal lattice basis."""
        output = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in self.sigma_R_basis])
        return output

    def sigma_R_discretization(self, z):
        """Projects a vector on the lattice using coordinate wise random rounding."""
        coordinates = self.compute_basis_coordinates(z)

        rounded_coordinates = coordinate_wise_random_rounding(coordinates)
        y = np.matmul(self.sigma_R_basis.T, rounded_coordinates)
        return y


    def encode(self, z: np.array) -> Polynomial:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = self.scale * pi_z
        rounded_scale_pi_zi = self.sigma_R_discretization(scaled_pi_z)
        p = self.sigma_inverse(rounded_scale_pi_zi)

        # We round it afterwards due to numerical imprecision
        coef = np.round(np.real(p.coef)).astype(int)
        p = self.context.QPolynomial(coef)
        
        return p


    def decode(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by removing the scale, 
        evaluating on the roots, and project it on C^(N/2)"""
        rescaled_p = p / self.scale
        z = self.sigma(rescaled_p)
        pi_z = self.pi(z)
        return pi_z

In [2]:
from numpy.polynomial import Polynomial

def mod_q(coeffs: np.ndarray, q: int) -> np.ndarray:
    """Reduce modulo q to (-q/2,q/2]"""
    
    r = coeffs % q
    # Coefficients larger than q/2 are sent to (-q/2,0]
    to_cycle = r > (q/2)
    r[to_cycle] = r[to_cycle] - q
    return r

def mod_q_polynomial(p: Polynomial, q: int) -> Polynomial:
    """Reduces modulo q the coefficients of a polynomial"""
    coeffs = p.coef
    coeffs_mod_q = mod_q(coeffs, q)
    p = Polynomial(coeffs_mod_q)
    return p

In [23]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial
from __future__ import annotations
        
class QPolynomialGenerator:
    """Polynomial generator for polynomials in Z_q[X]/(X^N + 1)"""
    
    def __init__(self, N: int, q: int):
        self.q = q
        
        coeffs = np.zeros(N + 1)
        coeffs[0] = 1
        coeffs[-1] = 1
        
        self.poly_modulus = Polynomial(coeffs)
    
    def __call__(self, coef) -> QPolynomial:
        """Creates a polynomial from coefficients."""
        return QPolynomial(coef, self.q, self.poly_modulus, self)
    
class QPolynomial:
    """Polynomial in Z_q[X]/(X^N + 1). 

    Addition, substraction, multiplication and division are overload with
    the correct operations inZ_q[X]/(X^N + 1)."""

    def __init__(self, coef, q, poly_modulus, generator):
        self.p = Polynomial(coef)
        self.q = q
        self.poly_modulus = poly_modulus
        self.generator = generator

    def __getattr__(self,k):
        if k in self.__dict__.keys():
            return getattr(self, k)
        else:
            try:
                return getattr(self.p, k)
            except AttributeError as e:
                print(e)

    def __mul__(self, right) -> QPolynomial:
        if isinstance(right, QPolynomial):
            right = right.p
        p_mul = self.p * right % self.poly_modulus
        p_mul = self.generator(p_mul.coef)
        return p_mul

    def __rmul__(self, left) -> QPolynomial:
        return self.__mul__(left)

    def __add__(self, right) -> QPolynomial:
        if isinstance(right, QPolynomial):
            right = right.p
        p_add = (self.p + right) % self.poly_modulus
        p_add = self.generator(p_add.coef)
        return p_add

    def __radd__(self, left) -> QPolynomial:
        return self.__add__(left)

    def __sub__(self, qpoly: QPolynomial) -> QPolynomial:
        p_sub = (self.p - qpoly.p) % self.poly_modulus
        p_sub = self.generator(p_sub.coef)
        return p_sub

    def __rsub__(self, qpoly: QPolynomial) -> QPolynomial:
        p_sub = (qpoly.p - self.p) % self.poly_modulus
        p_sub = self.generator(p_sub.coef)
        return p_sub

    def __neg__(self):
        p_neg = - self.p %self.poly_modulus
        p_neg = self.generator(p_neg.coef)
        return p_neg

    def __pos__(self):
        p_pos = self.p %self.poly_modulus
        p_pos = self.generator(p_pos.coef)
        return p_pos

    def __truediv__ (self, scale):
        p_div = self.p / scale
        p_div = self.generator(p_div.coef)
        return p_div

    def __call__(self, x):
        return self.p(x)

    def __repr__(self):
        return self.p.__repr__()

In [24]:
class PolynomialSampler:
    """Base class to sample polynomials."""
    def __init__(self, context):
        self.context = context
        
    def __getattr__(self,k):
        """Context variables are directly linked to the instance."""
        return getattr(self.context, k)
        
    def polynomial(self, coeffs):
        p = self.context.QPolynomial(coeffs)
        return p

In [25]:
class UniformPolynomial(PolynomialSampler):
    def sample(self) -> QPolynomial:
            
        coeffs = np.random.choice(self.q, size=self.N)
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p
    
    def sample_manually(self, q) -> QPolynomial:
        coeffs = np.random.choice(q, size=self.N)
        coeffs = mod_q(coeffs, q)
        
        p = self.polynomial(coeffs)
        return p

In [26]:
class ZO(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([0,1,-1], size=self.N, p=[1-self.p, self.p /2, self.p/2])
        
        p = self.polynomial(coeffs)
        return p

In [27]:
class HWT(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([-1,1], size=self.N)
        
        slots = np.random.choice(range(self.N), self.N - self.h, replace=False)
        
        coeffs[slots] = 0
            
        p = self.polynomial(coeffs)
        return p

In [28]:
class DG(PolynomialSampler):
    def sample(self):
        coeffs = np.random.normal(np.zeros(self.N), self.sigma)
        
        coeffs = np.array(coordinate_wise_random_rounding(coeffs))
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [29]:
class Context:
    def __init__(self, N, q):
        self.N = N
        self.M = N * 2
        self.q = q
        self.special_prime = q
        
        self.QPolynomial = QPolynomialGenerator(N, q)
        
        self.setup_parameters()
        
        self.hwt = HWT(self)
        self.dg = DG(self)
        self.uniform = UniformPolynomial(self)
        self.zo = ZO(self)
        
    def setup_parameters(self):
        self.h = 64
        self.sigma = 3
        self.p = 0.5

In [61]:
class Keygen:
    def __init__(self, context):
        self.context = context
        
    def generate_secret_key(self):
        s = self.context.hwt.sample()
        return s
    
    def generate_public_key(self, s):
        a = self.context.uniform.sample()
        e = self.context.dg.sample()
        
        b = -(a * s) + e

        pk = (b, a)
        return pk
    
    def generate_relin_key(self, s):
        a = self.context.uniform.sample_manually(self.context.q * self.context.special_prime)
        e = self.context.dg.sample()
        
        b = -(a * s) + e + self.context.special_prime * s * s
        relin_key = (b,a)
        return relin_key

In [62]:
class Encryptor:
    def __init__(self, context, pk):
        self.pk = pk
        self.context = context
        
    def encrypt(self, ptx):
        v = self.context.zo.sample()
        e0 = self.context.dg.sample()
        e1 = self.context.dg.sample()

        v_pk = (self.pk[0] * v, self.pk[1] * v)
        ctx = (v_pk[0] + e0 + ptx, v_pk[1] + e1)
        
        return ctx

In [63]:
class Decryptor:
    def __init__(self, context, s):
        self.s = s
        self.context = context
        
    def decrypt(self, ctx):
        ptx = ctx[0] + self.s * ctx[1]
        return ptx

In [64]:
N = 512
M = N * 2
scale = pow(2,30)
q = scale

In [65]:
context = Context(N, q)
keygen = Keygen(context)

In [66]:
s = keygen.generate_secret_key()
pk = keygen.generate_public_key(s)
relin_key = keygen.generate_relin_key(s)

In [69]:
encoder = CKKSEncoder(context, scale)

In [70]:
encryptor = Encryptor(context, pk)
decryptor = Decryptor(context, s)

In [71]:
z = np.arange(N//2)

In [72]:
ptx = encoder.encode(z)

In [73]:
ctx = encryptor.encrypt(ptx)
ptx = decryptor.decrypt(ctx)
encoder.decode(ptx)

array([8.74666881e-08+2.80348299e-08j, 9.99999812e-01-7.54441863e-07j,
       2.00000015e+00+5.02332226e-07j, 2.99999996e+00-6.01744684e-07j,
       3.99999963e+00+8.73697666e-07j, 5.00000034e+00-1.26485933e-08j,
       6.00000025e+00-6.08418610e-07j, 6.99999892e+00+2.53661812e-06j,
       7.99999675e+00-1.91879824e-06j, 8.99999966e+00+1.48259995e-06j,
       9.99999913e+00-3.86483899e-07j, 1.09999998e+01-6.68513326e-07j,
       1.20000015e+01+9.99945883e-07j, 1.29999982e+01+8.47804440e-07j,
       1.40000010e+01+1.46102057e-06j, 1.49999998e+01-2.73778760e-08j,
       1.59999993e+01+6.45048150e-07j, 1.70000003e+01-2.20513385e-10j,
       1.79999988e+01-1.22977695e-06j, 1.90000001e+01-9.03615565e-08j,
       2.00000009e+01-2.41782207e-07j, 2.09999998e+01-7.42053953e-07j,
       2.19999987e+01+2.58537234e-07j, 2.30000000e+01-4.33259686e-08j,
       2.39999998e+01+1.86527174e-07j, 2.50000000e+01-7.62345316e-08j,
       2.60000000e+01-1.01423968e-07j, 2.70000008e+01+9.51836004e-07j,
      

In [74]:
ct_add = (ctx[0] + ctx[0], ctx[1] + ctx[1])
ptx = decryptor.decrypt(ct_add)
encoder.decode(ptx)

array([1.74933376e-07+5.60696598e-08j, 1.99999962e+00-1.50888373e-06j,
       4.00000029e+00+1.00466445e-06j, 5.99999992e+00-1.20348937e-06j,
       7.99999926e+00+1.74739533e-06j, 1.00000007e+01-2.52971866e-08j,
       1.20000005e+01-1.21683722e-06j, 1.39999978e+01+5.07323624e-06j,
       1.59999935e+01-3.83759648e-06j, 1.79999993e+01+2.96519989e-06j,
       1.99999983e+01-7.72967798e-07j, 2.19999995e+01-1.33702665e-06j,
       2.40000030e+01+1.99989177e-06j, 2.59999963e+01+1.69560888e-06j,
       2.80000021e+01+2.92204113e-06j, 2.99999997e+01-5.47557519e-08j,
       3.19999986e+01+1.29009630e-06j, 3.40000007e+01-4.41026771e-10j,
       3.59999976e+01-2.45955390e-06j, 3.80000001e+01-1.80723113e-07j,
       4.00000017e+01-4.83564413e-07j, 4.19999996e+01-1.48410791e-06j,
       4.39999974e+01+5.17074469e-07j, 4.59999999e+01-8.66519372e-08j,
       4.79999997e+01+3.73054348e-07j, 4.99999999e+01-1.52469063e-07j,
       5.20000001e+01-2.02847936e-07j, 5.40000016e+01+1.90367201e-06j,
      

In [75]:
a = np.ones(encoder.slot_count)
b = np.ones(encoder.slot_count) * 5

In [76]:
ptx1 = encoder.encode(a)
ptx2 = encoder.encode(b)

In [77]:
ctx1 = encryptor.encrypt(ptx1)
ctx2 = encryptor.encrypt(ptx2)

In [78]:
ct_mul = (ctx1[0] * ctx2[0], ctx1[1] * ctx2[0] + ctx1[0] * ctx2[1], ctx1[1] * ctx2[1])

In [93]:
d = (ct_mul[2] * relin_key[0] / context.special_prime, ct_mul[2] * relin_key[1] / context.special_prime)

In [94]:
d = (context.QPolynomial(coordinate_wise_random_rounding(d[0].coef)), 
     context.QPolynomial(coordinate_wise_random_rounding(d[1].coef)))

In [95]:
ct_relin = (ct_mul[0] + d[0], ct_mul[1] + d[1])

In [97]:
ptx = decryptor.decrypt(ct_relin)
encoder.decode(ptx) / scale

array([5.03750131-3.31889031e-01j, 5.00749586+4.65599258e-01j,
       5.19769979-1.40269111e-01j, 5.01486511+4.97945796e-01j,
       5.06654593-1.07228585e-01j, 4.91182808-5.28833979e-01j,
       4.73969924-3.50959178e-01j, 5.10322369-2.15165387e-02j,
       4.95393491-2.92524571e-02j, 4.77256184+1.65123046e-01j,
       4.82863266+5.03841968e-01j, 5.4827891 -1.82881831e-01j,
       5.17102362+2.35680454e-01j, 4.80463359-4.08136147e-02j,
       5.02366343+1.22010441e-01j, 4.59527279-4.85108985e-01j,
       4.73414342-3.08973132e-01j, 4.81276417+4.08715148e-01j,
       4.94518062-5.60241467e-01j, 5.51095687+2.50942039e-01j,
       4.87954653+2.36569864e-01j, 4.95533197-1.76754156e-01j,
       4.6041859 +1.20881519e-01j, 4.57801498+6.20952891e-02j,
       5.19599657-2.87952326e-01j, 4.97654409-2.07463838e-01j,
       5.09982764-3.65451652e-01j, 4.94642676+3.52133627e-02j,
       4.84414665-2.70589883e-01j, 5.21072802-2.34105938e-02j,
       4.89485056+1.96965659e-02j, 4.99616835+4.4671723

In [79]:
pt_mul = ct_mul[0] + ct_mul[1] * s + ct_mul[2] * s * s

In [80]:
encoder.decode(pt_mul) / scale

array([5.00000212+3.99724677e-07j, 4.99999959-2.54464347e-06j,
       5.00000148-2.33034100e-06j, 4.99999618-4.11619193e-06j,
       5.00000638+1.65983180e-06j, 5.00000166-2.57470953e-06j,
       5.00000196-5.02832331e-06j, 4.99999627+7.64515826e-06j,
       4.99999948-5.62028747e-06j, 5.00000209+4.84288463e-06j,
       4.99999704-2.86492502e-06j, 5.00000273-9.99231079e-07j,
       4.99999245+8.80343933e-06j, 5.00000083-5.32245286e-06j,
       5.00001025+5.77940802e-06j, 4.99999473+5.36343441e-07j,
       4.99999249+4.37800987e-06j, 4.99999415-1.23396844e-06j,
       4.99999154-7.14831125e-06j, 5.00000169-2.03000445e-06j,
       4.99999464-2.12081633e-06j, 5.00000102-4.04948779e-07j,
       4.99999222+4.67891247e-06j, 5.00000351+2.66693471e-08j,
       4.99999898+2.10269179e-06j, 5.00000019+1.01514947e-07j,
       4.99998908-1.57141625e-06j, 5.00000087-4.07140915e-06j,
       4.99999881+1.29367334e-06j, 4.99999127+4.03209139e-06j,
       4.99999187-1.08489180e-06j, 5.00000613-3.8032974