In [62]:
!pip install Pyfhel
!pip install pynacl
!pip install cryptography
!pip install tqdm
!pip install scikit-learn



In [63]:
no_clients = 3
epochs = 3

In [64]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)

# Check if GPU is available
gpus = tf.config.list_physical_devices("GPU")
if gpus:
	print("GPUs available:", len(gpus))
	for gpu in gpus:
		print(gpu)
else:
	print("No GPU available.")

TensorFlow version: 2.16.1
No GPU available.


In [65]:
import tensorflow as tf
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
import os
import tensorflow as tf
from Pyfhel import Pyfhel
import nacl.utils
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
import nacl.utils
from nacl.public import PrivateKey, SealedBox
# from src.models.FMLEE import FMLEE
# from src.data.load_data import load_mnist

In [66]:
import os

os.environ["TF_USE_LEGACY_KERAS"] = "True"

In [67]:


save_dir = "dataset/mnist_data/"
os.makedirs(save_dir, exist_ok=True)

In [68]:
import tensorflow as tf


class MAML(tf.keras.Model):
	def __init__(self, model):
		super(MAML, self).__init__()
		self.model = model

	def call(self, inputs):
		x = tf.reshape(inputs, (-1, 28, 28, 1))  # Reshape the input tensor
		return self.model(x)

	def get_config(self):
		return {"model": self.model.get_config()}

	@classmethod
	def from_config(cls, config):
		model = tf.keras.models.Model.from_config(config["model"])
		return cls(model)

	def train_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		with tf.GradientTape() as tape:
			y_pred = self.model(x)
			loss = self.compiled_loss(y, y_pred)
		gradients = tape.gradient(loss, self.model.trainable_variables)
		self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}

	def test_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		y_pred = self.model(x)
		self.compiled_loss(y, y_pred)
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}


num_meta_updates = 10
num_inner_updates = 5
meta_batch_size = 32
inner_batch_size = 10

In [69]:
class FMLEE:
	def __init__(self, no_clients, epochs):
		self.no_clients = no_clients
		self.epochs = epochs
		print("Initializing CKKS scheme...")
		self.HE = self.CKKS()
		self.clients = []
		print("Initializing clients...")
		self.init_clients()
		print("Generating asymmetric keys...")
		self.pvt_key, self.pub_key = self.asym_keygen()
		print("Initialization complete.")

	def model_spec(self):
		model = tf.keras.models.Sequential(
			[
				tf.keras.layers.Conv2D(
					32, (3, 3), activation="relu", input_shape=(28, 28, 1)
				),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.Flatten(),
				tf.keras.layers.Dense(64, activation="relu"),
				tf.keras.layers.Dense(10),
			]
		)
		return model

	def init_model(self):
		model = MAML(self.model_spec())
		model.compile(
			optimizer=tf.keras.optimizers.Adam(),
			loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
			metrics=["accuracy"],
		)
		return model

	def CKKS(self):
		HE = Pyfhel()
		ckks_params = {
			"scheme": "CKKS",
			"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
			"scale": 2**30,  # All the encodings will use it for float->fixed point
			"qi_sizes": [
				60,
				30,
				30,
				30,
				60,
			],
		}
		print("Generating context for CKKS scheme...")
		HE.contextGen(**ckks_params)  # Generate context for ckks scheme
		print("Generating keys for CKKS scheme...")
		HE.keyGen()  # Key Generation: generates a pair of public/secret keys
		HE.rotateKeyGen()
		HE.relinKeyGen()
		print("CKKS scheme initialized.")
		return HE

	def asym_keygen(self):
		print("Generating private key...")
		pvt_key = PrivateKey.generate()
		print("Private key generated.")
		pub_key = pvt_key.public_key
		print("Public key generated.")
		return pvt_key, pub_key

	def init_clients(self):
		for i in range(self.no_clients):
			print(f"Initializing model for client {i}...")
			self.clients.append(self.init_model())
			print(f"Client {i} initialized.")

