# Encryption in CKKS

In [134]:
from numpy.polynomial import Polynomial
import numpy as np

def round_coordinates(coordinates):
    """Gives the integral rest."""
    coordinates = coordinates - np.floor(coordinates)
    return coordinates

def coordinate_wise_random_rounding(coordinates):
    """Rounds coordinates randonmly."""
    r = round_coordinates(coordinates)
    f = np.array([np.random.choice([c, c-1], 1, p=[1-c, c]) for c in r]).reshape(-1)
    
    rounded_coordinates = coordinates - f
    rounded_coordinates = [int(coeff) for coeff in rounded_coordinates]
    return rounded_coordinates

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, context, scale:float):
        """Initializes with scale."""
        
        M = context.M
        self.context = context
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.create_sigma_R_basis()
        self.scale = scale
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
    

    def pi(self, z: np.array) -> np.array:
        """Projects a vector of H into C^{N/2}."""

        N = self.M // 4
        return z[:N]


    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""

        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])
    
    def create_sigma_R_basis(self):
        """Creates the basis (sigma(1), sigma(X), ..., sigma(X** N-1))."""

        self.sigma_R_basis = np.array(self.vandermonde(self.xi, self.M)).T
    

    def compute_basis_coordinates(self, z):
        """Computes the coordinates of a vector with respect to the orthogonal lattice basis."""
        output = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in self.sigma_R_basis])
        return output

    def sigma_R_discretization(self, z):
        """Projects a vector on the lattice using coordinate wise random rounding."""
        coordinates = self.compute_basis_coordinates(z)

        rounded_coordinates = coordinate_wise_random_rounding(coordinates)
        y = np.matmul(self.sigma_R_basis.T, rounded_coordinates)
        return y


    def encode(self, z: np.array) -> Polynomial:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = self.scale * pi_z
        rounded_scale_pi_zi = self.sigma_R_discretization(scaled_pi_z)
        p = self.sigma_inverse(rounded_scale_pi_zi)

        # We round it afterwards due to numerical imprecision
        coef = np.round(np.real(p.coef)).astype(int)
        p = self.context.QPolynomial(coef)
        
        return p


    def decode(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by removing the scale, 
        evaluating on the roots, and project it on C^(N/2)"""
        rescaled_p = p / self.scale
        z = self.sigma(rescaled_p)
        pi_z = self.pi(z)
        return pi_z

In [135]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial

def mod_q(coeffs, q):
    r = coeffs % q
    to_cycle = r > (q/2)
    r[to_cycle] = r[to_cycle] - q
    return r

def mod_q_polynomial(p, q):
    coeffs = p.coef
    coeffs_mod_q = mod_q(coeffs, q)
    p = Polynomial(coeffs_mod_q)
    return p

In [136]:
class PolynomialSampler:
    def __init__(self, context):
        self.context = context
        
    def __getattr__(self,k):
        return getattr(self.context, k)
        
    def polynomial(self, coeffs):
        p = self.context.QPolynomial(coeffs)
        return p

In [137]:
class UniformPolynomial(PolynomialSampler):
    def sample(self) -> QPolynomial:
            
        coeffs = np.random.choice(self.q, size=self.N)
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [138]:
class ZO(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([0,1,-1], size=self.N, p=[1-self.p, self.p /2, self.p/2])
        
        p = self.polynomial(coeffs)
        return p

In [139]:
class HWT(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([-1,1], size=self.N)
        
        slots = np.random.choice(range(self.N), self.N - self.h, replace=False)
        
        coeffs[slots] = 0
            
        p = self.polynomial(coeffs)
        return p

In [140]:
class DG(PolynomialSampler):
    def sample(self):
        coeffs = np.random.normal(np.zeros(self.N), self.sigma)
        
        coeffs = np.array(coordinate_wise_random_rounding(coeffs))
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [141]:
class Context:
    def __init__(self, N, q):
        self.N = N
        self.M = N * 2
        self.q = q
        
        self.QPolynomial = QPolynomialGenerator(N, q)
        
        self.setup_parameters()
        
        self.hwt = HWT(self)
        self.dg = DG(self)
        self.uniform = UniformPolynomial(self)
        self.zo = ZO(self)
        
    def setup_parameters(self):
        self.h = 64
        self.sigma = 3
        self.p = 0.5

In [194]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial
from __future__ import annotations
        
class QPolynomialGenerator:
    
    def __init__(self, N, q):
        self.q = q
        
        coeffs = np.zeros(N + 1)
        coeffs[0] = 1
        coeffs[-1] = 1
        
        self.poly_modulus = Polynomial(coeffs)
    
    def __call__(self, coef):
        return QPolynomialGenerator.QPolynomial(coef, self.q, self.poly_modulus, self)
    
    class QPolynomial:
        def __init__(self, coef, q, poly_modulus, generator):
            self.p = Polynomial(coef)
            self.q = q
            self.poly_modulus = poly_modulus
            self.generator = generator

        def __getattr__(self,k):
            if k in self.__dict__.keys():
                return getattr(self, k)
            else:
                try:
                    return getattr(self.p, k)
                except AttributeError as e:
                    print(e)

        def __mul__(self, qpoly: QPolynomial) -> QPolynomial:
            p_mul = self.p * qpoly.p % self.poly_modulus
            p_mul = self.generator(p_mul.coef)
            return p_mul

        def __rmul__(self, qpoly: QPolynomial) -> QPolynomial:
            return self.__mul__(p)

        def __add__(self, qpoly: QPolynomial) -> QPolynomial:
            p_add = (self.p + qpoly.p) % self.poly_modulus
            p_add = self.generator(p_add.coef)
            return p_add

        def __radd__(self, qpoly: QPolynomial) -> QPolynomial:
            return self.__add__(qpoly)

        def __sub__(self, qpoly: QPolynomial) -> QPolynomial:
            p_sub = (self.p - qpoly.p) % self.poly_modulus
            p_sub = self.generator(p_sub.coef)
            return p_sub

        def __rsub__(self, qpoly: QPolynomial) -> QPolynomial:
            p_sub = (qpoly.p - self.p) % self.poly_modulus
            p_sub = self.generator(p_sub.coef)
            return p_sub
        
        def __neg__(self):
            p_neg = - self.p %self.poly_modulus
            p_neg = self.generator(p_neg.coef)
            return p_neg
        
        def __pos__(self):
            p_pos = self.p %self.poly_modulus
            p_pos = self.generator(p_pos.coef)
            return p_pos
        
        def __truediv__ (self, scale):
            p_div = self.p / scale
            p_div = self.generator(p_div.coef)
            return p_div
        
        def __call__(self, x):
            return self.p(x)
        
        def __repr__(self):
            return self.p.__repr__()

In [195]:
QPolynomial = QPolynomialGenerator(N, q)

In [196]:
coef = np.random.choice([0,1], N)
poly = QPolynomial(coef)

In [197]:
a = poly * poly 

In [198]:
a = poly.p * poly.p % poly.poly_modulus

In [199]:
b = poly * poly

In [200]:
class Keygen:
    def __init__(self, context):
        self.context = context
        
    def generate_secret_key(self):
        s = self.context.hwt.sample()
        return s
    
    def generate_public_key(self, s):
        a = self.context.uniform.sample()
        e = self.context.dg.sample()
        
        b = -(a * s) + e

        pk = (b, a)
        return pk

In [201]:
class Encryptor:
    def __init__(self, context, pk):
        self.pk = pk
        self.context = context
        
    def encrypt(self, ptx):
        v = self.context.zo.sample()
        e0 = self.context.dg.sample()
        e1 = self.context.dg.sample()

        v_pk = (self.pk[0] * v, self.pk[1] * v)
        ctx = (v_pk[0] + e0 + ptx, v_pk[1] + e1)
        
        return ctx

In [202]:
class Decryptor:
    def __init__(self, context, s):
        self.s = s
        self.context = context
        
    def decrypt(self, ctx):
        ptx = ctx[0] + self.s * ctx[1]
        return ptx

In [203]:
N = 512
M = N * 2
scale = pow(2,30)
q = scale

In [204]:
context = Context(N, q)
keygen = Keygen(context)

In [205]:
s = keygen.generate_secret_key()
pk = keygen.generate_public_key(s)

In [206]:
encoder = CKKSEncoder(context, scale)

In [207]:
encryptor = Encryptor(context, pk)
decryptor = Decryptor(context, s)

In [208]:
z = np.arange(N//2)

In [209]:
ptx = encoder.encode(z)

In [210]:
ctx = encryptor.encrypt(ptx)
ptx = decryptor.decrypt(ctx)
encoder.decode(ptx)

array([7.83721958e-08+9.91340920e-07j, 9.99999995e-01-4.14623626e-07j,
       2.00000067e+00+2.59877656e-07j, 2.99999908e+00+3.38044073e-08j,
       3.99999966e+00+2.05753176e-07j, 5.00000053e+00-1.97412830e-06j,
       6.00000018e+00+9.03122290e-07j, 6.99999972e+00+6.18335296e-07j,
       7.99999902e+00-5.87381844e-07j, 8.99999981e+00+1.01920412e-06j,
       1.00000006e+01-2.30881372e-06j, 1.09999997e+01-4.90301609e-07j,
       1.19999999e+01+1.50289765e-07j, 1.29999996e+01+5.49561641e-08j,
       1.39999997e+01+3.58636509e-07j, 1.50000009e+01-7.69647404e-07j,
       1.60000001e+01-1.94750083e-07j, 1.70000005e+01+7.20651062e-07j,
       1.79999995e+01-2.16965994e-07j, 1.90000003e+01+1.06361366e-07j,
       1.99999986e+01-1.87727782e-07j, 2.10000006e+01-2.25247717e-07j,
       2.19999986e+01+2.37763889e-07j, 2.30000005e+01+5.71838196e-07j,
       2.40000011e+01+1.23709750e-07j, 2.50000009e+01+7.71208832e-07j,
       2.59999996e+01+2.78433102e-07j, 2.69999997e+01+3.89273939e-07j,
      

In [211]:
ct_mult = (ctx[0] + ctx[0], ctx[1] + ctx[1])
ptx = decryptor.decrypt(ct_mult)
encoder.decode(ptx)

array([1.56744392e-07+1.98268184e-06j, 1.99999999e+00-8.29247252e-07j,
       4.00000134e+00+5.19755312e-07j, 5.99999816e+00+6.76088145e-08j,
       7.99999932e+00+4.11506353e-07j, 1.00000011e+01-3.94825659e-06j,
       1.20000004e+01+1.80624458e-06j, 1.39999994e+01+1.23667059e-06j,
       1.59999980e+01-1.17476369e-06j, 1.79999996e+01+2.03840824e-06j,
       2.00000013e+01-4.61762744e-06j, 2.19999994e+01-9.80603218e-07j,
       2.39999999e+01+3.00579529e-07j, 2.59999991e+01+1.09912328e-07j,
       2.79999994e+01+7.17273018e-07j, 3.00000018e+01-1.53929481e-06j,
       3.20000001e+01-3.89500165e-07j, 3.40000011e+01+1.44130212e-06j,
       3.59999990e+01-4.33931987e-07j, 3.80000006e+01+2.12722732e-07j,
       3.99999971e+01-3.75455564e-07j, 4.20000012e+01-4.50495435e-07j,
       4.39999972e+01+4.75527777e-07j, 4.60000009e+01+1.14367639e-06j,
       4.80000023e+01+2.47419500e-07j, 5.00000018e+01+1.54241766e-06j,
       5.19999992e+01+5.56866205e-07j, 5.39999994e+01+7.78547879e-07j,
      