# Encryption in CKKS

In [212]:
from numpy.polynomial import Polynomial
import numpy as np

def round_coordinates(coordinates):
    """Gives the integral rest."""
    coordinates = coordinates - np.floor(coordinates)
    return coordinates

def coordinate_wise_random_rounding(coordinates):
    """Rounds coordinates randonmly."""
    r = round_coordinates(coordinates)
    f = np.array([np.random.choice([c, c-1], 1, p=[1-c, c]) for c in r]).reshape(-1)
    
    rounded_coordinates = coordinates - f
    rounded_coordinates = [int(coeff) for coeff in rounded_coordinates]
    return rounded_coordinates

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, context, scale:float):
        """Initializes with scale."""
        
        M = context.M
        self.context = context
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.create_sigma_R_basis()
        self.scale = scale
        
        self.slot_count = M // 4
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
    

    def pi(self, z: np.array) -> np.array:
        """Projects a vector of H into C^{N/2}."""

        N = self.M // 4
        return z[:N]


    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""

        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])
    
    def create_sigma_R_basis(self):
        """Creates the basis (sigma(1), sigma(X), ..., sigma(X** N-1))."""

        self.sigma_R_basis = np.array(self.vandermonde(self.xi, self.M)).T
    

    def compute_basis_coordinates(self, z):
        """Computes the coordinates of a vector with respect to the orthogonal lattice basis."""
        output = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in self.sigma_R_basis])
        return output

    def sigma_R_discretization(self, z):
        """Projects a vector on the lattice using coordinate wise random rounding."""
        coordinates = self.compute_basis_coordinates(z)

        rounded_coordinates = coordinate_wise_random_rounding(coordinates)
        y = np.matmul(self.sigma_R_basis.T, rounded_coordinates)
        return y


    def encode(self, z: np.array) -> Polynomial:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = self.scale * pi_z
        rounded_scale_pi_zi = self.sigma_R_discretization(scaled_pi_z)
        p = self.sigma_inverse(rounded_scale_pi_zi)

        # We round it afterwards due to numerical imprecision
        coef = np.round(np.real(p.coef)).astype(int)
        p = self.context.QPolynomial(coef)
        
        return p


    def decode(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by removing the scale, 
        evaluating on the roots, and project it on C^(N/2)"""
        rescaled_p = p / self.scale
        z = self.sigma(rescaled_p)
        pi_z = self.pi(z)
        return pi_z

In [213]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial

def mod_q(coeffs, q):
    r = coeffs % q
    to_cycle = r > (q/2)
    r[to_cycle] = r[to_cycle] - q
    return r

def mod_q_polynomial(p, q):
    coeffs = p.coef
    coeffs_mod_q = mod_q(coeffs, q)
    p = Polynomial(coeffs_mod_q)
    return p

In [214]:
class PolynomialSampler:
    def __init__(self, context):
        self.context = context
        
    def __getattr__(self,k):
        return getattr(self.context, k)
        
    def polynomial(self, coeffs):
        p = self.context.QPolynomial(coeffs)
        return p

In [215]:
class UniformPolynomial(PolynomialSampler):
    def sample(self) -> QPolynomial:
            
        coeffs = np.random.choice(self.q, size=self.N)
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [216]:
class ZO(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([0,1,-1], size=self.N, p=[1-self.p, self.p /2, self.p/2])
        
        p = self.polynomial(coeffs)
        return p

In [217]:
class HWT(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([-1,1], size=self.N)
        
        slots = np.random.choice(range(self.N), self.N - self.h, replace=False)
        
        coeffs[slots] = 0
            
        p = self.polynomial(coeffs)
        return p

In [218]:
class DG(PolynomialSampler):
    def sample(self):
        coeffs = np.random.normal(np.zeros(self.N), self.sigma)
        
        coeffs = np.array(coordinate_wise_random_rounding(coeffs))
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [219]:
class Context:
    def __init__(self, N, q):
        self.N = N
        self.M = N * 2
        self.q = q
        
        self.QPolynomial = QPolynomialGenerator(N, q)
        
        self.setup_parameters()
        
        self.hwt = HWT(self)
        self.dg = DG(self)
        self.uniform = UniformPolynomial(self)
        self.zo = ZO(self)
        
    def setup_parameters(self):
        self.h = 64
        self.sigma = 3
        self.p = 0.5

In [220]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial
from __future__ import annotations
        
class QPolynomialGenerator:
    
    def __init__(self, N, q):
        self.q = q
        
        coeffs = np.zeros(N + 1)
        coeffs[0] = 1
        coeffs[-1] = 1
        
        self.poly_modulus = Polynomial(coeffs)
    
    def __call__(self, coef):
        return QPolynomialGenerator.QPolynomial(coef, self.q, self.poly_modulus, self)
    
    class QPolynomial:
        def __init__(self, coef, q, poly_modulus, generator):
            self.p = Polynomial(coef)
            self.q = q
            self.poly_modulus = poly_modulus
            self.generator = generator

        def __getattr__(self,k):
            if k in self.__dict__.keys():
                return getattr(self, k)
            else:
                try:
                    return getattr(self.p, k)
                except AttributeError as e:
                    print(e)

        def __mul__(self, qpoly: QPolynomial) -> QPolynomial:
            p_mul = self.p * qpoly.p % self.poly_modulus
            p_mul = self.generator(p_mul.coef)
            return p_mul

        def __rmul__(self, qpoly: QPolynomial) -> QPolynomial:
            return self.__mul__(p)

        def __add__(self, qpoly: QPolynomial) -> QPolynomial:
            p_add = (self.p + qpoly.p) % self.poly_modulus
            p_add = self.generator(p_add.coef)
            return p_add

        def __radd__(self, qpoly: QPolynomial) -> QPolynomial:
            return self.__add__(qpoly)

        def __sub__(self, qpoly: QPolynomial) -> QPolynomial:
            p_sub = (self.p - qpoly.p) % self.poly_modulus
            p_sub = self.generator(p_sub.coef)
            return p_sub

        def __rsub__(self, qpoly: QPolynomial) -> QPolynomial:
            p_sub = (qpoly.p - self.p) % self.poly_modulus
            p_sub = self.generator(p_sub.coef)
            return p_sub
        
        def __neg__(self):
            p_neg = - self.p %self.poly_modulus
            p_neg = self.generator(p_neg.coef)
            return p_neg
        
        def __pos__(self):
            p_pos = self.p %self.poly_modulus
            p_pos = self.generator(p_pos.coef)
            return p_pos
        
        def __truediv__ (self, scale):
            p_div = self.p / scale
            p_div = self.generator(p_div.coef)
            return p_div
        
        def __call__(self, x):
            return self.p(x)
        
        def __repr__(self):
            return self.p.__repr__()

In [221]:
QPolynomial = QPolynomialGenerator(N, q)

In [222]:
coef = np.random.choice([0,1], N)
poly = QPolynomial(coef)

In [223]:
a = poly * poly 

In [224]:
a = poly.p * poly.p % poly.poly_modulus

In [225]:
b = poly * poly

In [226]:
class Keygen:
    def __init__(self, context):
        self.context = context
        
    def generate_secret_key(self):
        s = self.context.hwt.sample()
        return s
    
    def generate_public_key(self, s):
        a = self.context.uniform.sample()
        e = self.context.dg.sample()
        
        b = -(a * s) + e

        pk = (b, a)
        return pk

In [227]:
class Encryptor:
    def __init__(self, context, pk):
        self.pk = pk
        self.context = context
        
    def encrypt(self, ptx):
        v = self.context.zo.sample()
        e0 = self.context.dg.sample()
        e1 = self.context.dg.sample()

        v_pk = (self.pk[0] * v, self.pk[1] * v)
        ctx = (v_pk[0] + e0 + ptx, v_pk[1] + e1)
        
        return ctx

In [228]:
class Decryptor:
    def __init__(self, context, s):
        self.s = s
        self.context = context
        
    def decrypt(self, ctx):
        ptx = ctx[0] + self.s * ctx[1]
        return ptx

In [229]:
N = 512
M = N * 2
scale = pow(2,30)
q = scale

In [230]:
context = Context(N, q)
keygen = Keygen(context)

In [231]:
s = keygen.generate_secret_key()
pk = keygen.generate_public_key(s)

In [232]:
encoder = CKKSEncoder(context, scale)

In [233]:
encryptor = Encryptor(context, pk)
decryptor = Decryptor(context, s)

In [234]:
z = np.arange(N//2)

In [235]:
ptx = encoder.encode(z)

In [236]:
ctx = encryptor.encrypt(ptx)
ptx = decryptor.decrypt(ctx)
encoder.decode(ptx)

array([-7.57317054e-08-7.05216714e-07j,  9.99999757e-01+2.01132499e-07j,
        1.99999991e+00+6.40854466e-07j,  2.99999992e+00+4.13944560e-07j,
        4.00000106e+00+5.02470143e-07j,  4.99999940e+00+1.76148262e-07j,
        6.00000042e+00+1.14551535e-06j,  6.99999944e+00-5.80426544e-07j,
        8.00000028e+00+9.35107806e-07j,  8.99999782e+00+3.82756054e-07j,
        1.00000003e+01+6.41280078e-08j,  1.09999996e+01+1.08810639e-06j,
        1.19999993e+01-8.42371300e-07j,  1.29999994e+01+7.32190774e-07j,
        1.40000004e+01-3.30819255e-07j,  1.49999996e+01-1.07693328e-06j,
        1.60000010e+01-6.00029754e-07j,  1.69999996e+01+4.11432733e-07j,
        1.80000001e+01+3.91897366e-07j,  1.89999996e+01-2.32060199e-07j,
        2.00000010e+01-2.98118231e-07j,  2.10000005e+01+1.00157820e-08j,
        2.20000006e+01-8.16639467e-07j,  2.29999997e+01-7.14923978e-08j,
        2.39999999e+01-1.03322524e-07j,  2.49999993e+01-6.76689559e-07j,
        2.59999998e+01-4.36773497e-08j,  2.70000003

In [237]:
ct_add = (ctx[0] + ctx[0], ctx[1] + ctx[1])
ptx = decryptor.decrypt(ct_add)
encoder.decode(ptx)

array([-1.51463411e-07-1.41043343e-06j,  1.99999951e+00+4.02264998e-07j,
        3.99999982e+00+1.28170893e-06j,  5.99999984e+00+8.27889121e-07j,
        8.00000211e+00+1.00494029e-06j,  9.99999880e+00+3.52296524e-07j,
        1.20000008e+01+2.29103069e-06j,  1.39999989e+01-1.16085309e-06j,
        1.60000006e+01+1.87021561e-06j,  1.79999956e+01+7.65512109e-07j,
        2.00000006e+01+1.28256016e-07j,  2.19999992e+01+2.17621277e-06j,
        2.39999986e+01-1.68474260e-06j,  2.59999988e+01+1.46438155e-06j,
        2.80000008e+01-6.61638509e-07j,  2.99999993e+01-2.15386656e-06j,
        3.20000020e+01-1.20005951e-06j,  3.39999992e+01+8.22865466e-07j,
        3.60000002e+01+7.83794732e-07j,  3.79999993e+01-4.64120397e-07j,
        4.00000021e+01-5.96236461e-07j,  4.20000009e+01+2.00315640e-08j,
        4.40000012e+01-1.63327893e-06j,  4.59999993e+01-1.42984796e-07j,
        4.79999997e+01-2.06645048e-07j,  4.99999986e+01-1.35337912e-06j,
        5.19999996e+01-8.73546995e-08j,  5.40000006

In [245]:
a = np.ones(encoder.slot_count)
b = np.ones(encoder.slot_count) * 2

In [246]:
ptx1 = encoder.encode(a)
ptx2 = encoder.encode(b)

In [247]:
ctx1 = encryptor.encrypt(ptx1)
ctx2 = encryptor.encrypt(ptx2)

In [248]:
ct_mul = (ctx1[0] * ctx2[0], ctx1[1] * ctx2[0] + ctx1[0] * ctx2[1], ctx1[1] * ctx2[1])

In [249]:
pt_mul = ct_mul[0] + ct_mul[1] * s + ct_mul[2] * s * s

In [250]:
encoder.decode(pt_mul)

array([2.14748462e+09+1.10114562e+03j, 2.14748428e+09+8.79125044e+02j,
       2.14748333e+09-5.69100079e+02j, 2.14748259e+09-2.33486692e+03j,
       2.14748453e+09-5.26272181e+02j, 2.14748310e+09+1.38128896e+03j,
       2.14748442e+09+1.17802594e+03j, 2.14748599e+09+1.31097597e+03j,
       2.14748499e+09+1.56802720e+03j, 2.14748173e+09+1.50104918e+03j,
       2.14748511e+09-2.53062561e+03j, 2.14748222e+09+3.91550865e+03j,
       2.14748173e+09-5.64535559e+02j, 2.14747828e+09-1.11709513e+03j,
       2.14748361e+09+1.13839707e+03j, 2.14748318e+09-3.21219608e+03j,
       2.14748541e+09+2.67056186e+02j, 2.14747977e+09-1.91322897e+03j,
       2.14748408e+09+2.77829937e+03j, 2.14748361e+09-1.88514873e+03j,
       2.14748915e+09-2.66335163e+03j, 2.14748166e+09+1.67899730e+03j,
       2.14748509e+09-3.75610880e+03j, 2.14748335e+09+4.00332586e+02j,
       2.14748344e+09+1.88017542e+02j, 2.14748326e+09-4.30222182e+02j,
       2.14748568e+09+1.42552256e+03j, 2.14748442e+09+9.79316152e+02j,
      