In [1]:
from dataclasses import dataclass
from typing import Any, Dict
from time import time, process_time
import pickle
import tenseal as ts
import torch
import sys
#from pympler import asizeof
import statistics
@dataclass
class Results:
    """Class for keeping track of an item in inventory."""
    time: float
    value: Any
    shapes: Any

In [2]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def encrypt(weights, context):
    start = process_time()
    res = {}
    shapes = {}
    # Do encryption
    for key in weights.keys():
        v1 = weights[key]
        shapes[key] = v1.shape
        v1 = v1.view(-1)
        if len(v1) > 8192//2:
            vals = chunks(v1, 8192//2)
            broken = []
            for chunk in vals:
                broken.append(ts.ckks_vector(context, chunk))
            res[key] = broken
        else:
            res[key]= ts.ckks_vector(context, v1)
    stop = process_time()
    return Results(stop-start, res, shapes)


def encrypt_n(weights, context, n):
    start = process_time()
    res = {}
    shapes = {}
    # Do encryption
    for key in weights.keys():
        v1 = weights[key]
        shapes[key] = v1.shape
        v1 = v1.view(-1)
        if len(v1) > n//2:
            vals = chunks(v1, n//2)
            broken = []
            for chunk in vals:
                broken.append(ts.ckks_vector(context, chunk))
            res[key] = broken
        else:
            res[key]= ts.ckks_vector(context, v1)
    stop = process_time()
    return Results(stop-start, res, shapes)

In [3]:
def decrypt(weights, shapes:Dict):
    start = process_time()
    res = {}
    # Do deencryption
    for key in weights:
        if isinstance(weights[key], list):
            lst = []
            for val in weights[key]:
                lst.extend(val.decrypt())
            res[key] = torch.Tensor(lst).view(shapes[key])
                
        else:
            res[key] = torch.Tensor(weights[key].decrypt())
    stop = process_time()
    return Results(stop-start, res, None)

In [4]:
import humanize

def get_human_readable_bytes(byte_count):
    return humanize.naturalsize(byte_count)

def fsize(stuff, shapes)->int:
    """ The file size in bytes"""
    bytes_s = 0
    for val in stuff:
        if isinstance(stuff[val], ts.tensors.ckksvector.CKKSVector):
            proto = stuff[val].serialize()
            pickle_data = pickle.dumps(proto)
            bytes_s += len(pickle_data)
        else:
            for item in stuff[val]:
                proto = item.serialize()
                pickle_data = pickle.dumps(proto)
                bytes_s += len(pickle_data)
    return get_human_readable_bytes(len(pickle.dumps(shapes)) + bytes_s)


def fsize2(stuff)->int:
    """ The file size in bytes"""
    bytes_s = 0
    for val in stuff:
        if isinstance(stuff[val], ts.tensors.ckksvector.CKKSVector):
            proto = stuff[val].serialize()
            print(type(proto))
            bytes_s += len(proto)
        else:
            for item in stuff[val]:
                proto = item.serialize()
                bytes_s += len(proto)
    return bytes_s



In [5]:
def write_model(fname: str, model: Dict[str, Any]) -> int:
    with open(fname, "wb") as fptr:
        with io.BytesIO() as buff:
            pickle.dump(model, buff)
            buff.seek(0)
            size = buff.getbuffer().nbytes
            fptr.write((size).to_bytes(32, byteorder="big", signed=False))
            buff.seek(0)
            fptr.write(buff.getbuffer())
    return 32 + size

def read_model(fname)->Dict[str, Any]:
    with open(fname, "rb") as fptr:
        fptr.seek(32)
        data = pickle.load(fptr)
    return data

def get_model():
    return read_model("./data.pickle")

x=get_model()


In [6]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=32768,
            coeff_mod_bit_sizes=[60, 60, 60, 60,60]
          )
context.generate_galois_keys()
context.global_scale = 2**20

In [7]:
v1 = x[list(x.keys())[0]]
v1 = x["fc1.weight"]
shape = v1.shape
print(shape)
v1 = v1.view(-1)
print("===")
print(len(v1))
print(len(v1[:32768//2]))
v1 = v1[:32768//2]
print("---")
v1 = v1.tolist()
print(x.keys())
enc_v1 = ts.ckks_vector(context, v1)
print(enc_v1)
#result = enc_v1 + enc_v2
#result.decrypt()

torch.Size([128, 9216])
===
1179648
16384
---
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
<tenseal.tensors.ckksvector.CKKSVector object at 0x2b2234cef850>


In [None]:
ptimes_e = []
ptimes_d = []
sizes = []


for _ in range(100):
    context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 21, 21, 21, 21, 21, 21, 40]
          )
 
    context.generate_galois_keys()
    context.global_scale = 2**20
    x=get_model()

    res = encrypt(x, context) 
    ptimes_e.append(res.time)
    sizes.append( fsize2(res.value))

    #print(f"Encrypt Time {res.time}")
    #print(f"Encrypt size {fsize(res.value, res.shapes)}")
    res = decrypt(res.value, res.shapes)
    ptimes_d.append(res.time)
    #print(f"Decrypt Time {res.time}")
    
print("encrypt", statistics.fmean(ptimes_e))
print("decrypt", statistics.fmean(ptimes_d))
print("sizes", statistics.fmean(sizes))



In [None]:
ptimes_e = []
ptimes_d = []
sizes = []

for _ in range(100):
    context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 21, 21, 21, 40]
          )
 
    context.generate_galois_keys()
    context.global_scale = 2**20
    x=get_model()

    res = encrypt(x, context) 
    ptimes_e.append(res.time)
    #print(f"Encrypt Time {res.time}")
    sizes.append(fsize2(res.value))
    res = decrypt(res.value, res.shapes)
    ptimes_d.append(res.time)
    #print(f"Decrypt Time {res.time}")
    
print("encrypt", statistics.fmean(ptimes_e))
print("decrypt", statistics.fmean(ptimes_d))
print("sizes", statistics.fmean(sizes))



In [None]:
ptimes_e = []
ptimes_d = []
sizes = []


for _ in range(100):
    context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 21, 40]
          )
 
    context.generate_galois_keys()
    context.global_scale = 2**20
    x=get_model()

    res = encrypt(x, context) 
    ptimes_e.append(res.time)
    sizes.append(fsize2(res.value))

    #print(f"Encrypt Time {res.time}")
    #print(f"Encrypt size {fsize(res.value, res.shapes)}")
    res = decrypt(res.value, res.shapes)
    ptimes_d.append(res.time)
    #print(f"Decrypt Time {res.time}")
    
