In [35]:
!pip install Pyfhel
!pip install pynacl
!pip install cryptography
!pip install tqdm
!pip install scikit-learn



In [36]:
no_clients = 3
epochs = 10

In [37]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)

# Check if GPU is available
gpus = tf.config.list_physical_devices("GPU")
if gpus:
	print("GPUs available:", len(gpus))
	for gpu in gpus:
		print(gpu)
else:
	print("No GPU available.")

TensorFlow version: 2.16.1
No GPU available.


In [38]:
import tensorflow as tf
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
import matplotlib.pyplot as plt
import numpy as np
import numpy as np
import os
import tensorflow as tf
from Pyfhel import Pyfhel
import nacl.utils
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
import nacl.utils
from nacl.public import PrivateKey, SealedBox
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import Accuracy

# from src.models.FMLEE import FMLEE
# from src.data.load_data import load_mnist

In [39]:
import os

os.environ["TF_USE_LEGACY_KERAS"] = "True"

In [40]:


save_dir = "dataset/mnist_data/"
os.makedirs(save_dir, exist_ok=True)

In [41]:
import tensorflow as tf


class MAML(tf.keras.Model):
	def __init__(self, model):
		super(MAML, self).__init__()
		self.model = model

	def call(self, inputs):
		x = tf.reshape(inputs, (-1, 28, 28, 1))  # Reshape the input tensor
		return self.model(x)

	def get_config(self):
		return {"model": self.model.get_config()}

	@classmethod
	def from_config(cls, config):
		model = tf.keras.models.Model.from_config(config["model"])
		return cls(model)

	def train_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		with tf.GradientTape() as tape:
			y_pred = self.model(x)
			loss = self.compiled_loss(y, y_pred)
		gradients = tape.gradient(loss, self.model.trainable_variables)
		self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}

	def test_step(self, data):
		x, y = data
		x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
		y = tf.reshape(y, (-1,))  # Reshape the target labels
		y_pred = self.model(x)
		self.compiled_loss(y, y_pred)
		self.compiled_metrics.update_state(y, y_pred)
		return {m.name: m.result() for m in self.metrics}


num_meta_updates = 10
num_inner_updates = 5
meta_batch_size = 32
inner_batch_size = 10

In [42]:
class FMLEE:
	def __init__(self, no_clients, epochs):
		self.no_clients = no_clients
		self.epochs = epochs
		print("Initializing CKKS scheme...")
		self.HE = self.CKKS()
		self.clients = []
		print("Initializing clients...")
		self.init_clients()
		print("Generating asymmetric keys...")
		self.pvt_key, self.pub_key = self.asym_keygen()
		print("Initialization complete.")

	def model_spec(self):
		model = tf.keras.models.Sequential(
			[
				tf.keras.layers.Conv2D(
					32, (3, 3), activation="relu", input_shape=(28, 28, 1)
				),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.MaxPooling2D((2, 2)),
				tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
				tf.keras.layers.Flatten(),
				tf.keras.layers.Dense(64, activation="relu"),
				tf.keras.layers.Dense(10),
			]
		)
		return model

	def init_model(self):
		model = MAML(self.model_spec())
		model.compile(
			optimizer=tf.keras.optimizers.Adam(),
			loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
			metrics=["accuracy"],
		)
		return model

	def CKKS(self):
		HE = Pyfhel()
		ckks_params = {
			"scheme": "CKKS",
			"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
			"scale": 2**30,  # All the encodings will use it for float->fixed point
			"qi_sizes": [
				60,
				30,
				30,
				30,
				60,
			],
		}
		print("Generating context for CKKS scheme...")
		HE.contextGen(**ckks_params)  # Generate context for ckks scheme
		print("Generating keys for CKKS scheme...")
		HE.keyGen()  # Key Generation: generates a pair of public/secret keys
		HE.rotateKeyGen()
		HE.relinKeyGen()
		print("CKKS scheme initialized.")
		return HE

	def asym_keygen(self):
		print("Generating private key...")
		pvt_key = PrivateKey.generate()
		print("Private key generated.")
		pub_key = pvt_key.public_key
		print("Public key generated.")
		return pvt_key, pub_key

	def init_clients(self):
		for i in range(self.no_clients):
			print(f"Initializing model for client {i}...")
			self.clients.append(self.init_model())
			print(f"Client {i} initialized.")

