In [None]:
# default_exp seal_helper

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# seal_helper

In [None]:
# export
from seal import *
from typing import List

In [None]:
# export
def print_vector(vec, print_size=4, prec=3):
    """Prints a vector with a given level of precision and print size"""
    slot_count = len(vec)
    print()
    if slot_count <= 2*print_size:
        print("    [", end="")
        for i in range(slot_count):
            print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="")
    else:
        print("    [", end="")
        for i in range(print_size):
            print(" " + (f"%.{prec}f" % vec[i]) + ",", end="")
        if len(vec) > 2*print_size:
            print(" ...,", end="")
        for i in range(slot_count - print_size, slot_count):
            print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="")
    print()
    
def ptx_value(ptx, i=0):
    """Returns the value of a Plaintext at a given position"""
    result = DoubleVector()
    encoder.decode(ptx,result)
    value = result[i]
    return value
    
def ctx_value(ctx, i=0):
    """Returns the value of a Ciphertext at a given position"""
    ptx = Plaintext()
    decryptor.decrypt(ctx, ptx)
    value = ptx_value(ptx,i)
    return value
    
def print_ctx(ctx):
    ptx = Plaintext()
    decryptor.decrypt(ctx, ptx)
    result = DoubleVector()
    encoder.decode(ptx,result)
    print_vector(result, 3, 7)
    
def print_ptx(ptx):
    result = DoubleVector()
    encoder.decode(ptx,result)
    print_vector(result, 3, 7)
    
def print_range_ctx(ctx, end=0, begin=0):
    ptx = Plaintext()
    decryptor.decrypt(ctx, ptx)
    
    print_range_ptx(ptx, end, begin)
        
def print_range_ptx(ptx, end=0, begin=0):
    r = range(begin,end)
    
    values = DoubleVector()
    encoder.decode(ptx, values)
    for i in r:
        print(f"{i} : {values[i]}")

In [None]:
# export
def float_to_ctx(x, encoder: CKKSEncoder, encryptor: Encryptor):
    ptx = Plaintext()
    if len(x) > 1:
        x = list(x)
        x = DoubleVector(x)
    encoder.encode(x, scale, ptx)

    ctx = Ciphertext()
    encryptor.encrypt(ptx, ctx)
    
    return ctx

In [None]:
# export
def vrep(x, n):
    k = n // len(x)
    rest = n % len(x)
    output = x * k + x[:rest]
    return output

In [None]:
# export
def create_seal_globals(globals: dict, poly_modulus_degree: int, moduli: List[int], PRECISION_BITS: int):
    """Creates SEAL context variables and populates the globals with it."""
    parms = EncryptionParameters(scheme_type.CKKS)
    parms.set_poly_modulus_degree(poly_modulus_degree)
    parms.set_coeff_modulus(CoeffModulus.Create(
        poly_modulus_degree, moduli))

    context = SEALContext.Create(parms)

    keygen = KeyGenerator(context)
    
    globals["parms"] = parms
    globals["context"] = context
    globals["scale"] = pow(2.0, PRECISION_BITS)
    
    globals["public_key"] = keygen.public_key()
    globals["secret_key"] = keygen.secret_key()
    globals["relin_keys"] = keygen.relin_keys()
    globals["galois_keys"] = keygen.galois_keys()

    globals["encryptor"] = Encryptor(context, globals["public_key"])
    globals["evaluator"] = Evaluator(context)
    globals["decryptor"] = Decryptor(context, globals["secret_key"])
    globals["encoder"] = CKKSEncoder(context)
    
def append_globals_to_builtins(globals, builtins):
    """Appends the SEAL context variables to the builtins.
    
    This allows the following variables to be called from functions globally. Only use for testing purposes.
    """
    
    variables = ["public_key", "secret_key", "relin_keys", "galois_keys",
                 "encryptor", "evaluator", "decryptor", "encoder", "scale", "parms", "context"]
    
    for var in variables:
        setattr(builtins, var, globals[var])

In [None]:
# export
from pathlib import Path

def save_seal_globals(globals, path:Path = Path("seal")):
    parms = globals["parms"]
    
    public_key = globals["public_key"]
    secret_key = globals["secret_key"]
    relin_keys = globals["relin_keys"]
    galois_keys = globals["galois_keys"]
    
    if not path.exists():
        path.mkdir()
        
    parms.save(str(path/"parms"))
    
    public_key.save(str(path/"public_key"))
    secret_key.save(str(path/"secret_key"))
    relin_keys.save(str(path/"relin_keys"))
    galois_keys.save(str(path/"galois_keys"))

def load_seal_globals(globals, path:Path = Path("seal"), load_pk:bool = False, load_sk:bool = False):
    """Loads and populates SEAL globals from saved files."""
    if not path.exists():
        raise FileNotFoundError("Path not found")
        
    parms = EncryptionParameters(scheme_type.CKKS)
    parms.load(str(path/"parms"))
    
    context = SEALContext.Create(parms)
    
    if load_pk:
        public_key = PublicKey()
        public_key.load(context, str(path/"public_key"))
        globals["public_key"] = public_key
        globals["encryptor"] = Encryptor(context, public_key)
        
    if load_sk:
        secret_key = SecretKey()
        secret_key.load(context, str(path/"secret_key"))
        globals["secret_key"] = secret_key
        globals["decryptor"] = Decryptor(context, secret_key)
    
    relin_keys = RelinKeys()
    relin_keys.load(context, str(path/"relin_keys"))
    
    galois_keys = GaloisKeys()
    galois_keys.load(context, str(path/"galois_keys"))
    
    globals["relin_keys"] = relin_keys
    globals["galois_keys"] = galois_keys

    globals["evaluator"] = Evaluator(context)
    globals["encoder"] = CKKSEncoder(context)

The examples bellow shows how one can play with SEAL.

First we initialize the SEAL context :

In [None]:
from seal import *

poly_modulus_degree = 4096
moduli = [35,30,35]
PRECISION_BITS = 30

create_seal_globals(globals(), poly_modulus_degree, moduli, PRECISION_BITS)

We can also save those parameters.

In [None]:
save_seal_globals(globals())

In order to load them later, or to send them to a third party which will do computation on our data.

In [None]:
load_seal_globals(globals())

Now we can start using the SEAL context to encrypt data and perform arithmetic on it.

In [None]:
def dot_product_plain(ctx, ptx, n_slots:int):
    """Performs the dot product between ctx and ptx."""
    

In [None]:
# First we encode x in a Plaintext
x = DoubleVector([1,2,3])

ptx = Plaintext()
encoder.encode(x, scale, ptx)

# Then we display it
print_ptx(ptx)

# Then we encrypt it
ctx = Ciphertext()
encryptor.encrypt(ptx, ctx)

print_ctx(ctx)

evaluator.add_plain_inplace(ctx, ptx)
print_ctx(ctx)


    [ 1.0000000, 2.0000000, 3.0000000, ..., -0.0000000, -0.0000000, -0.0000000 ]


    [ 0.9999997, 2.0000005, 2.9999996, ..., -0.0000000, -0.0000000, -0.0000001 ]


    [ 1.9999997, 4.0000005, 5.9999996, ..., -0.0000001, -0.0000000, -0.0000001 ]



In [None]:
import numpy as np

n_slot = 5
n = int(np.ceil(np.log2(n_slot)))

output = Ciphertext()
evaluator.multiply_plain(ctx, ptx)
for i in range(n):
    

3