In [1]:
%load_ext autoreload
%autoreload 2

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as nd

import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import data as dt, neuronshap as ns, sim
from cfgs.fedargs import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MnistModel(torch.nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = torch.nn.Conv2d(1, 32, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = torch.nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 7*7 by pooling
        self.fc1 = torch.nn.Linear(64*7*7, 1024)
        self.fc2 = torch.nn.Linear(1024, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)   # reshape Variable
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)
    
    def forward_test(self, x):
        res = {}
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        res["layer_1"] = x
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        res["layer_2"] = x
        x = x.view(-1, 64*7*7)   # reshape Variable
        res["layer_3"] = x        
        x = F.relu(self.fc1(x))
        res["layer_4"] = x        
        x = F.dropout(x, training=self.training)
        res["layer_5"] = x        
        x = self.fc2(x)
        res["layer_6"] = x        
        return res
    
model = MnistModel()
model

MnistModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (fc1): Linear(in_features=3136, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=10, bias=True)
)

In [3]:
batch_size = 128
train_data, test_data = dt.load_dataset(fedargs.dataset)

train_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000)

In [12]:
# Define the regularization strength
weight_decay = 0.001
#optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=weight_decay)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [6]:
model.train()
train_loss = []
train_accu = []
i = 0

#loss_fn = nn.MSELoss()
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(5):
    for data, target in train_loader:
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        #loss = F.nll_loss(output, target)
        loss= loss_fn(output, target)
        loss.backward()    # calc gradients
        train_loss.append(loss.item())
        optimizer.step()   # update gradients
        prediction = output.data.max(1)[1]   # first column has actual prob.
        accuracy = prediction.eq(target.data).sum()/batch_size*100
        train_accu.append(accuracy)
        if i % 1000 == 0:
            print('Train Step: {}\tLoss: {:.3f}\tAccuracy: {:.3f}'.format(i, loss.item(), accuracy))
        i += 1

  return F.log_softmax(x)


Train Step: 0	Loss: 2.300	Accuracy: 7.812
Train Step: 1000	Loss: 0.145	Accuracy: 96.094
Train Step: 2000	Loss: 0.073	Accuracy: 98.438


In [7]:
model.eval()
correct = 0
for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    loss = F.nll_loss(output, target)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()

print('\nTest set: \tLoss: {:.3f}\tAccuracy: {:.3f}'.format(loss, 100. * correct / len(test_loader.dataset)))

  data, target = Variable(data, volatile=True), Variable(target)
  return F.log_softmax(x)



Test set: 	Loss: 0.068	Accuracy: 97.990


In [6]:
'''
r_proj = nd.random.randint(2, size=(10000,10000))
r_proj[r_proj == 0] = -1
r_inv_proj = nd.linalg.pinv(r_proj)

print(r_proj.shape, r_inv_proj.shape)

with open('proj.npy', 'wb') as f:
    nd.save(f, r_proj)
    
with open('inv.npy', 'wb') as f:
    nd.save(f, r_inv_proj)  
'''

with open('proj.npy', 'rb') as f:
    r_proj = nd.load(f)

with open('inv.npy', 'rb') as f:
    r_inv_proj = nd.load(f)    

def get_enc_model(model):
    arr, slist = sim.get_net_arr(model)

    rem = nd.zeros(10000- (len(arr) % 10000))
    if len(arr) % 10000 != 0:
        arr = nd.concatenate((arr, rem), axis=None)

    #enc_model = []
    enc_model = nd.array([])
    index = 0
    while index < len(arr):
        #enc_model.append(arr[index:index+10000] @ r_proj)
        enc_model = nd.concatenate((enc_model, (arr[index:index+10000] @ r_proj)), axis = None)
        index = index + 10000
        #print(index)

    return enc_model

def get_dec_model(enc_model):
    arr, slist = sim.get_net_arr(model)
    
    rem = nd.zeros(10000- (len(arr) % 10000))
    if len(arr) % 10000 != 0:
        arr = nd.concatenate((arr, rem), axis=None)
    
    dec_model = nd.zeros(len(arr))
    index = 0
    while index < len(arr):
        dec = enc_model[index: index + 10000] @ r_inv_proj
        dec_model[index: index + 10000] = dec
        index = index + 10000
        #print(index)

    dec_model = sim.get_arr_net(model, dec_model, slist)
    return dec_model

FileNotFoundError: [Errno 2] No such file or directory: 'proj.npy'