In [70]:
def download_and_save_mnist(save_dir):
	(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

	# Save training data with progress bar
	for array, name in zip(
		[x_train_all, y_train_all, x_test, y_test],
		["x_train.npy", "y_train.npy", "x_test.npy", "y_test.npy"],
	):
		with tqdm(total=len(array), desc=f"Saving {name}") as pbar:
			np.save(os.path.join(save_dir, name), array)
			pbar.update(len(array))

	print(f"Dataset downloaded and saved locally at {save_dir}")


def load_mnist_from_local(save_dir):
	x_train_all = np.load(os.path.join(save_dir, "x_train.npy"))
	y_train_all = np.load(os.path.join(save_dir, "y_train.npy"))
	x_test = np.load(os.path.join(save_dir, "x_test.npy"))
	y_test = np.load(os.path.join(save_dir, "y_test.npy"))
	print(f"Dataset loaded from local files at {save_dir}")
	x_train_all = x_train_all.astype(np.float32) / 255
	x_test = x_test.astype(np.float32) / 255

	return (x_train_all, y_train_all), (x_test, y_test)


def load_mnist():
	if not os.path.exists(os.path.join(save_dir, "x_train.npy")):
		download_and_save_mnist(save_dir)
	return load_mnist_from_local(save_dir)

In [71]:
(x_train_all, y_train_all), (x_test, y_test)  = load_mnist()


Dataset loaded from local files at dataset/mnist_data/


In [72]:
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

# Load MNIST data
(x_train_all, y_train_all), (x_test, y_test) = mnist.load_data()

# Normalize and reshape data
x_train_all = x_train_all.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

# Splitting data into training and test sets
print("Splitting data into training and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(
	x_train_all, y_train_all, test_size=0.2, random_state=42
)
print("Data split complete.")
print(
	f"Training set size: {len(X_train)}, Temp set size: {len(X_temp)}, Test set size: {len(x_test)}"
)

# Further split the temporary set into validation and testing sets
print("Splitting temp set into validation and testing sets...")
X_val, X_test, y_val, y_test = train_test_split(
	X_temp, y_temp, test_size=0.15, random_state=42
)
print("Validation and test set split complete.")
print(f"Validation set size: {len(X_val)}, Test set size: {len(X_test)}")

# Split training data into n parts
n_parts = no_clients
part_size = len(X_train) // n_parts
dataset_parts = []

print(f"Splitting training data into {n_parts} parts...")
for i in range(n_parts):
	start = i * part_size
	end = (i + 1) * part_size if i != n_parts - 1 else len(X_train)
	X_part = X_train[start:end]
	y_part = y_train[start:end]
	dataset_parts.append((X_part, y_part))
	print(f"Part {i + 1} created: {len(X_part)} samples.")

print("Data splitting into parts complete.")

Splitting data into training and test sets...
Data split complete.
Training set size: 48000, Temp set size: 12000, Test set size: 10000
Splitting temp set into validation and testing sets...
Validation and test set split complete.
Validation set size: 10200, Test set size: 1800
Splitting training data into 3 parts...
Part 1 created: 16000 samples.
Part 2 created: 16000 samples.
Part 3 created: 16000 samples.
Data splitting into parts complete.


In [73]:
fml = FMLEE(no_clients, epochs)

Initializing CKKS scheme...
Generating context for CKKS scheme...
Generating keys for CKKS scheme...
CKKS scheme initialized.
Initializing clients...
Initializing model for client 0...
Client 0 initialized.
Initializing model for client 1...
Client 1 initialized.
Initializing model for client 2...
Client 2 initialized.
Generating asymmetric keys...
Generating private key...
Private key generated.
Public key generated.
Initialization complete.


In [74]:
fml.clients

[<MAML name=maml_6, built=False>,
 <MAML name=maml_7, built=False>,
 <MAML name=maml_8, built=False>]

In [75]:
fml.HE

<ckks Pyfhel obj at 0x79c32e06a760, [pk:Y, sk:Y, rtk:Y, rlk:Y, contx(n=16384, t=0, sec=128, qi=[60, 30, 30, 30, 60], scale=1073741824.0, )]>

In [76]:
# fml.clients[0].fit(x_train_all, y_train_all)

In [77]:
import numpy as np
from Pyfhel import Pyfhel


def CKKS_keygen():
	print("Initializing Pyfhel for CKKS scheme...")
	HE = Pyfhel()

	ckks_params = {
		"scheme": "CKKS",
		"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
		"scale": 2**30,  # All the encodings will use it for float->fixed point
		"qi_sizes": [60, 30, 30, 30, 60],  # Number of bits of each prime in the chain.
	}

	print("Generating context for CKKS scheme...")
	HE.contextGen(**ckks_params)  # Generate context for ckks scheme
	print("Context generation complete.")

	print("Generating public and secret keys...")
	HE.keyGen()  # Key Generation: generates a pair of public/secret keys
	print("Public and secret key generation complete.")

	print("Generating rotation keys...")
	HE.rotateKeyGen()
	print("Rotation keys generation complete.")

	return HE


HE = CKKS_keygen()

Initializing Pyfhel for CKKS scheme...
Generating context for CKKS scheme...
Context generation complete.
Generating public and secret keys...
Public and secret key generation complete.
Generating rotation keys...
Rotation keys generation complete.


In [78]:
def asym_keygen():
	pvt_key = PrivateKey.generate()
	pub_key = pvt_key.public_key
	return pvt_key, pub_key


agg_pvt_key, agg_pub_key = asym_keygen()

In [79]:
def nacl_session_keygen():
	return nacl.utils.random(32)

def encrypt_symmetric_key(pub_key, symmetric_key):
	sealed_box = SealedBox(pub_key)
	return sealed_box.encrypt(symmetric_key)

def decrypt_symmetric_key(pvt_key, encrypted_key):
	sealed_box = SealedBox(pvt_key)
	return sealed_box.decrypt(encrypted_key)



In [80]:
import tensorflow as tf


def maml_train_step(model, x_train, y_train, inner_lr, num_inner_updates):

	model.fit(
			x_train,
			y_train,
			epochs=1,
			batch_size=64,
			verbose=1,
			validation_data=(x_train, y_train),
		)
	with tf.GradientTape() as outer_tape:
		for i in range(num_inner_updates):
			with tf.GradientTape() as inner_tape:
				predictions = model(x_train, training=True)
				loss = tf.reduce_mean(
					tf.keras.losses.sparse_categorical_crossentropy(
						y_train, predictions
					)
				)
			grads = inner_tape.gradient(loss, model.trainable_variables)
			for var, grad in zip(model.trainable_variables, grads):
				if grad is not None:
					var.assign_sub(inner_lr * grad)

		predictions = model(x_train, training=True)
		outer_loss = tf.reduce_mean(
			tf.keras.losses.sparse_categorical_crossentropy(y_train, predictions)
		)

	outer_grads = outer_tape.gradient(outer_loss, model.trainable_variables)
	return outer_loss, outer_grads

In [2]:
def HE_encrypt(wtarray):
    cwt = []
    for layer in wtarray:
        flat_array = layer.astype(np.float64).flatten()

        chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
        clayer = []
        
        for chunk in chunks:
            ptxt = HE.encodeFrac(chunk)
            ctxt = HE.encryptPtxt(ptxt)
            clayer.append(ctxt)
        cwt.append(clayer.copy())
        
    return ciphertext

In [None]:
def encrypt_message_sym_AES(key, message):
    serialized_obj = pickle.dumps(message)

    iv = nacl.utils.random(16)

    padder = padding.PKCS7(algorithms.AES.block_size).padder()
    padded_obj = padder.update(serialized_obj) + padder.finalize()

    cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
    encryptor = cipher.encryptor()
    ciphertext = encryptor.update(padded_obj) + encryptor.finalize()

    return iv + ciphertext

In [None]:
def decrypt_message_sym_AES(key , ciphertext):
    iv = ciphertext[:16]
    ciphertext = ciphertext[16:]

    cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
    decryptor = cipher.decryptor()
    padded_obj = decryptor.update(ciphertext) + decryptor.finalize()

    unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
    unpadded_obj = unpadder.update(padded_obj) + unpadder.finalize()

    return pickle.loads(unpadded_obj)

In [None]:
def aggregate_wts_ckks(encrypted_wts):
    res_wts = []

    for j in range(len(encrypted_wts[0])):
        layer = []
        for k in range(len(encrypted_wts[0][j])):
            tmp = encrypted_wts[0][j][k].copy()

            for i in range(1, len(encrypted_wts)):
                tmp = temp + encrypted_wts[i][j][k]

            tmp = tmp / len(encrypted_wts)
            layer.append(tmp)

        res_wts.append(layer.copy())
        
    return res_wts

In [None]:
def decrypt_wts_ckks(encrypted_wts):
    decrypted_wts = []
    wtarray = dummy_model.get_weights()

    for layer_wts , layer in zip(encrypted_wts , wtarray):
        decrypted_layer = []
        flat_array = layer.astype(np.float64).flatten()
        chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)

        for chunk , cchunk in zip(chunks , layer_wts):
            decrypted_chunk = HE.decryptFrac(cchunk)
            original_chunk_size  = len(chunk)
            decrypted_chunk = decrypted_chunk[:original_chunk_size]
            
            decrypted_layer.append(decrypted_chunk)
            
        decrypted_layer = np.concatenate(decrypted_layer)
        decrypted_layer = decrypted_layer.reshape(layer.shape)
        decrypted_wts.append(decrypted_layer)
        
    return decrypted_wts