In [43]:
def download_and_save_mnist(save_dir):
	(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

	# Save training data with progress bar
	for array, name in zip(
		[x_train_all, y_train_all, x_test, y_test],
		["x_train.npy", "y_train.npy", "x_test.npy", "y_test.npy"],
	):
		with tqdm(total=len(array), desc=f"Saving {name}") as pbar:
			np.save(os.path.join(save_dir, name), array)
			pbar.update(len(array))

	print(f"Dataset downloaded and saved locally at {save_dir}")


def load_mnist_from_local(save_dir):
	x_train_all = np.load(os.path.join(save_dir, "x_train.npy"))
	y_train_all = np.load(os.path.join(save_dir, "y_train.npy"))
	x_test = np.load(os.path.join(save_dir, "x_test.npy"))
	y_test = np.load(os.path.join(save_dir, "y_test.npy"))
	print(f"Dataset loaded from local files at {save_dir}")
	x_train_all = x_train_all.astype(np.float32) / 255
	x_test = x_test.astype(np.float32) / 255

	return (x_train_all, y_train_all), (x_test, y_test)


def load_mnist():
	if not os.path.exists(os.path.join(save_dir, "x_train.npy")):
		download_and_save_mnist(save_dir)
	return load_mnist_from_local(save_dir)

In [44]:
(x_train_all, y_train_all), (x_test, y_test)  = load_mnist()


Dataset loaded from local files at dataset/mnist_data/


In [45]:
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split

# Load MNIST data
(x_train_all, y_train_all), (x_test, y_test) = mnist.load_data()

# Normalize and reshape data
x_train_all = x_train_all.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

# Splitting data into training and test sets
print("Splitting data into training and test sets...")
X_train, X_temp, y_train, y_temp = train_test_split(
	x_train_all, y_train_all, test_size=0.2, random_state=42
)
print("Data split complete.")
print(
	f"Training set size: {len(X_train)}, Temp set size: {len(X_temp)}, Test set size: {len(x_test)}"
)

# Further split the temporary set into validation and testing sets
print("Splitting temp set into validation and testing sets...")
X_val, X_test, y_val, y_test = train_test_split(
	X_temp, y_temp, test_size=0.15, random_state=42
)
print("Validation and test set split complete.")
print(f"Validation set size: {len(X_val)}, Test set size: {len(X_test)}")

# Split training data into n parts
n_parts = no_clients
part_size = len(X_train) // n_parts
dataset_parts = []

print(f"Splitting training data into {n_parts} parts...")
for i in range(n_parts):
	start = i * part_size
	end = (i + 1) * part_size if i != n_parts - 1 else len(X_train)
	X_part = X_train[start:end]
	y_part = y_train[start:end]
	dataset_parts.append((X_part, y_part))
	print(f"Part {i + 1} created: {len(X_part)} samples.")

print("Data splitting into parts complete.")

Splitting data into training and test sets...
Data split complete.
Training set size: 48000, Temp set size: 12000, Test set size: 10000
Splitting temp set into validation and testing sets...
Validation and test set split complete.
Validation set size: 10200, Test set size: 1800
Splitting training data into 3 parts...
Part 1 created: 16000 samples.
Part 2 created: 16000 samples.
Part 3 created: 16000 samples.
Data splitting into parts complete.


In [46]:
fml = FMLEE(no_clients, epochs)

Initializing CKKS scheme...
Generating context for CKKS scheme...
Generating keys for CKKS scheme...
CKKS scheme initialized.
Initializing clients...
Initializing model for client 0...
Client 0 initialized.
Initializing model for client 1...
Client 1 initialized.
Initializing model for client 2...
Client 2 initialized.
Generating asymmetric keys...
Generating private key...
Private key generated.
Public key generated.
Initialization complete.


In [47]:
fml.clients

[<MAML name=maml_3, built=False>,
 <MAML name=maml_4, built=False>,
 <MAML name=maml_5, built=False>]

In [48]:
fml.HE

<ckks Pyfhel obj at 0x75d07e063490, [pk:Y, sk:Y, rtk:Y, rlk:Y, contx(n=16384, t=0, sec=128, qi=[60, 30, 30, 30, 60], scale=1073741824.0, )]>

In [49]:
# fml.clients[0].fit(x_train_all, y_train_all)

In [50]:
import numpy as np
from Pyfhel import Pyfhel


def CKKS_keygen():
	print("Initializing Pyfhel for CKKS scheme...")
	HE = Pyfhel()

	ckks_params = {
		"scheme": "CKKS",
		"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
		"scale": 2**30,  # All the encodings will use it for float->fixed point
		"qi_sizes": [60, 30, 30, 30, 60],  # Number of bits of each prime in the chain.
	}

	print("Generating context for CKKS scheme...")
	HE.contextGen(**ckks_params)  # Generate context for ckks scheme
	print("Context generation complete.")

	print("Generating public and secret keys...")
	HE.keyGen()  # Key Generation: generates a pair of public/secret keys
	print("Public and secret key generation complete.")

	print("Generating rotation keys...")
	HE.rotateKeyGen()
	print("Rotation keys generation complete.")

	return HE


HE = CKKS_keygen()

Initializing Pyfhel for CKKS scheme...
Generating context for CKKS scheme...
Context generation complete.
Generating public and secret keys...
Public and secret key generation complete.
Generating rotation keys...
Rotation keys generation complete.


In [51]:
def asym_keygen():
	pvt_key = PrivateKey.generate()
	pub_key = pvt_key.public_key
	return pvt_key, pub_key


agg_pvt_key, agg_pub_key = asym_keygen()

In [52]:
def nacl_session_keygen():
	return nacl.utils.random(32)

def encrypt_symmetric_key(pub_key, symmetric_key):
	sealed_box = SealedBox(pub_key)
	return sealed_box.encrypt(symmetric_key)

def decrypt_symmetric_key(pvt_key, encrypted_key):
	sealed_box = SealedBox(pvt_key)
	return sealed_box.decrypt(encrypted_key)



In [53]:
import tensorflow as tf


def maml_train_step(model, x_train, y_train, inner_lr, num_inner_updates):

	model.fit(
			x_train,
			y_train,
			epochs=1,
			batch_size=64,
			verbose=1,
			validation_data=(x_train, y_train),
		)
	with tf.GradientTape() as outer_tape:
		for i in range(num_inner_updates):
			with tf.GradientTape() as inner_tape:
				predictions = model(x_train, training=True)
				loss = tf.reduce_mean(
					tf.keras.losses.sparse_categorical_crossentropy(
						y_train, predictions
					)
				)
			grads = inner_tape.gradient(loss, model.trainable_variables)
			for var, grad in zip(model.trainable_variables, grads):
				if grad is not None:
					var.assign_sub(inner_lr * grad)

		predictions = model(x_train, training=True)
		outer_loss = tf.reduce_mean(
			tf.keras.losses.sparse_categorical_crossentropy(y_train, predictions)
		)

	outer_grads = outer_tape.gradient(outer_loss, model.trainable_variables)
	return outer_loss, outer_grads

In [54]:
def HE_encrypt(wtarray):
	cwt = []
	for layer in wtarray:
		flat_array = layer.astype(np.float64).flatten()

		chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
		clayer = []
		
		for chunk in chunks:
			ptxt = HE.encodeFrac(chunk)
			ctxt = HE.encryptPtxt(ptxt)
			clayer.append(ctxt)
		cwt.append(clayer.copy())
		
	return cwt

In [55]:
def encrypt_message_sym_AES(key, message):
	serialized_obj = pickle.dumps(message)

	iv = nacl.utils.random(16)

	padder = padding.PKCS7(algorithms.AES.block_size).padder()
	padded_obj = padder.update(serialized_obj) + padder.finalize()

	cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
	encryptor = cipher.encryptor()
	ciphertext = encryptor.update(padded_obj) + encryptor.finalize()

	return iv + ciphertext

In [56]:
def decrypt_message_sym_AES(key , ciphertext):
	iv = ciphertext[:16]
	ciphertext = ciphertext[16:]

	cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
	decryptor = cipher.decryptor()
	padded_obj = decryptor.update(ciphertext) + decryptor.finalize()

	unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
	unpadded_obj = unpadder.update(padded_obj) + unpadder.finalize()

	return pickle.loads(unpadded_obj)

In [57]:
def aggregate_wts_ckks(encrypted_wts):
	res_wts = []
	print(len(encrypted_wts) )
	print(len(encrypted_wts[0]) )
	for j in range(len(encrypted_wts[0])):
		layer = []
		for k in range(len(encrypted_wts[0][j])):
			tmp = encrypted_wts[0][j][k].copy()

			for i in range(1, len(encrypted_wts)):
				tmp = tmp + encrypted_wts[i][j][k]

			tmp = tmp / len(encrypted_wts)
			layer.append(tmp)

		res_wts.append(layer.copy())
		
	return res_wts

In [58]:
def decrypt_wts_ckks(encrypted_wts):
	decrypted_wts = []
	wtarray = dummy_model.get_weights()

	for layer_wts , layer in zip(encrypted_wts , wtarray):
		decrypted_layer = []
		flat_array = layer.astype(np.float64).flatten()
		chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)

		for chunk , cchunk in zip(chunks , layer_wts):
			decrypted_chunk = HE.decryptFrac(cchunk)
			original_chunk_size  = len(chunk)
			decrypted_chunk = decrypted_chunk[:original_chunk_size]
			
			decrypted_layer.append(decrypted_chunk)
			
		decrypted_layer = np.concatenate(decrypted_layer)
		decrypted_layer = decrypted_layer.reshape(layer.shape)
		decrypted_wts.append(decrypted_layer)
		
	return decrypted_wts

