# Encryption in CKKS

In [61]:
from numpy.polynomial import Polynomial
import numpy as np

def round_coordinates(coordinates):
    """Gives the integral rest."""
    coordinates = coordinates - np.floor(coordinates)
    return coordinates

def coordinate_wise_random_rounding(coordinates):
    """Rounds coordinates randonmly."""
    r = round_coordinates(coordinates)
    f = np.array([np.random.choice([c, c-1], 1, p=[1-c, c]) for c in r]).reshape(-1)
    
    rounded_coordinates = coordinates - f
    rounded_coordinates = [int(coeff) for coeff in rounded_coordinates]
    return rounded_coordinates

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, context, scale:float):
        """Initializes with scale."""
        
        M = context.M
        self.context = context
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        self.create_sigma_R_basis()
        self.scale = scale
        
        self.slot_count = M // 4
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
    

    def pi(self, z: np.array) -> np.array:
        """Projects a vector of H into C^{N/2}."""

        N = self.M // 4
        return z[:N]


    def pi_inverse(self, z: np.array) -> np.array:
        """Expands a vector of C^{N/2} by expanding it with its
        complex conjugate."""

        z_conjugate = z[::-1]
        z_conjugate = [np.conjugate(x) for x in z_conjugate]
        return np.concatenate([z, z_conjugate])
    
    def create_sigma_R_basis(self):
        """Creates the basis (sigma(1), sigma(X), ..., sigma(X** N-1))."""

        self.sigma_R_basis = np.array(self.vandermonde(self.xi, self.M)).T
    

    def compute_basis_coordinates(self, z):
        """Computes the coordinates of a vector with respect to the orthogonal lattice basis."""
        output = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in self.sigma_R_basis])
        return output

    def sigma_R_discretization(self, z):
        """Projects a vector on the lattice using coordinate wise random rounding."""
        coordinates = self.compute_basis_coordinates(z)

        rounded_coordinates = coordinate_wise_random_rounding(coordinates)
        y = np.matmul(self.sigma_R_basis.T, rounded_coordinates)
        return y


    def encode(self, z: np.array) -> Polynomial:
        """Encodes a vector by expanding it first to H,
        scale it, project it on the lattice of sigma(R), and performs
        sigma inverse.
        """
        pi_z = self.pi_inverse(z)
        scaled_pi_z = self.scale * pi_z
        rounded_scale_pi_zi = self.sigma_R_discretization(scaled_pi_z)
        p = self.sigma_inverse(rounded_scale_pi_zi)

        # We round it afterwards due to numerical imprecision
        coef = np.round(np.real(p.coef)).astype(int)
        p = self.context.QPolynomial(coef)
        
        return p


    def decode(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by removing the scale, 
        evaluating on the roots, and project it on C^(N/2)"""
        rescaled_p = p / self.scale
        z = self.sigma(rescaled_p)
        pi_z = self.pi(z)
        return pi_z
        
    def decode_float(self, p:Polynomial) -> np.array:
        pi_z = self.decode(p)
        pi_z = np.real(pi_z)
        return pi_z

In [62]:
from numpy.polynomial import Polynomial

def mod_q(coeffs: np.ndarray, q: int) -> np.ndarray:
    """Reduce modulo q to (-q/2,q/2]"""
    
    r = coeffs % q
    # Coefficients larger than q/2 are sent to (-q/2,0]
    to_cycle = r > (q/2)
    r[to_cycle] = r[to_cycle] - q
    return r

def mod_q_polynomial(p: Polynomial, q: int) -> Polynomial:
    """Reduces modulo q the coefficients of a polynomial"""
    coeffs = p.coef
    coeffs_mod_q = mod_q(coeffs, q)
    p = Polynomial(coeffs_mod_q)
    return p

In [63]:
from numpy.polynomial import polynomial as poly
from numpy.polynomial import Polynomial
from __future__ import annotations
        
class QPolynomialGenerator:
    """Polynomial generator for polynomials in Z_q[X]/(X^N + 1)"""
    
    def __init__(self, N: int, q: int):
        self.q = q
        
        coeffs = np.zeros(N + 1)
        coeffs[0] = 1
        coeffs[-1] = 1
        
        self.poly_modulus = Polynomial(coeffs)
    
    def __call__(self, coef) -> QPolynomial:
        """Creates a polynomial from coefficients."""
        return QPolynomial(coef, self.q, self.poly_modulus, self)
    
class QPolynomial:
    """Polynomial in Z_q[X]/(X^N + 1). 

    Addition, substraction, multiplication and division are overload with
    the correct operations inZ_q[X]/(X^N + 1)."""

    def __init__(self, coef, q, poly_modulus, generator):
        self.p = Polynomial(coef)
        self.q = q
        self.poly_modulus = poly_modulus
        self.generator = generator

    def __getattr__(self,k):
        if k in self.__dict__.keys():
            return getattr(self, k)
        else:
            try:
                return getattr(self.p, k)
            except AttributeError as e:
                print(e)

    def __mul__(self, right) -> QPolynomial:
        if isinstance(right, QPolynomial):
            right = right.p
        p_mul = self.p * right % self.poly_modulus
        p_mul = self.generator(p_mul.coef)
        return p_mul

    def __rmul__(self, left) -> QPolynomial:
        return self.__mul__(left)

    def __add__(self, right) -> QPolynomial:
        if isinstance(right, QPolynomial):
            right = right.p
        p_add = (self.p + right) % self.poly_modulus
        p_add = self.generator(p_add.coef)
        return p_add

    def __radd__(self, left) -> QPolynomial:
        return self.__add__(left)

    def __sub__(self, qpoly: QPolynomial) -> QPolynomial:
        p_sub = (self.p - qpoly.p) % self.poly_modulus
        p_sub = self.generator(p_sub.coef)
        return p_sub

    def __rsub__(self, qpoly: QPolynomial) -> QPolynomial:
        p_sub = (qpoly.p - self.p) % self.poly_modulus
        p_sub = self.generator(p_sub.coef)
        return p_sub

    def __neg__(self):
        p_neg = - self.p %self.poly_modulus
        p_neg = self.generator(p_neg.coef)
        return p_neg

    def __pos__(self):
        p_pos = self.p %self.poly_modulus
        p_pos = self.generator(p_pos.coef)
        return p_pos

    def __truediv__ (self, scale):
        p_div = self.p / scale
        p_div = self.generator(p_div.coef)
        return p_div

    def __call__(self, x):
        return self.p(x)

    def __repr__(self):
        return self.p.__repr__()

In [64]:
class PolynomialSampler:
    """Base class to sample polynomials."""
    def __init__(self, context):
        self.context = context
        
    def __getattr__(self,k):
        """Context variables are directly linked to the instance."""
        return getattr(self.context, k)
        
    def polynomial(self, coeffs):
        p = self.context.QPolynomial(coeffs)
        return p

In [65]:
class UniformPolynomial(PolynomialSampler):
    def sample(self) -> QPolynomial:
            
        coeffs = np.random.choice(self.q, size=self.N)
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p
    
    def sample_manually(self, q) -> QPolynomial:
        coeffs = np.random.choice(q, size=self.N)
        coeffs = mod_q(coeffs, q)
        
        p = self.polynomial(coeffs)
        return p

In [66]:
class ZO(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([0,1,-1], size=self.N, p=[1-self.p, self.p /2, self.p/2])
        
        p = self.polynomial(coeffs)
        return p

In [67]:
class HWT(PolynomialSampler):
    def sample(self):
        coeffs = np.random.choice([-1,1], size=self.N)
        
        slots = np.random.choice(range(self.N), self.N - self.h, replace=False)
        
        coeffs[slots] = 0
            
        p = self.polynomial(coeffs)
        return p

In [68]:
class DG(PolynomialSampler):
    def sample(self):
        coeffs = np.random.normal(np.zeros(self.N), self.sigma)
        
        coeffs = np.array(coordinate_wise_random_rounding(coeffs))
        coeffs = mod_q(coeffs, self.q)
        
        p = self.polynomial(coeffs)
        return p

In [69]:
class Context:
    def __init__(self, N, moduli):
        self.N = N
        self.M = N * 2
        
        self.q = np.cumprod(moduli[:-1])[-1]
        self.special_prime = moduli[-1]
        
        self.QPolynomial = QPolynomialGenerator(N, q)
        
        self.setup_parameters()
        
        self.hwt = HWT(self)
        self.dg = DG(self)
        self.uniform = UniformPolynomial(self)
        self.zo = ZO(self)
        
    def setup_parameters(self):
        self.h = 64
        self.sigma = 3
        self.p = 0.5

In [70]:
class Keygen:
    def __init__(self, context):
        self.context = context
        
    def generate_secret_key(self):
        s = self.context.hwt.sample()
        return s
    
    def generate_public_key(self, s):
        a = self.context.uniform.sample()
        e = self.context.dg.sample()
        
        b = -(a * s) + e

        pk = (b, a)
        return pk
    
    def generate_relin_key(self, s):
        a = self.context.uniform.sample_manually(self.context.q * self.context.special_prime)
        e = self.context.dg.sample()
        
        b = -(a * s) + e + self.context.special_prime * s * s
        relin_key = (b,a)
        return relin_key

In [71]:
class Encryptor:
    def __init__(self, context, pk):
        self.pk = pk
        self.context = context
        
    def encrypt(self, ptx):
        v = self.context.zo.sample()
        e0 = self.context.dg.sample()
        e1 = self.context.dg.sample()

        v_pk = (self.pk[0] * v, self.pk[1] * v)
        ctx = (v_pk[0] + e0 + ptx, v_pk[1] + e1)
        
        return ctx

In [72]:
class Decryptor:
    def __init__(self, context, s):
        self.s = s
        self.context = context
        
    def decrypt(self, ctx):
        ptx = ctx[0] + self.s * ctx[1]
        return ptx

In [73]:
N = 512
M = N * 2
scale = pow(2,15)

moduli = [scale * pow(2,5), scale, scale * pow(2,5)]

In [74]:
context = Context(N, moduli)
keygen = Keygen(context)

In [94]:
s = keygen.generate_secret_key()
pk = keygen.generate_public_key(s)
relin_key = keygen.generate_relin_key(s)

In [95]:
encoder = CKKSEncoder(context, scale)

In [96]:
encryptor = Encryptor(context, pk)
decryptor = Decryptor(context, s)

In [97]:
z = np.arange(N//2)

In [98]:
ptx = encoder.encode(z)

In [99]:
ctx = encryptor.encrypt(ptx)
ptx = decryptor.decrypt(ctx)
encoder.decode(ptx)

array([1.95242507e-02+0.02307015j, 9.87522660e-01+0.01956589j,
       1.95971907e+00-0.01554842j, 2.99916504e+00-0.03097077j,
       4.00722961e+00-0.01257957j, 4.97797569e+00-0.01011068j,
       6.01533298e+00+0.05620058j, 7.00060952e+00-0.01955817j,
       8.02398273e+00-0.00705084j, 9.04106358e+00+0.01193944j,
       9.98658761e+00+0.01726699j, 1.10093786e+01+0.02025468j,
       1.20172724e+01-0.01604947j, 1.30048634e+01-0.0051589j ,
       1.40093209e+01+0.02733144j, 1.49969724e+01+0.00509166j,
       1.60075420e+01-0.02975023j, 1.70627424e+01+0.03032853j,
       1.79742599e+01+0.0270433j , 1.89917879e+01-0.01220565j,
       2.00036910e+01-0.04791913j, 2.10056464e+01-0.03520789j,
       2.20004967e+01-0.0195347j , 2.29967046e+01+0.01381751j,
       2.40184623e+01-0.04258035j, 2.50199947e+01-0.01132079j,
       2.59529199e+01+0.06636023j, 2.70008553e+01+0.00230454j,
       2.79954756e+01+0.0482255j , 2.90534245e+01+0.02586398j,
       3.00078917e+01-0.02711116j, 3.10229399e+01+0.010

In [100]:
ct_add = (ctx[0] + ctx[0], ctx[1] + ctx[1])
ptx = decryptor.decrypt(ct_add)
encoder.decode(ptx)

array([3.90485014e-02+0.04614029j, 1.97504532e+00+0.03913179j,
       3.91943814e+00-0.03109683j, 5.99833008e+00-0.06194153j,
       8.01445923e+00-0.02515915j, 9.95595138e+00-0.02022136j,
       1.20306660e+01+0.11240117j, 1.40012190e+01-0.03911633j,
       1.60479655e+01-0.01410167j, 1.80821272e+01+0.02387888j,
       1.99731752e+01+0.03453397j, 2.20187572e+01+0.04050935j,
       2.40345448e+01-0.03209894j, 2.60097268e+01-0.0103178j ,
       2.80186418e+01+0.05466287j, 2.99939448e+01+0.01018333j,
       3.20150839e+01-0.05950045j, 3.41254849e+01+0.06065706j,
       3.59485198e+01+0.0540866j , 3.79835759e+01-0.0244113j ,
       4.00073820e+01-0.09583825j, 4.20112928e+01-0.07041578j,
       4.40009934e+01-0.0390694j , 4.59934091e+01+0.02763503j,
       4.80369246e+01-0.08516069j, 5.00399895e+01-0.02264158j,
       5.19058397e+01+0.13272046j, 5.40017107e+01+0.00460908j,
       5.59909511e+01+0.096451j  , 5.81068491e+01+0.05172796j,
       6.00157834e+01-0.05422233j, 6.20458799e+01+0.021

In [101]:
a = np.ones(encoder.slot_count)
b = np.ones(encoder.slot_count) * 5

In [102]:
ptx1 = encoder.encode(a)
ptx2 = encoder.encode(b)

In [103]:
ctx1 = encryptor.encrypt(ptx1)
ctx2 = encryptor.encrypt(ptx2)

In [104]:
ct_mul = (ctx1[0] * ctx2[0], ctx1[1] * ctx2[0] + ctx1[0] * ctx2[1], ctx1[1] * ctx2[1])

In [105]:
d = (ct_mul[2] * relin_key[0] / context.special_prime, ct_mul[2] * relin_key[1] / context.special_prime)

In [106]:
d = (context.QPolynomial(coordinate_wise_random_rounding(d[0].coef)), 
     context.QPolynomial(coordinate_wise_random_rounding(d[1].coef)))

In [107]:
ct_relin = (ct_mul[0] + d[0], ct_mul[1] + d[1])

In [108]:
ptx = decryptor.decrypt(ct_relin)
encoder.decode(ptx) / scale

array([ 3.67942309e+12+1.44222254e+13j,  1.72190318e+12-6.71147807e+12j,
       -9.63824524e+12+3.09514183e+12j, -1.53583301e+13+2.48433341e+12j,
       -5.84136476e+12-9.41468967e+12j,  5.79015404e+12+7.93774668e+12j,
        4.15217958e+12-7.77578484e+12j,  1.22094194e+12+6.65815750e+12j,
        5.03807402e+12+5.44352568e+12j,  4.33078783e+12-8.32285611e+12j,
       -5.09398186e+12-1.81276414e+13j,  1.31595816e+13-1.33159869e+13j,
        6.44714033e+12+8.43099497e+12j,  3.85575296e+12+1.81740218e+11j,
       -9.73212589e+12+1.48712206e+13j, -1.02839119e+13+4.17555812e+12j,
        2.29673298e+12-5.58790190e+12j,  1.31368059e+13-1.05096267e+13j,
        4.44610683e+12-5.54353324e+12j,  2.69656638e+12+9.04671654e+12j,
        3.15566930e+12-1.48968755e+13j,  6.30210709e+12+8.88434453e+12j,
       -6.03966568e+12-1.39449105e+13j,  2.32434458e+13-3.37212146e+12j,
        1.49180933e+13+2.02971954e+12j,  2.47705488e+13+8.93393803e+12j,
       -2.67876684e+12-1.84447872e+13j, -2.31294214

In [109]:
pt_mul = ct_mul[0] + ct_mul[1] * s + ct_mul[2] * s * s

In [110]:
encoder.decode(pt_mul) / scale

array([ 5.54214881e+02-2.89972637e+02j,  6.81777741e+02+7.48734060e+02j,
        5.26697036e+02-3.75245932e+01j, -1.52396108e+02-7.01592983e+01j,
       -6.70670510e+02-2.20930285e+02j, -3.29489140e+01+2.92085089e+02j,
       -2.26052514e+00-5.07715403e+02j,  2.27298545e+01+1.94358918e+02j,
       -1.52586823e+02+1.39701539e+02j, -3.41207426e+02+1.15570690e+02j,
       -3.32481562e+01+7.05404080e+02j,  2.82900981e+02+5.29701631e+02j,
        3.19832364e+02+3.56942833e+01j, -3.09452562e+02-3.09068446e+01j,
       -1.71144758e+02+1.18537633e+03j,  3.83627614e+02+7.35277388e+01j,
        2.22443710e+02+1.05744502e+02j,  2.56457586e+02+3.61654476e+02j,
        2.46540987e+02+8.32959592e+02j,  4.20249338e+02-7.40552205e+01j,
        1.87535618e+02-3.02408535e+02j, -9.30404983e+02-9.07204549e+01j,
        2.26995044e+02+2.43341447e+02j, -2.06210182e+02-3.22705225e+00j,
        1.80085368e+00+3.78391011e+02j,  6.35774765e+01-6.19547179e+02j,
       -4.34222117e+02-1.88768667e+03j, -1.71133170