In [5]:
def aggregate_wts(wts):
    peeled_wts = []
    for client_id in range(no_clients):
        sesion_key = decrypt_symmetric_key(agg_pvt_key , agg_sesion_keys[client_id])
        peeled_wt = decrypt_message_sym_AES(sesion_key, wts[client_id])
    
    res_wts = aggregate_wts_ckks(peeled_wts)
    return res_wts

In [81]:
dummy_model = fml.clients[0]

In [82]:
agg_sesion_keys = [0 for i in range(no_clients)]

In [83]:
inner_lr = 0.001
num_inner_updates = 1
outer_lr = 0.001


In [84]:
accuracies = [[] for i in range(no_clients)]
losses = [[] for i in range(no_clients)]

In [None]:
client_session_keys = [0 for i in range(no_clients)]

In [4]:
for r in tqdm(range(epochs)):
    for client_id , (client , client_dataset) in enumerate(zip(fml.clients , dataset_parts)):
        model = client
        x_train, y_train = client_dataset

        outer_loss, outer_grads = maml_train_step(
				model, x_train, y_train, inner_lr, num_inner_updates
			)
        optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
        optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
        history = model.evaluate(
            x_train,
            y_train,
            batch_size=64,
            verbose=0,
        )
        accuracies[client_id].append(history[1])
        losses[client_id].append(history[0])
        trained_weights = model.get_weights()

        session_key = nacl_session_keygen()
        client_session_keys[client_id] = session_key
        enc_session_key = encrypt_symmetric_key(agg_pub_key , session_key)
        agg_sesion_keys[client_id] = enc_session_key

        He_ciphertext = HE_encrypt(trained_weights)
        sym_ctxt = encrypt_message_sym_AES(session_key , He_ciphertext)
        enc_wts[client_id] = sym_ctxt
        
    agg_wts = agggregate_wts()
    
    for client_id , client in enumerate(fml.clients):
        new_wts = decrypt_wts_ckks(agg_wts)        

        client.set_weights(new_wts)

NameError: name 'tqdm' is not defined

In [91]:
history

[-3.5314900875091553, 0.9926249980926514]

In [86]:
dataset_parts[0][1].shape

(16000,)

In [87]:
accuracies

[[0.9706249833106995, 0.9886875152587891, 0.992562472820282],
 [0.9772499799728394, 0.9846875071525574, 0.9941874742507935],
 [0.9736250042915344, 0.9865000247955322, 0.9913125038146973]]