In [59]:
def aggregate_wts(wts):
	peeled_wts = []
	for client_id in range(no_clients):
		sesion_key = decrypt_symmetric_key(agg_pvt_key , agg_sesion_keys[client_id])
		peeled_wt = decrypt_message_sym_AES(sesion_key, wts[client_id])
		peeled_wts.append(peeled_wt)
	res_wts = aggregate_wts_ckks(peeled_wts)
	return res_wts

In [60]:
dummy_model = fml.clients[0]

In [61]:
agg_sesion_keys = [0 for i in range(no_clients)]

In [62]:
inner_lr = 0.001
num_inner_updates = 1
outer_lr = 0.001


In [63]:
accuracies = [[] for i in range(no_clients)]
losses = [[] for i in range(no_clients)]

In [64]:
client_session_keys = [0 for i in range(no_clients)]

In [65]:
enc_wts = [0 for i in range(no_clients)]

In [66]:
for r in tqdm(range(epochs)):
	for client_id , (client , client_dataset) in enumerate(zip(fml.clients , dataset_parts)):
		model = client
		x_train, y_train = client_dataset

		outer_loss, outer_grads = maml_train_step(
				model, x_train, y_train, inner_lr, num_inner_updates
			)
		optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
		optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
		history = model.evaluate(
			x_train,
			y_train,
			batch_size=64,
			verbose=0,
		)
		accuracies[client_id].append(history[1])
		losses[client_id].append(history[0])
		trained_weights = model.get_weights()

		session_key = nacl_session_keygen()
		client_session_keys[client_id] = session_key
		enc_session_key = encrypt_symmetric_key(agg_pub_key , session_key)
		agg_sesion_keys[client_id] = enc_session_key

		He_ciphertext = HE_encrypt(trained_weights)
		sym_ctxt = encrypt_message_sym_AES(session_key , He_ciphertext)
		enc_wts[client_id] = sym_ctxt
	print(len(enc_wts[0]))
	agg_wts = aggregate_wts(enc_wts)

	for client_id , client in enumerate(fml.clients):
		new_wts = decrypt_wts_ckks(agg_wts)        

		client.set_weights(new_wts)

  0%|          | 0/10 [00:00<?, ?it/s]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.7216 - loss: 0.7292 - val_accuracy: 0.9635 - val_loss: 0.7961
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.7306 - loss: -1.4648 - val_accuracy: 0.9635 - val_loss: -2.1062
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.7350 - loss: -1.5182 - val_accuracy: 0.9624 - val_loss: -2.3780
101725776
3
10


 10%|█         | 1/10 [00:48<07:16, 48.48s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.5791 - loss: -0.8708 - val_accuracy: 0.9104 - val_loss: -1.6532
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5856 - loss: -1.9201 - val_accuracy: 0.9286 - val_loss: -3.4923
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5665 - loss: -1.7410 - val_accuracy: 0.9289 - val_loss: -2.9385
101725776
3
10


 20%|██        | 2/10 [01:34<06:15, 46.96s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.5606 - loss: -1.1808 - val_accuracy: 0.9345 - val_loss: -1.8441
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5462 - loss: -1.3596 - val_accuracy: 0.8953 - val_loss: -2.4912
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.5402 - loss: -1.4272 - val_accuracy: 0.9199 - val_loss: -2.8295
101725776
3
10


 30%|███       | 3/10 [02:20<05:24, 46.38s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6445 - loss: -2.0370 - val_accuracy: 0.9321 - val_loss: -2.8044
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5800 - loss: -1.1738 - val_accuracy: 0.9367 - val_loss: -2.1918
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.5871 - loss: -1.5873 - val_accuracy: 0.9324 - val_loss: -2.6717
101725776
3
10


 40%|████      | 4/10 [03:05<04:36, 46.16s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6225 - loss: -1.6123 - val_accuracy: 0.9421 - val_loss: -2.4566
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.5940 - loss: -1.3814 - val_accuracy: 0.9370 - val_loss: -2.6127
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6279 - loss: -1.5690 - val_accuracy: 0.9384 - val_loss: -2.5204
101725776
3
10


 50%|█████     | 5/10 [03:51<03:50, 46.10s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6348 - loss: -1.8173 - val_accuracy: 0.9471 - val_loss: -3.1523
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6266 - loss: -1.8748 - val_accuracy: 0.9355 - val_loss: -3.0329
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6323 - loss: -1.7521 - val_accuracy: 0.9373 - val_loss: -2.4331
101725776
3
10


 60%|██████    | 6/10 [04:37<03:04, 46.04s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6326 - loss: -1.7109 - val_accuracy: 0.9411 - val_loss: -2.8056
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6415 - loss: -1.8211 - val_accuracy: 0.9394 - val_loss: -2.3398
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6258 - loss: -1.6418 - val_accuracy: 0.9515 - val_loss: -2.6599
101725776
3
10


 70%|███████   | 7/10 [05:23<02:18, 46.00s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6347 - loss: -1.8933 - val_accuracy: 0.9196 - val_loss: -3.3097
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6485 - loss: -1.9488 - val_accuracy: 0.9486 - val_loss: -3.5262
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6521 - loss: -1.8711 - val_accuracy: 0.9559 - val_loss: -3.0917
101725776
3
10


 80%|████████  | 8/10 [06:09<01:31, 45.95s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6289 - loss: -2.0256 - val_accuracy: 0.9539 - val_loss: -3.2408
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6333 - loss: -2.0009 - val_accuracy: 0.9479 - val_loss: -3.2118
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6363 - loss: -1.9173 - val_accuracy: 0.9534 - val_loss: -2.9811
101725776
3
10


 90%|█████████ | 9/10 [06:55<00:45, 45.94s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6218 - loss: -2.1187 - val_accuracy: 0.9406 - val_loss: -3.5638
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 14ms/step - accuracy: 0.6406 - loss: -2.1376 - val_accuracy: 0.9514 - val_loss: -3.6898
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.6431 - loss: -2.1438 - val_accuracy: 0.9568 - val_loss: -3.4688
101725776
3
10


100%|██████████| 10/10 [07:41<00:00, 46.15s/it]


In [67]:
history

[-4.029541492462158, 0.9583125114440918]

In [68]:
dataset_parts[0][1].shape

(16000,)

In [69]:
fmlee_accuracies = accuracies
fmlee_losses = losses

In [70]:
no_fml = FMLEE(no_clients, epochs)

Initializing CKKS scheme...
Generating context for CKKS scheme...
Generating keys for CKKS scheme...
CKKS scheme initialized.
Initializing clients...
Initializing model for client 0...
Client 0 initialized.
Initializing model for client 1...
Client 1 initialized.
Initializing model for client 2...
Client 2 initialized.
Generating asymmetric keys...
Generating private key...
Private key generated.
Public key generated.
Initialization complete.


In [71]:
accuracies = [[] for i in range(no_clients)]
losses = [[] for i in range(no_clients)]

In [73]:
def aggregate_wts_noenc(weight_list):
    n_models = len(weight_list)
    # Initialize the aggregated weights with zeros of the same shape as the first model's weights
    aggregated_weights = [np.zeros_like(w) for w in weight_list[0]]

    # Sum the weights for each layer
    for weights in weight_list:
        for i, weight in enumerate(weights):
            aggregated_weights[i] += weight

    # Divide by the number of models to get the average
    for i in range(len(aggregated_weights)):
        aggregated_weights[i] /= n_models

    return aggregated_weights

In [74]:
for r in tqdm(range(epochs)):
    for client_id, (client, client_dataset) in enumerate(
        zip(fml.clients, dataset_parts)
    ):
        model = client
        x_train, y_train = client_dataset

        outer_loss, outer_grads = maml_train_step(
            model, x_train, y_train, inner_lr, num_inner_updates
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
        optimizer.apply_gradients(zip(outer_grads, model.trainable_variables))
        history = model.evaluate(
            x_train,
            y_train,
            batch_size=64,
            verbose=0,
        )
        accuracies[client_id].append(history[1])
        losses[client_id].append(history[0])
        trained_weights = model.get_weights()

        #     session_key = nacl_session_keygen()
        #     client_session_keys[client_id] = session_key
        #     enc_session_key = encrypt_symmetric_key(agg_pub_key, session_key)
        #     agg_sesion_keys[client_id] = enc_session_key

        #     He_ciphertext = HE_encrypt(trained_weights)
        #     sym_ctxt = encrypt_message_sym_AES(session_key, He_ciphertext)
        enc_wts[client_id] = trained_weights
    # print(len(enc_wts[0]))
    agg_wts = aggregate_wts_noenc(enc_wts)

    for client_id, client in enumerate(fml.clients):
        # new_wts = decrypt_wts_ckks(agg_wts)

        client.set_weights(agg_wts)

  0%|          | 0/10 [00:00<?, ?it/s]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.6277 - loss: -2.1087 - val_accuracy: 0.9430 - val_loss: -3.6915
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.6355 - loss: -2.2194 - val_accuracy: 0.9356 - val_loss: -4.0544
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.6337 - loss: -2.2919 - val_accuracy: 0.9394 - val_loss: -3.8906


 10%|█         | 1/10 [00:41<06:14, 41.59s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9482 - loss: -3.8147 - val_accuracy: 0.9599 - val_loss: -3.6688
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9406 - loss: -3.6164 - val_accuracy: 0.9630 - val_loss: -3.9531
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9466 - loss: -3.6404 - val_accuracy: 0.9691 - val_loss: -4.1117


 20%|██        | 2/10 [01:22<05:29, 41.13s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9656 - loss: -4.3576 - val_accuracy: 0.9683 - val_loss: -4.3057
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9659 - loss: -4.2816 - val_accuracy: 0.9778 - val_loss: -4.0303
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9699 - loss: -4.3272 - val_accuracy: 0.9756 - val_loss: -3.8032


 30%|███       | 3/10 [02:03<04:46, 40.99s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9774 - loss: -4.5882 - val_accuracy: 0.9769 - val_loss: -4.8248
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9762 - loss: -4.6576 - val_accuracy: 0.9768 - val_loss: -4.1148
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9720 - loss: -4.5678 - val_accuracy: 0.9841 - val_loss: -4.0649


 40%|████      | 4/10 [02:44<04:05, 40.92s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9773 - loss: -4.8736 - val_accuracy: 0.9846 - val_loss: -4.8787
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9800 - loss: -4.8226 - val_accuracy: 0.9837 - val_loss: -4.4000
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9813 - loss: -4.8756 - val_accuracy: 0.9854 - val_loss: -5.0519


 50%|█████     | 5/10 [03:24<03:24, 40.83s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9842 - loss: -5.2925 - val_accuracy: 0.9898 - val_loss: -4.9288
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9849 - loss: -5.0627 - val_accuracy: 0.9847 - val_loss: -4.8973
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9857 - loss: -5.1272 - val_accuracy: 0.9843 - val_loss: -5.5619


 60%|██████    | 6/10 [04:05<02:43, 40.85s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 14ms/step - accuracy: 0.9829 - loss: -5.4935 - val_accuracy: 0.9862 - val_loss: -5.1508
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9840 - loss: -5.3474 - val_accuracy: 0.9876 - val_loss: -5.3561
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9848 - loss: -5.6244 - val_accuracy: 0.9876 - val_loss: -5.4084


 70%|███████   | 7/10 [04:46<02:02, 40.84s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9884 - loss: -5.8608 - val_accuracy: 0.9888 - val_loss: -5.4715
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9841 - loss: -5.5287 - val_accuracy: 0.9886 - val_loss: -5.9524
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9862 - loss: -5.6957 - val_accuracy: 0.9912 - val_loss: -5.6242


 80%|████████  | 8/10 [05:27<01:21, 40.81s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9892 - loss: -6.0792 - val_accuracy: 0.9909 - val_loss: -5.5064
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9890 - loss: -5.9512 - val_accuracy: 0.9916 - val_loss: -5.3993
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9878 - loss: -6.1036 - val_accuracy: 0.9921 - val_loss: -5.7827


 90%|█████████ | 9/10 [06:08<00:40, 40.88s/it]

[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9881 - loss: -6.2107 - val_accuracy: 0.9914 - val_loss: -6.2212
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 15ms/step - accuracy: 0.9873 - loss: -6.2943 - val_accuracy: 0.9922 - val_loss: -5.2541
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 16ms/step - accuracy: 0.9893 - loss: -6.1763 - val_accuracy: 0.9933 - val_loss: -6.0266


100%|██████████| 10/10 [06:50<00:00, 41.01s/it]


In [41]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Create a dummy dataset
num_samples = 10000
num_classes = 10
input_shape = (224, 224, 3)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, num_classes, num_samples)

# Convert labels to one-hot encoding
y_dummy = tf.keras.utils.to_categorical(y_dummy, num_classes)

# Load MobileNetV2 with pre-trained weights and exclude the top layers
base_model = MobileNetV2(weights="imagenet", include_top=False, input_shape=input_shape)

# Add custom top layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(num_classes, activation="softmax")(x)

# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])

# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=32)

# Print the training history
print(history.history)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5


2024-06-19 06:24:55.214502: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:559] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
2024-06-19 06:24:55.215441: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:624 : INTERNAL: libdevice not found at ./libdevice.10.bc
2024-06-19 06:24:55.232139: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:559] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
2024-06-19 06:24:55.233222: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:624 : INTERNAL: libdevice not found at ./libdevice.10.bc
2024-06-19 06:24:55.249400: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:559] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
2024-06-19 06:24:55.250296: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at xla_ops.cc:6

InternalError: Graph execution error:

Detected at node Adam/StatefulPartitionedCall_159 defined at (most recent call last):
  File "/home/voy/.conda/envs/he39/lib/python3.9/runpy.py", line 197, in _run_module_as_main

  File "/home/voy/.conda/envs/he39/lib/python3.9/runpy.py", line 87, in _run_code

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/home/voy/.conda/envs/he39/lib/python3.9/asyncio/base_events.py", line 601, in run_forever

  File "/home/voy/.conda/envs/he39/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once

  File "/home/voy/.conda/envs/he39/lib/python3.9/asyncio/events.py", line 80, in _run

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 359, in execute_request

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 446, in do_execute

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code

  File "/tmp/ipykernel_1454134/4227392670.py", line 35, in <module>

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/engine/training.py", line 1783, in fit

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/engine/training.py", line 1377, in train_function

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/engine/training.py", line 1360, in step_function

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/engine/training.py", line 1349, in run_step

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/engine/training.py", line 1130, in train_step

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 544, in minimize

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1223, in apply_gradients

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 652, in apply_gradients

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1253, in _internal_apply_gradients

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1345, in _distributed_apply_gradients_fn

  File "/home/voy/.conda/envs/he39/lib/python3.9/site-packages/keras/src/optimizers/optimizer.py", line 1340, in apply_grad_to_update_var

libdevice not found at ./libdevice.10.bc
	 [[{{node Adam/StatefulPartitionedCall_159}}]] [Op:__inference_train_function_20970]

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    UpSampling2D,
    Concatenate,
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam


# Define U-Net architecture
def unet(input_size=(128, 128, 1)):
    inputs = Input(input_size)

    # Encoder
    conv1 = Conv2D(64, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(1024, (3, 3), activation="relu", padding="same")(conv5)

    # Decoder
    up6 = UpSampling2D(size=(2, 2))(conv5)
    up6 = Conv2D(512, (2, 2), activation="relu", padding="same")(up6)
    merge6 = Concatenate()([conv4, up6])
    conv6 = Conv2D(512, (3, 3), activation="relu", padding="same")(merge6)
    conv6 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv6)

    up7 = UpSampling2D(size=(2, 2))(conv6)
    up7 = Conv2D(256, (2, 2), activation="relu", padding="same")(up7)
    merge7 = Concatenate()([conv3, up7])
    conv7 = Conv2D(256, (3, 3), activation="relu", padding="same")(merge7)
    conv7 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv7)

    up8 = UpSampling2D(size=(2, 2))(conv7)
    up8 = Conv2D(128, (2, 2), activation="relu", padding="same")(up8)
    merge8 = Concatenate()([conv2, up8])
    conv8 = Conv2D(128, (3, 3), activation="relu", padding="same")(merge8)
    conv8 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv8)

    up9 = UpSampling2D(size=(2, 2))(conv8)
    up9 = Conv2D(64, (2, 2), activation="relu", padding="same")(up9)
    merge9 = Concatenate()([conv1, up9])
    conv9 = Conv2D(64, (3, 3), activation="relu", padding="same")(merge9)
    conv9 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv9)
    conv9 = Conv2D(2, (3, 3), activation="relu", padding="same")(conv9)
    conv10 = Conv2D(1, (1, 1), activation="sigmoid")(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model


# Create a dummy dataset
num_samples = 10000
input_shape = (128, 128, 1)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, 2, (num_samples, 128, 128, 1)).astype(np.float32)

# Instantiate the U-Net model
model = unet(input_shape)

# Compile the model
model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy"])

# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=8)

# Print the training history
print(history.history)

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Create a dummy dataset
num_samples = 100
num_classes = 10
input_shape = (224, 224, 3)

# Generate random images and labels
x_dummy = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_dummy = np.random.randint(0, num_classes, num_samples)

# Convert labels to one-hot encoding
y_dummy = tf.keras.utils.to_categorical(y_dummy, num_classes)

# Load ResNet50 with pre-trained weights and exclude the top layers
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)

# Add custom top layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(num_classes, activation="softmax")(x)

# Create the full model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(), loss="categorical_crossentropy", metrics=["accuracy"])

# Train the model for one epoch
history = model.fit(x_dummy, y_dummy, epochs=1, batch_size=32)

# Print the training history
print(history.history)