In [None]:
enc_model = get_enc_model(model)
print(enc_model)

In [None]:
print(len(enc_model))

In [None]:
dec_model = get_dec_model(enc_model)
print(dec_model)

In [None]:
data, target = Variable(data, volatile=True), Variable(target)
output = dec_model(data)
print(output[0].argmax(), target[0])

<h1>Homomorphic Encryption</h1>

In [4]:
import tenseal as ts

def ckks_context():
    #context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60,40,40,60])
    context = ts.context(ts.SCHEME_TYPE.CKKS, 16384, coeff_mod_bit_sizes=[31, 26, 26, 26, 26, 26, 26, 26, 26, 26, 31])
    context.global_scale = pow(2, 40)
    return context

context = ckks_context()
context.generate_galois_keys()

p_context = context.serialize(save_public_key=False, save_secret_key=False, save_galois_keys=False, save_relin_keys=False)

In [5]:
import base64

def writeCkks(ckks_vec, filename):
    ser_ckks_vec = base64.b64encode(ckks_vec)

    with open(filename, 'wb') as f:
        f.write(ser_ckks_vec)

def readCkks(filename):
    with open(filename, 'rb') as f:
        ser_ckks_vec = f.read()
    
    return base64.b64decode(ser_ckks_vec)

In [8]:
arr, slist = sim.get_net_arr(model)[0:8000]
enc_ckks_model = ts.ckks_vector(context, arr)

#import pickle
#with open('plain_model.pkl','wb') as f:
#    pickle.dump(arr, f)

start = time.time()
enc_ckks_model_ser = enc_ckks_model.serialize()
end = time.time()
print(end-start)

writeCkks(enc_ckks_model_ser, "enc_model")
print(len(arr))

The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
4.270610094070435
3274634


In [8]:
import numpy as np

r_context = ts.context_from(p_context)
r_enc_ckks_model = ts.lazy_ckks_vector_from(enc_ckks_model_ser)
r_enc_ckks_model.link_context(r_context)

dec_ckks_model = sim.get_arr_net(model, np.array(enc_ckks_model.decrypt()), slist)
data, target = Variable(data, volatile=True), Variable(target)
output = dec_ckks_model(data)
print(output[0].argmax(), target[0])

NameError: name 'data' is not defined

In [9]:
enc_ckks_model = (enc_ckks_model + enc_ckks_model)
dec_ckks_model = sim.get_arr_net(model, np.array(enc_ckks_model.decrypt())/2, slist)
data, target = Variable(data, volatile=True), Variable(target)
output = dec_ckks_model(data)
print(output[0].argmax(), target[0])

NameError: name 'data' is not defined

<h1>Using OpenFHE</h1>

In [4]:
from openfhe import *

In [5]:
arr, slist = sim.get_net_arr(model)

def get_enc_model(arr, he_batch_size = 2048):
    if len(arr)%he_batch_size != 0:
        arr = nd.append(arr, [0 for i in range(len(arr)%he_batch_size)])

    enc_model = nd.array([])
    for element in arr:
        ptxt = cryptocontext.MakeCKKSPackedPlaintext(element,1,depth-1)
        ptxt.SetLength(he_batch_size)

        ciph = cryptocontext.Encrypt(key_pair.publicKey, ptxt)
        ciphertext_after = cryptocontext.EvalBootstrap(ciph)
        
        enc_model = nd.concatenate(enc_model, ciphertext_after)
    return enc_model

def get_dec_model(enc_model, he_batch_size = 2048):
    arr, slist = sim.get_net_arr(model)
    
    rem = nd.zeros(he_batch_size- (len(arr) % he_batch_size))
    if len(arr) % he_batch_size != 0:
        arr = nd.concatenate((arr, rem), axis=None)
    
    dec_model = nd.zeros(len(arr))
    index = 0
    while index < len(arr):
        dec = enc_model[index: index + 10000] @ r_inv_proj
        dec_model[index: index + 10000] = dec
        index = index + he_batch_size
        #print(index)

    dec_model = sim.get_arr_net(model, dec_model, slist)
    return dec_model

In [6]:
parameters = CCParamsCKKSRNS()

secret_key_dist = SecretKeyDist.UNIFORM_TERNARY
parameters.SetSecretKeyDist(secret_key_dist)

parameters.SetSecurityLevel(SecurityLevel.HEStd_NotSet)
parameters.SetRingDim(1<<12)

