In [75]:
!pip install Pyfhel
!pip install pynacl
!pip install cryptography
!pip install tqdm
!pip install scikit-learn



In [76]:
no_clients = 3
epochs = 3

In [77]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)

# Check if GPU is available
gpus = tf.config.list_physical_devices("GPU")
if gpus:
	print("GPUs available:", len(gpus))
	for gpu in gpus:
		print(gpu)
else:
	print("No GPU available.")

TensorFlow version: 2.16.1
No GPU available.


In [78]:
import tensorflow as tf
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
import os
import tensorflow as tf
from Pyfhel import Pyfhel
import nacl.utils
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
import nacl.utils
from nacl.public import PrivateKey, SealedBox
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import Accuracy

# from src.models.FMLEE import FMLEE
# from src.data.load_data import load_mnist

In [79]:
import os

os.environ["TF_USE_LEGACY_KERAS"] = "True"

In [80]:


save_dir = "dataset/mnist_data/"
os.makedirs(save_dir, exist_ok=True)

In [81]:
import tensorflow as tf


class MAML(tf.keras.Model):
	def __init__(self, model):
		super(MAML, self).__init__()
		self.model = model

	def call(self, inputs):
		x = tf.reshape(inputs, (-1, 28, 28, 1))  # Reshape the input tensor
		return self.model(x)

	def get_config(self):
		return {"model": self.model.get_config()}

	@classmethod
	def from_config(cls, config):
		model = tf.keras.models.Model.from_config(config["model"])
		return cls(model)

	def train_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		with tf.GradientTape() as tape:
			y_pred = self.model(x)
			loss = self.compiled_loss(y, y_pred)
		gradients = tape.gradient(loss, self.model.trainable_variables)
		self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}

	def test_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		y_pred = self.model(x)
		self.compiled_loss(y, y_pred)
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}


num_meta_updates = 10
num_inner_updates = 5
meta_batch_size = 32
inner_batch_size = 10

In [82]:
class FMLEE:
	def __init__(self, no_clients, epochs):
		self.no_clients = no_clients
		self.epochs = epochs
		print("Initializing CKKS scheme...")
		self.HE = self.CKKS()
		self.clients = []
		print("Initializing clients...")
		self.init_clients()
		print("Generating asymmetric keys...")
		self.pvt_key, self.pub_key = self.asym_keygen()
		print("Initialization complete.")

	def model_spec(self):
		model = tf.keras.models.Sequential(
			[
				tf.keras.layers.Conv2D(
					32, (3, 3), activation="relu", input_shape=(28, 28, 1)
				),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.Flatten(),
				tf.keras.layers.Dense(64, activation="relu"),
				tf.keras.layers.Dense(10),
			]
		)
		return model

	def init_model(self):
		model = MAML(self.model_spec())
		model.compile(
			optimizer=tf.keras.optimizers.Adam(),
			loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
			metrics=["accuracy"],
		)
		return model

	def CKKS(self):
		HE = Pyfhel()
		ckks_params = {
			"scheme": "CKKS",
			"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
			"scale": 2**30,  # All the encodings will use it for float->fixed point
			"qi_sizes": [
				60,
				30,
				30,
				30,
				60,
			],
		}
		print("Generating context for CKKS scheme...")
		HE.contextGen(**ckks_params)  # Generate context for ckks scheme
		print("Generating keys for CKKS scheme...")
		HE.keyGen()  # Key Generation: generates a pair of public/secret keys
		HE.rotateKeyGen()
		HE.relinKeyGen()
		print("CKKS scheme initialized.")
		return HE

	def asym_keygen(self):
		print("Generating private key...")
		pvt_key = PrivateKey.generate()
		print("Private key generated.")
		pub_key = pvt_key.public_key
		print("Public key generated.")
		return pvt_key, pub_key

	def init_clients(self):
		for i in range(self.no_clients):
			print(f"Initializing model for client {i}...")
			self.clients.append(self.init_model())
			print(f"Client {i} initialized.")

