# CKKS with Numpy

In [427]:
import numpy as np

# First we set the parameters
M = 128
N = M //2

# We set xi, which will be used in our computations
xi = np.exp(2 * np.pi * 1j / M)
xi

(0.9987954562051724+0.049067674327418015j)

In [428]:
from numpy.polynomial import Polynomial
from numpy.polynomial.polynomial import polyval

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, M: int):
        """Initialization of the encoder for M a power of 2. 
        
        xi, which is an M-th root of unity will, be used as a basis for our computations.
        """
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        # A = CKKSEncoder.vandermonde(self.xi, M)
        N = M // 2
        root = self.xi
        roots = np.power(root, 2 * np.arange(N) + 1)
        A = np.vander(roots, N, increasing=True)    

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            # output = polyval(root, p)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)

In [458]:
# First we initialize our encoder
encoder = CKKSEncoder(M)

b = np.array(list(range(N)))
print("Message:\n", b)
p = encoder.sigma_inverse(b)
print("\nMessage_Poly:\n", p)
b_reconstructed = encoder.sigma(p)
print("\nMessage_Reconstructed:\n", b_reconstructed)
print("\nError: ", np.linalg.norm(b_reconstructed - b))

Message:
 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
(64, 64) (64,)

Message_Poly:
 (31.499999999999652-1.4566126083082054e-13j) -
