In [1]:
import piheaan as heaan
from piheaan.math import sort
from piheaan.math import approx # for piheaan math function
import math
import numpy as np
import pandas as pd
import os

In [2]:
# set parameter
params = heaan.ParameterPreset.FGb
context = heaan.make_context(params) # context has paramter information
heaan.make_bootstrappable(context) # make parameter bootstrapable

# create and save keys
key_file_path = "./keys"
sk = heaan.SecretKey(context) # create secret key
os.makedirs(key_file_path, mode=0o775, exist_ok=True)
sk.save(key_file_path+"/secretkey.bin") # save secret key

key_generator = heaan.KeyGenerator(context, sk) # create public key
key_generator.gen_common_keys()
key_generator.save(key_file_path+"/") # save public key

In [3]:
# load secret key and public key
# When a key is created, it can be used again to save a new key without creating a new one
key_file_path = "./keys"

sk = heaan.SecretKey(context,key_file_path+"/secretkey.bin") # load secret key
pk = heaan.KeyPack(context, key_file_path+"/") # load public key
pk.load_enc_key()
pk.load_mult_key()

eval = heaan.HomEvaluator(context,pk) # to load piheaan basic function
dec = heaan.Decryptor(context) # for decrypt
enc = heaan.Encryptor(context) # for encrypt

In [4]:
# log_slots is used for the number of slots per ciphertext
# It depends on the parameter used (ParameterPreset)
# The number '15' is the value for maximum number of slots,
# but you can also use a smaller number (ex. 2, 3, 5, 7 ...)
# The actual number of slots in the ciphertext is calculated as below.
log_slots = 15 
num_slots = 2**log_slots

In [5]:
input_msg = heaan.Message(log_slots)
for k in range(num_slots):
    input_msg[k] = k*0.00001

input_ctxt = heaan.Ciphertext(context)
enc.encrypt(input_msg, pk, input_ctxt)



In [6]:
# mean
def cal_mean(input_ctxt):
    # 1. sum all value
    duplicate_ctxt = heaan.Ciphertext(input_ctxt)
    tmp_ctxt = heaan.Ciphertext(input_ctxt)
    for i in range(int(np.log2(num_slots))):
        eval.left_rotate(duplicate_ctxt, 2**i, tmp_ctxt)
        if duplicate_ctxt.level == eval.min_level_for_bootstrap:
            eval.bootstrap(duplicate_ctxt, duplicate_ctxt)
        if tmp_ctxt.level == eval.min_level_for_bootstrap:
            eval.bootstrap(tmp_ctxt, tmp_ctxt)

        eval.add(duplicate_ctxt, tmp_ctxt, duplicate_ctxt)

    # Lets check sum
    # output_msg = piheaan.Message(log_slots)
    # dec.decrypt(duplicate_ctxt,sk,output_msg)
    # print(output_msg[:5])
    
    # 2. divide by the total number
    eval.mult(duplicate_ctxt, 1/num_slots, duplicate_ctxt)
    if duplicate_ctxt.level == eval.min_level_for_bootstrap:
            eval.bootstrap(duplicate_ctxt, duplicate_ctxt)

    return duplicate_ctxt


In [8]:
mean_ctxt = cal_mean(input_ctxt)
mean_msg = heaan.Message(log_slots)
dec.decrypt(mean_ctxt, sk, mean_msg)

# Find the average in plaintext
msg = [i*0.00001 for i in range(0,num_slots)]
plain_mean = np.mean(msg)

print("piheaan result : ", mean_msg[0].real)
print("palintext result : ", plain_mean)

piheaan result :  0.16383500000029116
palintext result :  0.163835


In [11]:
# variation
# 1. mean(input**2)
square_ctxt = heaan.Ciphertext(context)
eval.mult(input_ctxt,input_ctxt,square_ctxt)

if square_ctxt.level == eval.min_level_for_bootstrap:
    eval.bootstrap(square_ctxt,square_ctxt)

square_mean = cal_mean(square_ctxt)

# 2. mean(input)**2
mean_square = heaan.Ciphertext(context)
eval.mult(mean_ctxt,mean_ctxt,mean_square)

if mean_square.level == eval.min_level_for_bootstrap:
    eval.bootstrap(mean_square,mean_square)

# 3. 1 - 2
var_ctxt = heaan.Ciphertext(context)
eval.sub(square_mean, mean_square, var_ctxt)

var_msg = heaan.Message(log_slots)
dec.decrypt(var_ctxt, sk, var_msg)

# Find the variation in plaintext
plain_var = np.var(msg)

print("piheaan result : ", var_msg[0].real)
print("palintext result : ", plain_var)

piheaan result :  0.008947848525119519
palintext result :  0.008947848525


In [12]:
# standard deviation
std_ctxt = heaan.Ciphertext(context)
heaan.math.approx.sqrt(eval, var_ctxt, std_ctxt)

std_msg = heaan.Message(log_slots)
dec.decrypt(std_ctxt, sk, std_msg)

# Find the standard deviation in plaintext
plain_std = np.std(msg)

print("piheaan result : ", std_msg[0].real)
print("palintext result : ", plain_std)

piheaan result :  0.09459306806061159
palintext result :  0.09459306805997995