print("encrypt", statistics.fmean(ptimes_e))
print("decrypt", statistics.fmean(ptimes_d))
print("sizes", statistics.fmean(sizes))



In [None]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 21, 21, 21, 40]
          )
context.generate_galois_keys()
context.global_scale = 2**20
x=get_model()

res = encrypt(x, context)
print(f"Encrypt Time {res.time}")
print(f"Encrypt size {fsize(res.value, res.shapes)}")
res = decrypt(res.value, res.shapes)
print(f"Decrypt Time {res.time}")


In [None]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 21, 40]
          )
context.generate_galois_keys()
context.global_scale = 2**20
x=get_model()

res = encrypt(x, context)
print(f"Encrypt Time {res.time}")
print(f"Encrypt size {fsize(res.value, res.shapes)}")
res = decrypt(res.value, res.shapes)
print(f"Decrypt Time {res.time}")


In [None]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[60, 60, 60]
          )
 
context.generate_galois_keys()
context.global_scale = 2**20
x=get_model()

res = encrypt(x, context)
print(f"Encrypt Time {res.time}")
print(f"Encrypt size {fsize(res.value, res.shapes)}")
res = decrypt(res.value, res.shapes)
print(f"Decrypt Time {res.time}")


In [None]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[40, 40, 40]
          )
 
context.generate_galois_keys()
context.global_scale = 2**20
x=get_model()

res = encrypt(x, context)
print(f"Encrypt Time {res.time}")
print(f"Encrypt size {fsize(res.value, res.shapes)}")
res = decrypt(res.value, res.shapes)
print(f"Decrypt Time {res.time}")


In [None]:
context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=8192,
            coeff_mod_bit_sizes=[20, 20, 20]
          )
 
context.generate_galois_keys()
context.global_scale = 2**20
x=get_model()

res = encrypt(x, context)
print(f"Encrypt Time {res.time}")
print(f"Encrypt size {fsize(res.value, res.shapes)}")
res = decrypt(res.value, res.shapes)
print(f"Decrypt Time {res.time}")


In [None]:
for n in [8192, 16384, 32768]:
    ptimes_e = []
    ptimes_d = []
    sizes=[]
    for _ in range(100):
        context = ts.context(
            ts.SCHEME_TYPE.CKKS,
            poly_modulus_degree=n,
            coeff_mod_bit_sizes=[40, 40, 40]
          )
 
        context.generate_galois_keys()
        context.global_scale = 2**20
        x=get_model()

        res = encrypt_n(x, context, n) 
        ptimes_e.append(res.time)
        sizes.append(fsize2(res.value))
        res = decrypt(res.value, res.shapes)
        ptimes_d.append(res.time)
    
    print(f"{n:>6} encrypt", statistics.fmean(ptimes_e))
    print(f"{n:>6} decrypt", statistics.fmean(ptimes_d)) 
    print(f"{n:>6} size", statistics.fmean(sizes)) 

    print("-"*45)