if get_native_int()==128:
    rescale_tech = ScalingTechnique.FIXEDAUTO
    dcrt_bits = 78
    first_mod = 89
else:
    rescale_tech = ScalingTechnique.FLEXIBLEAUTO
    dcrt_bits = 59
    first_mod = 60

parameters.SetScalingModSize(dcrt_bits)
parameters.SetScalingTechnique(rescale_tech)
parameters.SetFirstModSize(first_mod)

level_budget = [4, 4]

levels_available_after_bootstrap = 10

depth = levels_available_after_bootstrap + FHECKKSRNS.GetBootstrapDepth(level_budget, secret_key_dist)

parameters.SetMultiplicativeDepth(depth)

In [None]:
cryptocontext = GenCryptoContext(parameters)
cryptocontext.Enable(PKESchemeFeature.PKE)
cryptocontext.Enable(PKESchemeFeature.KEYSWITCH)
cryptocontext.Enable(PKESchemeFeature.LEVELEDSHE)
cryptocontext.Enable(PKESchemeFeature.ADVANCEDSHE)
cryptocontext.Enable(PKESchemeFeature.FHE)

ring_dim = cryptocontext.GetRingDimension()
# This is the mazimum number of slots that can be used full packing.

num_slots = int(ring_dim / 2)
print(f"CKKS is using ring dimension {ring_dim}")

cryptocontext.EvalBootstrapSetup(level_budget)

key_pair = cryptocontext.KeyGen()
cryptocontext.EvalMultKeyGen(key_pair.secretKey)
cryptocontext.EvalBootstrapKeyGen(key_pair.secretKey, num_slots)

In [4]:
#x = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0]
#x = [i for i in range(10000)]
#encoded_length = len(x)
#encoded_length = len(x)

#ptxt = cryptocontext.MakeCKKSPackedPlaintext(x,1,depth-1)
#ptxt.SetLength(encoded_length)

#print(f"Input: {ptxt}")

#ciph = cryptocontext.Encrypt(key_pair.publicKey, ptxt)

#print(f"Initial number of levels remaining: {depth - ciph.GetLevel()}")

#ciphertext_after = cryptocontext.EvalBootstrap(ciph)

#print(f"Number of levels remaining after bootstrapping: {depth - ciphertext_after.GetLevel()}")

enc_model = get_enc_model(arr)
print(enc_model)

'''
result = cryptocontext.Decrypt(ciphertext_after,key_pair.secretKey)
result.SetLength(encoded_length)
print(f"Output after bootstrapping: {result}")
'''

Input: (0.25, 0.5, 0.75, 1, 2, 3, 4, 5,  ... ); Estimated precision: 59 bits

Initial number of levels remaining: 1
Number of levels remaining after bootstrapping: 10
Output after bootstrapping: (0.250001, 0.500003, 0.75, 1, 2, 3, 4, 5,  ... ); Estimated precision: 18 bits



array([4.13170642e-01, 2.24534548e-01, 7.45997796e-03, 3.16243573e-01,
       7.10044608e-02, 1.22103451e-01, 8.57850172e-01, 6.25656515e-01,
       1.83341988e-01, 6.37512279e-02, 6.38942622e-01, 5.59309137e-01,
       2.42711674e-01, 2.65905043e-01, 2.90837745e-01, 5.03035203e-01,
       1.94449566e-01, 8.63909611e-01, 6.53846379e-01, 6.73665133e-02,
       9.79637133e-01, 9.24821259e-01, 9.44064746e-01, 3.45074725e-01,
       6.56453591e-01, 3.75502925e-01, 9.00758740e-01, 7.38177063e-01,
       3.60649399e-01, 2.40396296e-01, 8.45726725e-01, 9.27700552e-01,
       7.02997196e-01, 9.76046857e-01, 9.22864558e-01, 1.55953582e-01,
       7.21227022e-01, 4.98312618e-02, 1.22130776e-03, 6.79851841e-01,
       7.29915874e-01, 3.86427185e-01, 5.17484104e-03, 4.99808248e-01,
       6.34858317e-02, 8.68954899e-01, 6.65203549e-01, 9.81372471e-01,
       5.06317752e-02, 9.06889403e-01, 4.08691954e-01, 4.84048719e-01,
       2.89525897e-01, 3.30418260e-02, 4.18406859e-01, 8.29152588e-01,
      