In [83]:
def download_and_save_mnist(save_dir):
	(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

	# Save training data with progress bar
	for array, name in zip(
		[x_train_all, y_train_all, x_test, y_test],
		["x_train.npy", "y_train.npy", "x_test.npy", "y_test.npy"],
	):
		with tqdm(total=len(array), desc=f"Saving {name}") as pbar:
			np.save(os.path.join(save_dir, name), array)
			pbar.update(len(array))

	print(f"Dataset downloaded and saved locally at {save_dir}")


def load_mnist_from_local(save_dir):
	x_train_all = np.load(os.path.join(save_dir, "x_train.npy"))
	y_train_all = np.load(os.path.join(save_dir, "y_train.npy"))
	x_test = np.load(os.path.join(save_dir, "x_test.npy"))
	y_test = np.load(os.path.join(save_dir, "y_test.npy"))
	print(f"Dataset loaded from local files at {save_dir}")
	x_train_all = x_train_all.astype(np.float32) / 255
	x_test = x_test.astype(np.float32) / 255

	return (x_train_all, y_train_all), (x_test, y_test)


def load_mnist():
	if not os.path.exists(os.path.join(save_dir, "x_train.npy")):
		download_and_save_mnist(save_dir)
	return load_mnist_from_local(save_dir)

In [84]:
(x_train_all, y_train_all), (x_test, y_test)  = load_mnist()


Dataset loaded from local files at dataset/mnist_data/


In [85]:
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

# Load MNIST data
(x_train_all, y_train_all), (x_test, y_test) = mnist.load_data()

# Normalize and reshape data
x_train_all = x_train_all.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

# Splitting data into training and test sets
print("Splitting data into training and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(
	x_train_all, y_train_all, test_size=0.2, random_state=42
)
print("Data split complete.")
print(
	f"Training set size: {len(X_train)}, Temp set size: {len(X_temp)}, Test set size: {len(x_test)}"
)

# Further split the temporary set into validation and testing sets
print("Splitting temp set into validation and testing sets...")
X_val, X_test, y_val, y_test = train_test_split(
	X_temp, y_temp, test_size=0.15, random_state=42
)
print("Validation and test set split complete.")
print(f"Validation set size: {len(X_val)}, Test set size: {len(X_test)}")

# Split training data into n parts
n_parts = no_clients
part_size = len(X_train) // n_parts
dataset_parts = []

print(f"Splitting training data into {n_parts} parts...")
for i in range(n_parts):
	start = i * part_size
	end = (i + 1) * part_size if i != n_parts - 1 else len(X_train)
	X_part = X_train[start:end]
	y_part = y_train[start:end]
	dataset_parts.append((X_part, y_part))
	print(f"Part {i + 1} created: {len(X_part)} samples.")

print("Data splitting into parts complete.")

Splitting data into training and test sets...
Data split complete.
Training set size: 48000, Temp set size: 12000, Test set size: 10000
Splitting temp set into validation and testing sets...
Validation and test set split complete.
Validation set size: 10200, Test set size: 1800
Splitting training data into 3 parts...
Part 1 created: 16000 samples.
Part 2 created: 16000 samples.
Part 3 created: 16000 samples.
Data splitting into parts complete.


In [86]:
fml = FMLEE(no_clients, epochs)

Initializing CKKS scheme...
Generating context for CKKS scheme...
Generating keys for CKKS scheme...
CKKS scheme initialized.
Initializing clients...
Initializing model for client 0...
Client 0 initialized.
Initializing model for client 1...
Client 1 initialized.
Initializing model for client 2...
Client 2 initialized.
Generating asymmetric keys...
Generating private key...
Private key generated.
Public key generated.
Initialization complete.


In [87]:
fml.clients

[<MAML name=maml_9, built=False>,
 <MAML name=maml_10, built=False>,
 <MAML name=maml_11, built=False>]

In [88]:
fml.HE

<ckks Pyfhel obj at 0x75d07cdf6da0, [pk:Y, sk:Y, rtk:Y, rlk:Y, contx(n=16384, t=0, sec=128, qi=[60, 30, 30, 30, 60], scale=1073741824.0, )]>

In [89]:
# fml.clients[0].fit(x_train_all, y_train_all)

In [90]:
import numpy as np
from Pyfhel import Pyfhel


def CKKS_keygen():
	print("Initializing Pyfhel for CKKS scheme...")
	HE = Pyfhel()

	ckks_params = {
		"scheme": "CKKS",
		"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
		"scale": 2**30,  # All the encodings will use it for float->fixed point
		"qi_sizes": [60, 30, 30, 30, 60],  # Number of bits of each prime in the chain.
	}

	print("Generating context for CKKS scheme...")
	HE.contextGen(**ckks_params)  # Generate context for ckks scheme
	print("Context generation complete.")

	print("Generating public and secret keys...")
	HE.keyGen()  # Key Generation: generates a pair of public/secret keys
	print("Public and secret key generation complete.")

	print("Generating rotation keys...")
	HE.rotateKeyGen()
	print("Rotation keys generation complete.")

	return HE


HE = CKKS_keygen()

Initializing Pyfhel for CKKS scheme...
Generating context for CKKS scheme...
Context generation complete.
Generating public and secret keys...
Public and secret key generation complete.
Generating rotation keys...
Rotation keys generation complete.


In [91]:
def asym_keygen():
	pvt_key = PrivateKey.generate()
	pub_key = pvt_key.public_key
	return pvt_key, pub_key


agg_pvt_key, agg_pub_key = asym_keygen()

In [92]:
def nacl_session_keygen():
	return nacl.utils.random(32)

def encrypt_symmetric_key(pub_key, symmetric_key):
	sealed_box = SealedBox(pub_key)
	return sealed_box.encrypt(symmetric_key)

def decrypt_symmetric_key(pvt_key, encrypted_key):
	sealed_box = SealedBox(pvt_key)
	return sealed_box.decrypt(encrypted_key)



In [93]:
import tensorflow as tf


def maml_train_step(model, x_train, y_train, inner_lr, num_inner_updates):

	model.fit(
			x_train,
			y_train,
			epochs=1,
			batch_size=64,
			verbose=1,
			validation_data=(x_train, y_train),
		)
	with tf.GradientTape() as outer_tape:
		for i in range(num_inner_updates):
			with tf.GradientTape() as inner_tape:
				predictions = model(x_train, training=True)
				loss = tf.reduce_mean(
					tf.keras.losses.sparse_categorical_crossentropy(
						y_train, predictions
					)
				)
			grads = inner_tape.gradient(loss, model.trainable_variables)
			for var, grad in zip(model.trainable_variables, grads):
				if grad is not None:
					var.assign_sub(inner_lr * grad)

		predictions = model(x_train, training=True)
		outer_loss = tf.reduce_mean(
			tf.keras.losses.sparse_categorical_crossentropy(y_train, predictions)
		)

	outer_grads = outer_tape.gradient(outer_loss, model.trainable_variables)
	return outer_loss, outer_grads

In [94]:
def HE_encrypt(wtarray):
	cwt = []
	for layer in wtarray:
		flat_array = layer.astype(np.float64).flatten()

		chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
		clayer = []
		
		for chunk in chunks:
			ptxt = HE.encodeFrac(chunk)
			ctxt = HE.encryptPtxt(ptxt)
			clayer.append(ctxt)
		cwt.append(clayer.copy())
		
	return cwt

In [95]:
def encrypt_message_sym_AES(key, message):
	serialized_obj = pickle.dumps(message)

	iv = nacl.utils.random(16)

	padder = padding.PKCS7(algorithms.AES.block_size).padder()
	padded_obj = padder.update(serialized_obj) + padder.finalize()

	cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
	encryptor = cipher.encryptor()
	ciphertext = encryptor.update(padded_obj) + encryptor.finalize()

	return iv + ciphertext

In [96]:
def decrypt_message_sym_AES(key , ciphertext):
	iv = ciphertext[:16]
	ciphertext = ciphertext[16:]

	cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
	decryptor = cipher.decryptor()
	padded_obj = decryptor.update(ciphertext) + decryptor.finalize()

	unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
	unpadded_obj = unpadder.update(padded_obj) + unpadder.finalize()

	return pickle.loads(unpadded_obj)

In [97]:
def aggregate_wts_ckks(encrypted_wts):
	res_wts = []
	print(len(encrypted_wts) )
	print(len(encrypted_wts[0]) )
	for j in range(len(encrypted_wts[0])):
		layer = []
		for k in range(len(encrypted_wts[0][j])):
			tmp = encrypted_wts[0][j][k].copy()

			for i in range(1, len(encrypted_wts)):
				tmp = tmp + encrypted_wts[i][j][k]

			tmp = tmp / len(encrypted_wts)
			layer.append(tmp)

		res_wts.append(layer.copy())
		
	return res_wts

In [98]:
def decrypt_wts_ckks(encrypted_wts):
	decrypted_wts = []
	wtarray = dummy_model.get_weights()

	for layer_wts , layer in zip(encrypted_wts , wtarray):
		decrypted_layer = []
		flat_array = layer.astype(np.float64).flatten()
		chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)

		for chunk , cchunk in zip(chunks , layer_wts):
			decrypted_chunk = HE.decryptFrac(cchunk)
			original_chunk_size  = len(chunk)
			decrypted_chunk = decrypted_chunk[:original_chunk_size]
			
			decrypted_layer.append(decrypted_chunk)
			
		decrypted_layer = np.concatenate(decrypted_layer)
		decrypted_layer = decrypted_layer.reshape(layer.shape)
		decrypted_wts.append(decrypted_layer)
		
	return decrypted_wts