(1.0036416142611415e-13-10.190008123548022j)·x¹ -
(1.092459456231154e-13-5.101148618689225j)·x² -
(1.3078427230084344e-13-3.407608418468759j)·x³ -
(1.177946629127291e-13-2.5629154477414877j)·x⁴ -
(7.394085344003543e-14-2.057781009953395j)·x⁵ -
(6.550315845288424e-14-1.7224470982383433j)·x⁶ -
(7.30526750203353e-14-1.4841646163141828j)·x⁷ -
(9.137135492665038e-14-1.3065629648763863j)·x⁸ -
(7.671641100159832e-14-1.1694399334328702j)·x⁹ -
(5.284661597215745e-14-1.0606776859903406j)·x¹⁰ -
(4.496403249731884e-14-0.972568237861962j)·x¹¹ -
(4.96269692007445e-14-0.8999762231364206j)·x¹² -
(5.767608612927688e-14-0.8393496454155284j)·x¹³ -
(4.846123502488808e-14-0.7881546234512515j)·x¹⁴ -
(5.2791104820926193e-14-0.744

In [459]:
b1, b2 = np.arange(N), np.arange(N)
p1, p2 = encoder.sigma_inverse(b1), encoder.sigma_inverse(b2)
modulo = np.zeros(N+1)
modulo[0] = 1; modulo[-1] = 1

encoder.sigma(p1 * p2 % modulo)

(64, 64) (64,)
(64, 64) (64,)


array([0.000e+00+2.74954301e-13j, 1.000e+00-4.15883237e-12j,
       4.000e+00-7.28522041e-12j, 9.000e+00-8.25155853e-12j,
       1.600e+01-1.08095124e-11j, 2.500e+01-1.04684519e-11j,
       3.600e+01-6.03466519e-12j, 4.900e+01-8.76314930e-12j,
       6.400e+01-9.21789665e-12j, 8.100e+01-4.78410998e-12j,
       1.000e+02-9.21789665e-12j, 1.210e+02-6.31888229e-12j,
       1.440e+02-5.80729152e-12j, 1.690e+02-5.12517049e-12j,
       1.960e+02-3.30618109e-12j, 2.250e+02-4.87648053e-12j,
       2.560e+02-1.67903822e-12j, 2.890e+02-3.78745016e-13j,
       3.240e+02+1.24129242e-12j, 3.610e+02+1.92341345e-12j,
       4.000e+02+2.15078712e-12j, 4.410e+02+2.26447396e-12j,
       4.840e+02+2.49184764e-12j, 5.290e+02+1.41182268e-12j,
       5.760e+02+2.66237789e-12j, 6.250e+02+4.70874097e-12j,
       6.760e+02+2.88975157e-12j, 7.290e+02+3.88451140e-12j,
       7.840e+02+6.07298303e-12j, 8.410e+02+2.70501046e-12j,
       9.000e+02+5.33401858e-12j, 9.610e+02-3.68086875e-13j,
       1.024e+03+5.17947

In [460]:
encoder.sigma(
    Polynomial(np.polydiv(np.polymul(p1.coef, p2.coef), modulo)[0])
    )

array([-842.42253112-1.24820138e+03j, -745.73070861-6.03099511e+02j,
       -646.97008121-2.90830746e+02j, -546.77395971-9.41375932e+01j,
       -445.77984894+4.13248562e+01j, -344.62330827+1.37368824e+02j,
       -243.93183081+2.05003365e+02j, -144.31880029+2.50641897e+02j,
        -46.37758388+2.78428784e+02j,   49.32418228+2.91271912e+02j,
        142.25005973+2.91357323e+02j,  231.8999739 +2.80428168e+02j,
        317.81503005+2.59945573e+02j,  399.58193274+2.31185747e+02j,
        476.83696605+1.95300492e+02j,  549.26949637+1.53355649e+02j,
        616.624964  +1.06355649e+02j,  678.70733464+5.52590336e+01j,
        735.38098738+9.87909301e-01j,  786.57202073-5.55667551e+01j,
        832.26896441-1.13542722e+02j,  872.52288973-1.72106117e+02j,
        907.44691783-2.30452483e+02j,  937.21513018-2.87808830e+02j,
        962.06089216-3.43436338e+02j,  982.27460555-3.96633464e+02j,
        998.20091187-4.46739262e+02j, 1010.23537323-4.93136741e+02j,
       1018.8206629 -5.35256148e+0

# CKKS with Jax

In [285]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import numpy as np
import jax
import jax.numpy as jnp
from jax import device_put
from jax import grad, jit, vmap

# First we set the parameters
M = 128
N = M //2

# We set xi, which will be used in our computations
xi = jnp.exp(2 * jnp.pi * 1j / M)
xi

DeviceArray(0.99879545+0.04906768j, dtype=complex64, weak_type=True)

In [469]:
from functools import partial

class CKKSEncoder:
    
    def __init__(self, M: int):
        """Initialization of the encoder for M a power of 2. 
        
        xi, which is an M-th root of unity will, be used as a basis for our computations.
        """
        self.xi = jnp.exp(2 * jnp.pi * 1j / M)
        self.M = M
        
    @partial(jit, static_argnums=(0,))
    def sigma_inverse(self, b: jnp.array) -> jnp.array:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        N = M // 2
        root = self.xi
        roots = jnp.power(root, 2 * jnp.arange(N) + 1)
        A = jnp.vander(roots, N)

        # Then we solve the system
        coeffs = jnp.linalg.solve(A, b)

        # Finally we output the polynomial
        return coeffs

    @partial(jit, static_argnums=(0,))
    def sigma(self, p: jnp.array) -> jnp.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = jnp.polyval(p, root)
            outputs.append(output)
        return jnp.array(outputs)

## Encoding&Decoding

In [473]:
# First we initialize our encoder
encoder = CKKSEncoder(M)

b = jnp.array(list(range(N)))
print("Message:\n", b)
p = encoder.sigma_inverse(b)
# print("\nMessage_Poly:\n", p)
b_reconstructed = encoder.sigma(p)
print("\nMessage_Reconstructed:\n", b_reconstructed)
print("\nError: ", jnp.linalg.norm(b_reconstructed - b))

Message:
 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]

Message_Reconstructed:
 [ 0  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22
 23 24 25 26 27 29 30 31 31 32 33 34 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63]

Error:  0.010645335


  return _convert_element_type(operand, new_dtype, weak_type=False)


## Multiplications

In [479]:
b1, b2 = jnp.arange(N), jnp.arange(N)
p1, p2 = encoder.sigma_inverse(b1), encoder.sigma_inverse(b2)

In [481]:
# encoder.sigma()
pp = jnp.polymul(p1, p2)
modulo = np.zeros(N+1)
modulo[0] = 1; modulo[-1] = 1
_, r = np.polydiv(pp, modulo)
encoder.sigma(r).astype(jnp.int32)

DeviceArray([   0,    1,    4,    9,   16,   25,   36,   49,   63,   81,
               99,  121,  143,  168,  195,  224,  255,  288,  323,  360,
              399,  440,  483,  529,  575,  624,  675,  728,  783,  841,
              900,  961, 1023, 1088, 1155, 1225, 1296, 1369, 1444, 1521,
             1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401,
             2500, 2601, 2704, 2809, 2916, 3025, 3136, 3249, 3364, 3481,
             3600, 3721, 3844, 3969], dtype=int32)

In [483]:
b1*b2

DeviceArray([   0,    1,    4,    9,   16,   25,   36,   49,   64,   81,
              100,  121,  144,  169,  196,  225,  256,  289,  324,  361,
              400,  441,  484,  529,  576,  625,  676,  729,  784,  841,
              900,  961, 1024, 1089, 1156, 1225, 1296, 1369, 1444, 1521,
             1600, 1681, 1764, 1849, 1936, 2025, 2116, 2209, 2304, 2401,
             2500, 2601, 2704, 2809, 2916, 3025, 3136, 3249, 3364, 3481,
             3600, 3721, 3844, 3969], dtype=int32)

# Performance - Jax vs NumPy

## Polymul

In [398]:
N = 4096
modulo = np.zeros(N+1); modulo[0] = 1; modulo[-1] = 1
p2 = np.random.rand(N)
p1 = device_put(p2)

In [391]:
%timeit pp1 = jnp.polymul(p1, p1) # .block_until_ready()

652 µs ± 124 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [392]:
%timeit pp2 = np.polymul(p2, p2)

3.44 ms ± 353 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [477]:
pp1 = jnp.polymul(p1, p1)
pp2 = np.polymul(p2, p2)

In [478]:
pp1, pp2

(DeviceArray([-1.03861481e+02-6.4343917e-03j,
              -1.03986572e+02-3.7704463e-04j,
              -9.54910965e+01+2.4767614e-03j,
              -8.70177460e+01+4.1205338e-03j,
              -7.97153549e+01+5.0234604e-03j,
              -7.35813370e+01+5.7864068e-03j,
              -6.84279709e+01+5.9999418e-03j,
              -6.40705872e+01+6.1287032e-03j,
              -6.03543930e+01+6.5655857e-03j,
              -5.71610069e+01+6.2806401e-03j,
              -5.43963242e+01+6.2824506e-03j,
              -5.19885483e+01+6.6246684e-03j,
              -4.98800316e+01+6.6925110e-03j,
              -4.80270462e+01+6.7446260e-03j,
              -4.63928223e+01+6.7938929e-03j,
              -4.49486084e+01+6.9023184e-03j,
              -4.36706390e+01+6.6472217e-03j,
              -4.25398407e+01+6.7230482e-03j,
              -4.15399132e+01+6.2594768e-03j,
              -4.06574516e+01+6.3963202e-03j,
              -3.98820076e+01+6.6454317e-03j,
              -3.92040215e+01+6.52

In [508]:
N = 4096
M = 1000
P1 = [jax.random.uniform(jax.random.PRNGKey(i), (N,)) for i in range(M)]
P2 = [np.random.rand(N) for i in range(M)]

@jit
def polymul_many_times_gpu():
    for p in P1:
        jnp.polymul(p, p)

def polymul_many_times_cpu():
    for p in P2:
        np.polymul(p, p)

In [492]:
%timeit polymul_many_times_gpu()

The slowest run took 115.07 times longer than the fastest. This could mean that an intermediate result is being cached.
190 µs ± 380 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [493]:
%timeit polymul_many_times_cpu()

3.82 s ± 417 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [498]:
len(P1), len(P2)

(1000, 1000)

In [497]:
P1[0].shape, P2[0].shape

((4096,), (4096,))

## Polyval

In [514]:
N = 4096
M = 100
R1 = [jax.random.uniform(jax.random.PRNGKey(i), (N,)) for i in range(M)]
R2 = [np.random.rand(N) for i in range(M)]

@jit
def polyval_gpu():
    for r in R1:
        jnp.polyval(r, 0.1)

def polyval_cpu():
    for r in R2:
        np.polyval(r, 0.1)

In [515]:
%timeit polyval_gpu()

The slowest run took 48.24 times longer than the fastest. This could mean that an intermediate result is being cached.
83.6 µs ± 172 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [516]:
%timeit polyval_cpu()

67.9 ms ± 1.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