In [99]:
def aggregate_wts(wts):
	peeled_wts = []
	for client_id in range(no_clients):
		sesion_key = decrypt_symmetric_key(agg_pvt_key , agg_sesion_keys[client_id])
		peeled_wt = decrypt_message_sym_AES(sesion_key, wts[client_id])
		peeled_wts.append(peeled_wt)
	res_wts = aggregate_wts_ckks(peeled_wts)
	return res_wts

In [100]:
dummy_model = fml.clients[0]

In [101]:
agg_sesion_keys = [0 for i in range(no_clients)]

In [102]:
inner_lr = 0.001
num_inner_updates = 1
outer_lr = 0.001


In [103]:
accuracies = [[] for i in range(no_clients)]
losses = [[] for i in range(no_clients)]

In [104]:
client_session_keys = [0 for i in range(no_clients)]

In [105]:
enc_wts = [0 for i in range(no_clients)]

In [106]:
for r in tqdm(range(epochs)):
	for client_id , (client , client_dataset) in enumerate(zip(fml.clients , dataset_parts)):
		model = client
		x_train, y_train = client_dataset

		outer_loss, outer_grads = maml_train_step(
				model, x_train, y_train, inner_lr, num_inner_updates
			)
		optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
		optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
		history = model.evaluate(
			x_train,
			y_train,
			batch_size=64,
			verbose=0,
		)
		accuracies[client_id].append(history[1])
		losses[client_id].append(history[0])
		trained_weights = model.get_weights()

		session_key = nacl_session_keygen()
		client_session_keys[client_id] = session_key
		enc_session_key = encrypt_symmetric_key(agg_pub_key , session_key)
		agg_sesion_keys[client_id] = enc_session_key

		He_ciphertext = HE_encrypt(trained_weights)
		sym_ctxt = encrypt_message_sym_AES(session_key , He_ciphertext)
		enc_wts[client_id] = sym_ctxt
	print(len(enc_wts[0]))
	agg_wts = aggregate_wts(enc_wts)

	for client_id , client in enumerate(fml.clients):
		new_wts = decrypt_wts_ckks(agg_wts)        

		client.set_weights(new_wts)

  0%|          | 0/3 [00:00<?, ?it/s]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.7395 - loss: -0.0288 - val_accuracy: 0.9573 - val_loss: -0.0958
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.7445 - loss: -0.2572 - val_accuracy: 0.9545 - val_loss: -0.6079
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.7172 - loss: 0.7054 - val_accuracy: 0.9564 - val_loss: 0.3457
101725776
3
10


 33%|███▎      | 1/3 [00:49<01:38, 49.25s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5957 - loss: -1.0453 - val_accuracy: 0.9247 - val_loss: -1.8873
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5735 - loss: -1.2432 - val_accuracy: 0.9319 - val_loss: -2.1239
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.6469 - loss: -0.7827 - val_accuracy: 0.9199 - val_loss: -1.3466
101725776
3
10


 67%|██████▋   | 2/3 [01:35<00:47, 47.60s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6171 - loss: -1.0251 - val_accuracy: 0.9433 - val_loss: -2.2790
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6505 - loss: -1.3621 - val_accuracy: 0.9430 - val_loss: -2.4384
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6086 - loss: -0.8146 - val_accuracy: 0.9412 - val_loss: -1.5550
101725776
3
10


100%|██████████| 3/3 [02:22<00:00, 47.35s/it]


In [107]:
history

[-2.116286277770996, 0.9089999794960022]

In [108]:
dataset_parts[0][1].shape

(16000,)

In [109]:
fmlee_accuracies = accuracies
fmlee_losses = losses

In [110]:
no_fml = FMLEE(no_clients, epochs)

Initializing CKKS scheme...
Generating context for CKKS scheme...
Generating keys for CKKS scheme...
CKKS scheme initialized.
Initializing clients...
Initializing model for client 0...
Client 0 initialized.
Initializing model for client 1...
Client 1 initialized.
Initializing model for client 2...
Client 2 initialized.
Generating asymmetric keys...
Generating private key...
Private key generated.
Public key generated.
Initialization complete.


In [111]:
accuracies = [[] for i in range(no_clients)]
losses = [[] for i in range(no_clients)]

In [112]:
def aggregate_wts_noenc(weight_list):
    n_models = len(weight_list)
    # Initialize the aggregated weights with zeros of the same shape as the first model's weights
    aggregated_weights = [np.zeros_like(w) for w in weight_list[0]]

    # Sum the weights for each layer
    for weights in weight_list:
        for i, weight in enumerate(weights):
            aggregated_weights[i] += weight

    # Divide by the number of models to get the average
    for i in range(len(aggregated_weights)):
        aggregated_weights[i] /= n_models

    return aggregated_weights

In [113]:
for r in tqdm(range(epochs)):
    for client_id, (client, client_dataset) in enumerate(
        zip(fml.clients, dataset_parts)
    ):
        model = client
        x_train, y_train = client_dataset

        outer_loss, outer_grads = maml_train_step(
            model, x_train, y_train, inner_lr, num_inner_updates
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
        optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
        history = model.evaluate(
            x_train,
            y_train,
            batch_size=64,
            verbose=0,
        )
        accuracies[client_id].append(history[1])
        losses[client_id].append(history[0])
        trained_weights = model.get_weights()

        enc_wts[client_id] = trained_weights
    agg_wts = aggregate_wts_noenc(enc_wts)

    for client_id, client in enumerate(fml.clients):
        client.set_weights(agg_wts)

  0%|          | 0/3 [00:00<?, ?it/s]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6005 - loss: -1.4667 - val_accuracy: 0.9352 - val_loss: -2.9217
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.5982 - loss: -0.9547 - val_accuracy: 0.9500 - val_loss: -1.8954
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.5711 - loss: -1.2228 - val_accuracy: 0.9467 - val_loss: -2.3903


 33%|███▎      | 1/3 [00:41<01:22, 41.07s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9283 - loss: -2.4199 - val_accuracy: 0.9663 - val_loss: -2.4230
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9295 - loss: -2.3908 - val_accuracy: 0.9699 - val_loss: -2.3831
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9275 - loss: -2.5641 - val_accuracy: 0.9649 - val_loss: -2.6319


 67%|██████▋   | 2/3 [01:21<00:40, 40.78s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9654 - loss: -3.0352 - val_accuracy: 0.9770 - val_loss: -2.9188
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9643 - loss: -2.9149 - val_accuracy: 0.9754 - val_loss: -2.8710
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9668 - loss: -3.0407 - val_accuracy: 0.9749 - val_loss: -2.8254


100%|██████████| 3/3 [02:02<00:00, 40.79s/it]


In [114]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

num_samples = 100
num_classes = 10
input_shape = (224, 224, 3)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, num_classes, num_samples)

# Convert labels to one-hot encoding
y_dummy = tf.keras.utils.to_categorical(y_dummy, num_classes)

# Load MobileNetV2 with pre-trained weights and exclude the top layers
base_model = MobileNetV2(weights="imagenet", include_top=False, input_shape=input_shape)

# Add custom top layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(num_classes, activation="softmax")(x)

# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])

# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=32)

# Print the training history
print(history.history)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m924s[0m 3s/step - accuracy: 0.0988 - loss: 2.3847
{'accuracy': [0.09839999675750732], 'loss': [2.3259079456329346]}


In [115]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Concatenate,
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


# Define U-Net architecture
def unet(input_size=(128, 128, 1)):
    inputs = Input(input_size)

    # Encoder
    conv1 = Conv2D(64, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(1024, (3, 3), activation="relu", padding="same")(conv5)

    # Decoder
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Conv2D(512, (2, 2), activation="relu", padding="same")(up6)
    merge6 = Concatenate()([conv4, up6])
    conv6 = Conv2D(512, (3, 3), activation="relu", padding="same")(merge6)
    conv6 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Conv2D(256, (2, 2), activation="relu", padding="same")(up7)
    merge7 = Concatenate()([conv3, up7])
    conv7 = Conv2D(256, (3, 3), activation="relu", padding="same")(merge7)
    conv7 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Conv2D(128, (2, 2), activation="relu", padding="same")(up8)
    merge8 = Concatenate()([conv2, up8])
    conv8 = Conv2D(128, (3, 3), activation="relu", padding="same")(merge8)
    conv8 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Conv2D(64, (2, 2), activation="relu", padding="same")(up9)
    merge9 = Concatenate()([conv1, up9])
    conv9 = Conv2D(64, (3, 3), activation="relu", padding="same")(merge9)
    conv9 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv9)
    conv9 = Conv2D(2, (3, 3), activation="relu", padding="same")(conv9)
    conv10 = Conv2D(1, (1, 1), activation="sigmoid")(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model


# Create a dummy dataset
num_samples = 100
input_shape = (128, 128, 1)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, 2, (num_samples, 128, 128, 1)).astype(np.float32)

unet_models = []
for _ in range(len(clients)):
# Instantiate the U-Net model
    model = unet(input_shape)

    # Compile the model
    model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy"])
    unet_models.append(model)
# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=8)

# Print the training history
print(history.history)

[1m1226/1250[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m1:02[0m 3s/step - accuracy: 0.5000 - loss: 0.6931

KeyboardInterrupt: 

In [1]:
import time as tm

In [None]:
for r in tqdm(range(1)):
    for client_id, (unet_models, client_dataset) in enumerate(
        zip(clients, dataset_parts)
    ):
        model = client
        x_train, y_train = client_dataset

        outer_loss, outer_grads = maml_train_step(
            model, x_train, y_train, inner_lr, num_inner_updates
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
        optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
        history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=8)
        
        trained_weights = model.get_weights()
        resnet_enc_st = tm.time()
        
        session_key = nacl_session_keygen()
        client_session_keys[client_id] = session_key
        enc_session_key = encrypt_symmetric_key(agg_pub_key, session_key)
        agg_sesion_keys[client_id] = enc_session_key

        He_ciphertext = HE_encrypt(trained_weights)
        sym_ctxt = encrypt_message_sym_AES(session_key, He_ciphertext)
        enc_wts[client_id] = sym_ctxt
        resnet_enc_stop = tm.time()
    print(len(enc_wts[0]))
    
    resnet_agg_st = tm.time()
    agg_wts = aggregate_wts(enc_wts)
    resnet_agg_stop = tm.time()
    
    for client_id, client in enumerate(fml.clients):
        resnet_dec_st = tm.time()
        new_wts = decrypt_wts_ckks(agg_wts)
        resnet_dec_stop = tm.time()
        client.set_weights(new_wts)

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Create a dummy dataset
num_samples = 100
num_classes = 10
input_shape = (224, 224, 3)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, num_classes, num_samples)

# Convert labels to one-hot encoding
y_dummy = tf.keras.utils.to_categorical(y_dummy, num_classes)

# Load ResNet50 with pre-trained weights and exclude the top layers
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)

# Add custom top layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(num_classes, activation="softmax")(x)

renet_models = []
for _ in range(len(clients)):
    # Create the full model
    model = Model(inputs=base_model.input, outputs=predictions)

    # Compile the model
    model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])

    resnet_models.append(model)
    # 
# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])

# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=32)

# Print the training history
print(history.history)

In [None]:
for r in tqdm(range(1)):
    for client_id, (resnet_models, client_dataset) in enumerate(
        zip(clients, dataset_parts)
    ):
        model = client
        x_train, y_train = client_dataset

        outer_loss, outer_grads = maml_train_step(
            model, x_train, y_train, inner_lr, num_inner_updates
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
        optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
        history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=8)

        trained_weights = model.get_weights()
        unet_enc_st = tm.time()

        session_key = nacl_session_keygen()
        client_session_keys[client_id] = session_key
        enc_session_key = encrypt_symmetric_key(agg_pub_key, session_key)
        agg_sesion_keys[client_id] = enc_session_key

        He_ciphertext = HE_encrypt(trained_weights)
        sym_ctxt = encrypt_message_sym_AES(session_key, He_ciphertext)
        enc_wts[client_id] = sym_ctxt
        unet_enc_stop = tm.time()
    print(len(enc_wts[0]))

    unet_agg_st = tm.time()
    agg_wts = aggregate_wts(enc_wts)
    unet_agg_stop = tm.time()

    for client_id, client in enumerate(fml.clients):
        unet_dec_st = tm.time()
        new_wts = decrypt_wts_ckks(agg_wts)
        resnet_dec_stop = tm.time()
        client.set_weights(new_wts)