In [50]:
import torch


def check_cuda():
    if torch.cuda.is_available():
        print("CUDA is available.")
        print(f"Number of GPUs available: {torch.cuda.device_count()}")
        print(f"Current GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}")
        print(f"CUDA version: {torch.version.cuda}")
    else:
        print("CUDA is not available.")


check_cuda()

CUDA is available.
Number of GPUs available: 1
Current GPU: NVIDIA GeForce RTX 3090
CUDA version: 11.8


In [51]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys

import os



In [52]:

from datetime import datetime

# Get the current timestamp
current_timestamp = datetime.now()

# Format the timestamp in a human-readable form
folder_path = current_timestamp.strftime("%d_%H_%M")
fp = f"models/{folder_path}"
if not os.path.exists(fp):
	os.makedirs(fp)


In [53]:
import pandas as pd
import seaborn as sns
from PIL import Image
import os
import matplotlib.pyplot as plt
import cv2

from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as A

import time
import os
from tqdm.notebook import tqdm

from torchsummary import summary
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [54]:
df = pd.read_csv("drone/class_dict_seg.csv")

In [55]:
clients = [0 , 1 , 2]
no_clients = len(clients)
epochs = 100

In [56]:
image_path = "drone/dataset/semantic_drone_dataset/label_images_semantic"
original_image_path = "drone/dataset/semantic_drone_dataset/original_images"

In [57]:
files = os.listdir(image_path)

# Filter out only the image files (assuming image files have extensions like .jpg, .png, etc.)
image_files = [
	file for file in files if file.endswith((".jpg", ".jpeg", ".png", ".gif", ".bmp"))
]
image_files = sorted(image_files)

# list o image_paths
image_list = []
# Print the list of image files
for image_file in image_files:
	image_list.append(image_path + "/" + image_file)


files_2 = os.listdir(original_image_path)
# Filter out only the image files (assuming image files have extensions like .jpg, .png, etc.)
image_files = [
	file for file in files_2 if file.endswith((".jpg", ".jpeg", ".png", ".gif", ".bmp"))
]
image_files = sorted(image_files)

# list o image_paths
original_image_list = []
# Print the list of image files
for image_file in image_files:
	original_image_list.append(original_image_path + "/" + image_file)

In [58]:
IMAGE_PATH = "drone/dataset/semantic_drone_dataset/original_images"
MASK_PATH = "drone/dataset/semantic_drone_dataset/label_images_semantic"

In [59]:
n_classes = 23


def create_df():
	name = []
	for dirname, _, filenames in os.walk(IMAGE_PATH):
		for filename in filenames:
			name.append(filename.split(".")[0])

	return pd.DataFrame({"id": name}, index=np.arange(0, len(name)))


df = create_df()
print("Total Images: ", len(df))

Total Images:  400


In [60]:
df.iloc[0]

id    515
Name: 0, dtype: object

In [61]:
# spliting the data for traning , testing and validation
X_trainval, X_test = train_test_split(df["id"].values, test_size=0.1, random_state=42)
X_train, X_val = train_test_split(X_trainval, test_size=0.15, random_state=42)

In [62]:
split_train = [X_train[i*len(X_train)//no_clients:(i+1)*len(X_train)//no_clients] for i in range(no_clients)]
split_val = [X_val[i*len(X_val)//no_clients:(i+1)*len(X_val)//no_clients] for i in range(no_clients)]

In [63]:
len(split_train[0])

102

In [64]:
class DroneDataset(Dataset):

	def __init__(self, img_path, mask_path, X, mean, std, transform=None, patch=False):
		self.img_path = img_path
		self.mask_path = mask_path
		self.X = X
		self.transform = transform
		self.patches = patch
		self.mean = mean
		self.std = std

	def __len__(self):
		return len(self.X)

	def __getitem__(self, idx):
		img_full_path = os.path.join(self.img_path, self.X[idx] + ".jpg")
		mask_full_path = os.path.join(self.mask_path, self.X[idx] + ".png")

		# Load the image and mask
		img = cv2.imread(img_full_path)
		if img is None:
			raise FileNotFoundError(f"Image not found at {img_full_path}")

		mask = cv2.imread(mask_full_path, cv2.IMREAD_GRAYSCALE)
		if mask is None:
			raise FileNotFoundError(f"Mask not found at {mask_full_path}")

		# Convert color from BGR to RGB
		img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

		if self.transform is not None:
			aug = self.transform(image=img, mask=mask)
			img = Image.fromarray(aug["image"])
			mask = aug["mask"]

		if self.transform is None:
			img = Image.fromarray(img)

		t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
		img = t(img)
		mask = torch.from_numpy(mask).long()

		if self.patches:
			img, mask = self.tiles(img, mask)

		return img, mask

	def tiles(self, img, mask):

		img_patches = img.unfold(1, 512, 512).unfold(2, 768, 768)
		img_patches = img_patches.contiguous().view(3, -1, 512, 768)
		img_patches = img_patches.permute(1, 0, 2, 3)

		mask_patches = mask.unfold(0, 512, 512).unfold(1, 768, 768)
		mask_patches = mask_patches.contiguous().view(-1, 512, 768)

		return img_patches, mask_patches

In [65]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

t_train = A.Compose(
	[
		A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
		A.HorizontalFlip(),
		A.VerticalFlip(),
		A.GridDistortion(p=0.2),
		A.RandomBrightnessContrast((0, 0.5), (0, 0.5)),
		A.GaussNoise(),
	]
)

t_val = A.Compose(
	[
		A.Resize(704, 1056, interpolation=cv2.INTER_NEAREST),
		A.HorizontalFlip(),
		A.GridDistortion(p=0.2),
	]
)

# datasets
train_set = DroneDataset(
	IMAGE_PATH, MASK_PATH, X_train, mean, std, t_train, patch=False
)
val_set = DroneDataset(IMAGE_PATH, MASK_PATH, X_val, mean, std, t_val, patch=False)

# dataloader
batch_size = 1

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

In [66]:
# datasets
train_loaders = []
val_loaders = [] 
for i in range(no_clients):
	b_train = split_train[i]
	b_val = split_val[i]
	train_set = DroneDataset(
		IMAGE_PATH, MASK_PATH, b_train, mean, std, t_train, patch=False
	)
	val_set = DroneDataset(IMAGE_PATH, MASK_PATH, b_val, mean, std, t_val, patch=False)

	# dataloader
	batch_size = 1

	train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
	val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
	train_loaders.append(train_loader)
	val_loaders.append(val_loader)

In [67]:
models = []
for i in range(no_clients):
	model = smp.Unet(
		encoder_name="resnet34",
		encoder_weights="imagenet",
		in_channels=3,
		classes=n_classes,
	)
	models.append(model)

In [68]:
def pixel_accuracy(output, mask):
	with torch.no_grad():
		output = torch.argmax(F.softmax(output, dim=1), dim=1)
		correct = torch.eq(output, mask).int()
		accuracy = float(correct.sum()) / float(correct.numel())
	return accuracy

In [69]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=23):
	with torch.no_grad():
		pred_mask = F.softmax(pred_mask, dim=1)
		pred_mask = torch.argmax(pred_mask, dim=1)
		pred_mask = pred_mask.contiguous().view(-1)
		mask = mask.contiguous().view(-1)

		iou_per_class = []
		for clas in range(0, n_classes):  # loop per pixel class
			true_class = pred_mask == clas
			true_label = mask == clas

			if true_label.long().sum().item() == 0:  # no exist label in this loop
				iou_per_class.append(np.nan)
			else:
				intersect = (
					torch.logical_and(true_class, true_label).sum().float().item()
				)
				union = torch.logical_or(true_class, true_label).sum().float().item()

				iou = (intersect + smooth) / (union + smooth)
				iou_per_class.append(iou)
		return np.nanmean(iou_per_class)

In [70]:
import numpy as np
import matplotlib.pyplot as plt


def plot(pred_masks, true_masks):
	n_classes = len(np.unique(true_masks))  # Number of unique classes
	class_colors = plt.cm.tab20.colors  # Colors for different classes

	n_images = pred_masks.shape[0]  # Number of images

	# Create subplots outside the loop
	fig, axes = plt.subplots(n_images, 2, figsize=(10, 5 * n_images))

	for i in range(n_images):
		im_pred = axes[i, 0].imshow(
			pred_masks[i, 0], cmap="tab20", vmin=0, vmax=n_classes - 1
		)  # Assuming single channel masks
		axes[i, 0].set_title("Predicted Mask")
		axes[i, 0].set_axis_off()
		fig.colorbar(im_pred, ax=axes[i, 0], label="Predicted Class")

		im_true = axes[i, 1].imshow(
			true_masks[i], cmap="tab20", vmin=0, vmax=n_classes - 1
		)
		axes[i, 1].set_title("True Mask")
		axes[i, 1].set_axis_off()
		fig.colorbar(im_true, ax=axes[i, 1], label="True Class")

	plt.tight_layout()
	plt.show()

In [71]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import time
from tqdm import tqdm
import learn2learn as l2l


def get_lr(optimizer):
	for param_group in optimizer.param_groups:
		return param_group["lr"]


def fit(
	epochs,
	model,
	train_loader,
	val_loader,
	criterion,
	optimizer,
	scheduler,
	patch=False,
	adaptation_steps=5,
	inner_lr=0.01,
):
	torch.cuda.empty_cache()
	train_losses = []
	test_losses = []
	val_iou = []
	val_acc = []
	train_iou = []
	train_acc = []
	lrs = []
	min_loss = np.inf
	decrease = 1
	not_improve = 0

	model.to(device)
	fit_time = time.time()
	for e in range(epochs):
		since = time.time()
		running_loss = 0
		iou_score = 0
		accuracy = 0

		# training loop
		model.train()
		for i, data in enumerate(tqdm(train_loader)):
			# training phase
			image_tiles, mask_tiles = data
			if patch:
				bs, n_tiles, c, h, w = image_tiles.size()
				image_tiles = image_tiles.view(-1, c, h, w)
				mask_tiles = mask_tiles.view(-1, h, w)

			image = image_tiles.to(device)
			mask = mask_tiles.to(device)

			# Meta-learning with MAML
			optimizer.zero_grad()

			# Clone the model for inner-loop updates
			learner = l2l.algorithms.MAML(model, lr=inner_lr).clone()

			# Inner loop: Adaptation
			for step in range(adaptation_steps):
				output = learner(image)
				loss = criterion(output, mask)
				learner.adapt(loss)

			# Meta-update
			output = learner(image)
			loss = criterion(output, mask)
			loss.backward()
			optimizer.step()

			# Step the learning rate
			lrs.append(get_lr(optimizer))
			scheduler.step()

			running_loss += loss.item()
			# Evaluation metrics
			iou_score += mIoU(output, mask)
			accuracy += pixel_accuracy(output, mask)

		# Validation loop
		model.eval()
		test_loss = 0
		test_accuracy = 0
		val_iou_score = 0
		with torch.no_grad():
			for i, data in enumerate(tqdm(val_loader)):
				image_tiles, mask_tiles = data
				if patch:
					bs, n_tiles, c, h, w = image_tiles.size()
					image_tiles = image_tiles.view(-1, c, h, w)
					mask_tiles = mask_tiles.view(-1, h, w)

				image = image_tiles.to(device)
				mask = mask_tiles.to(device)

				output = model(image)
				val_iou_score += mIoU(output, mask)
				test_accuracy += pixel_accuracy(output, mask)
				loss = criterion(output, mask)
				test_loss += loss.item()

		# Calculate mean for each batch
		train_losses.append(running_loss / len(train_loader))
		test_losses.append(test_loss / len(val_loader))

		if min_loss > (test_loss / len(val_loader)):
			print(
				"Loss Decreasing.. {:.3f} >> {:.3f} ".format(
					min_loss, (test_loss / len(val_loader))
				)
			)
			min_loss = test_loss / len(val_loader)
			decrease += 1
			if decrease % 5 == 0:
				print("saving model...")
				torch.save(
					model,
					"Unet-Mobilenet_v2_mIoU-{:.3f}.pt".format(
						val_iou_score / len(val_loader)
					),
				)

		if (test_loss / len(val_loader)) > min_loss:
			not_improve += 1
			min_loss = test_loss / len(val_loader)
			print(f"Loss Not Decrease for {not_improve} time")
			if not_improve == 7:
				print("Loss not decrease for 7 times, Stop Training")
				break

		# Update metrics
		val_iou.append(val_iou_score / len(val_loader))
		train_iou.append(iou_score / len(train_loader))
		train_acc.append(accuracy / len(train_loader))
		val_acc.append(test_accuracy / len(val_loader))
		print(
			"Epoch:{}/{}..".format(e + 1, epochs),
			"Train Loss: {:.3f}..".format(running_loss / len(train_loader)),
			"Val Loss: {:.3f}..".format(test_loss / len(val_loader)),
			"Train mIoU:{:.3f}..".format(iou_score / len(train_loader)),
			"Val mIoU: {:.3f}..".format(val_iou_score / len(val_loader)),
			"Train Acc:{:.3f}..".format(accuracy / len(train_loader)),
			"Val Acc:{:.3f}..".format(test_accuracy / len(val_loader)),
			"Time: {:.2f}m".format((time.time() - since) / 60),
		)

	history = {
		"train_loss": train_losses,
		"val_loss": test_losses,
		"train_miou": train_iou,
		"val_miou": val_iou,
		"train_acc": train_acc,
		"val_acc": val_acc,
		"lrs": lrs,
	}
	print("Total time: {:.2f} m".format((time.time() - fit_time) / 60))
	return history

In [72]:
# max_lr = 1e-3
# epoch = 2
# weight_decay = 1e-4

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
# sched = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr, epochs=epoch, steps_per_epoch=len(train_loader)
# )

# history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)

In [73]:
def fed_train(model):
	history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)
	return model , history

In [74]:
import numpy as np
from Pyfhel import Pyfhel

HE = Pyfhel()
ckks_params = {
	"scheme": "CKKS",
	"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
	"scale": 2**30,  # All the encodings will use it for float->fixed point
	"qi_sizes": [60, 30, 30, 30, 60],  # Number of bits of each prime in the chain.
}
HE.contextGen(**ckks_params)  # Generate context for ckks scheme
HE.keyGen()  # Key Generation: generates a pair of public/secret keys
HE.rotateKeyGen()

In [75]:
def generate_diffie_hellman_parameters():
	parameters = dh.generate_parameters(generator=2, key_size=512)
	return parameters


def generate_diffie_hellman_keys(parameters):
	private_key = parameters.generate_private_key()
	public_key = private_key.public_key()
	return private_key, public_key


def derive_key(private_key, peer_public_key):
	shared_key = private_key.exchange(peer_public_key)
	derived_key = HKDF(
		algorithm=hashes.SHA256(),
		length=32,
		salt=None,
		info=b"handshake data",
	).derive(shared_key)
	return derived_key


def encrypt_message_AES(key, message):
	serialized_obj = pickle.dumps(message)
	cipher = Cipher(algorithms.AES(key), modes.ECB())
	encryptor = cipher.encryptor()
	padded_obj = serialized_obj + b" " * (16 - len(serialized_obj) % 16)
	ciphertext = encryptor.update(padded_obj) + encryptor.finalize()
	return ciphertext


def decrypt_message_AES(key, ciphertext):
	cipher = Cipher(algorithms.AES(key), modes.ECB())
	decryptor = cipher.decryptor()
	padded_obj = decryptor.update(ciphertext) + decryptor.finalize()
	serialized_obj = padded_obj.rstrip(b" ")
	obj = pickle.loads(serialized_obj)
	return obj


def setup_AES():
	num_clients = len(clients)
	parameters = generate_diffie_hellman_parameters()
	server_private_key, server_public_key = generate_diffie_hellman_keys(parameters)
	client_keys = [generate_diffie_hellman_keys(parameters) for _ in range(num_clients)]
	shared_keys = [
		derive_key(server_private_key, client_public_key)
		for _, client_public_key in client_keys
	]
	client_shared_keys = [
		derive_key(client_private_key, server_public_key)
		for client_private_key, _ in client_keys
	]

	return client_keys, shared_keys, client_shared_keys

client_keys, shared_keys, client_shared_keys = setup_AES()

In [76]:
def load_weights(model, weights):
	with torch.no_grad(): 
		for param, weight in zip(model.parameters(), weights):
			param.copy_(torch.tensor(weight))
	return model

In [77]:
def get_weights(model):
	return [param.cpu().detach().numpy() for param in model.parameters()]

In [78]:
def aggregate_wt(encypted_cwts):
	# cwts = []
	# for i, ecwt in enumerate(encypted_cwts):
	# 	cwts.append(decrypt_message_AES(shared_keys[i], ecwt))
	cwts = encypted_cwts
	resmodel = []
	for j in range(len(cwts[0])):  # for layers
		layer = []
		for k in range(len(cwts[0][j])):  # for chunks
			tmp = cwts[0][j][k].copy()
			for i in range(1, len(cwts)):  # for clients
				tmp = tmp + cwts[i][j][k]
			tmp = tmp / len(cwts)
			layer.append(tmp)
		resmodel.append(layer)

	res = [resmodel.copy() for _ in range(len(clients))]
	return res

In [79]:
def encrypt_wt(wtarray, i):
	cwt = []
	for layer in wtarray:
		flat_array = layer.astype(np.float64).flatten()

		chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)
		clayer = []
		for chunk in chunks:
			ptxt = HE.encodeFrac(chunk)
			ctxt = HE.encryptPtxt(ptxt)
			clayer.append(ctxt)
		cwt.append(clayer.copy())
	# ciphertext = encrypt_message_AES(client_shared_keys[i], cwt)
	# return ciphertext
	return cwt

In [80]:
def decrypt_weights(res):
	decrypted_weights = []
	for client_weights, model in zip(res, models):
		decrypted_client_weights = []
		wtarray = get_weights(model)
		for layer_weights, layer in zip(client_weights, wtarray):
			decrypted_layer_weights = []
			flat_array = layer.astype(np.float64).flatten()
			chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)
			for chunk, encrypted_chunk in zip(chunks, layer_weights):
				decrypted_chunk = HE.decryptFrac(encrypted_chunk)
				original_chunk_size = len(chunk)
				decrypted_chunk = decrypted_chunk[:original_chunk_size]
				decrypted_layer_weights.append(decrypted_chunk)
			decrypted_layer_weights = np.concatenate(decrypted_layer_weights, axis=0)
			decrypted_layer_weights = decrypted_layer_weights.reshape(layer.shape)
			decrypted_client_weights.append(decrypted_layer_weights)
		decrypted_weights.append(decrypted_client_weights)
	return decrypted_weights

In [81]:
max_lr = 0.01
weight_decay = 0.01

In [82]:
# histories = []
# cwts = [encrypt_wt(get_weights(model), i) for i, model in enumerate(models)]
# print("Initial encrypted weights generated for all clients.")

# for e in range(epochs):
#     print(f"Epoch {e+1}/{epochs} started")
	
#     cwts = aggregate_wt(cwts)
#     print(f"Aggregated encrypted weights after epoch {e+1}")
	
#     wts = decrypt_weights(cwts)
#     print(f"Decrypted weights after aggregation for epoch {e+1}")
	
#     cwts = []
#     epoch_histories = []
	
#     for i in range(no_clients):
#         print(f"Client {i} preparing to load weights and datasets for epoch {e+1}")
		
#         wt = wts[i]
#         model = load_weights(models[i], wt)
		
#         train_loader = train_loaders[i]
#         val_loader = val_loaders[i]
		
#         criterion = nn.CrossEntropyLoss()
#         optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
#         sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=1, steps_per_epoch=len(train_loader))
		
#         print(f"Client {i} Epoch {e+1} started")
		
#         history = fit(1, model, train_loader, val_loader, criterion, optimizer, sched)
		
#         epoch_histories.append(history)
#         print(f"Client {i} Epoch {e+1} completed with history: {history}")
		
#         wtarray = get_weights(model)
#         cwts.append(encrypt_wt(wtarray, i))
#         print(f"Client {i} weights encrypted and added to cwts for epoch {e+1}")

#     histories.append(epoch_histories)
#     print(f"Epoch {e+1} completed and histories updated")

# print("Training completed.")


In [83]:
# from tqdm import tqdm

# histories = []
# previous_losses = {i: [] for i in range(no_clients)}

# cwts = [encrypt_wt(get_weights(model), i) for i, model in enumerate(models)]
# print("Initial encrypted weights generated for all clients.")

# for e in tqdm(range(epochs), desc="Epochs", colour='green'):
#     print(f"Epoch {e+1}/{epochs} started")
	
#     cwts = aggregate_wt(cwts)
#     print(f"Aggregated encrypted weights after epoch {e+1}")
	
#     wts = decrypt_weights(cwts)
#     print(f"Decrypted weights after aggregation for epoch {e+1}")
	
#     cwts = []
#     epoch_histories = []
	
#     for i in range(no_clients):
#         print(f"Client {i} preparing to load weights and datasets for epoch {e+1}")
		
#         wt = wts[i]
#         model = load_weights(models[i], wt)
		
#         train_loader = train_loaders[i]
#         val_loader = val_loaders[i]
		
#         criterion = nn.CrossEntropyLoss()
#         optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
#         sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=1, steps_per_epoch=len(train_loader))
		
#         print(f"Client {i} previous losses before epoch {e+1}: {previous_losses[i]}")
#         print(f"Client {i} Epoch {e+1} started")
		
#         history = fit(1, model, tqdm(train_loader, desc=f"Client {i} Training", colour='blue'), val_loader, criterion, optimizer, sched)
		
#         epoch_histories.append(history)
#         print(f"Client {i} Epoch {e+1} completed with history: {history}")
		
#         previous_losses[i].append({
#             'train_loss': history['train_loss'][-1],
#             'val_loss': history['val_loss'][-1]
#         })
#         print(f"Client {i} previous losses updated: {previous_losses[i]}")
		
#         wtarray = get_weights(model)
#         cwts.append(encrypt_wt(wtarray, i))
#         print(f"Client {i} weights encrypted and added to cwts for epoch {e+1}")

#     histories.append(epoch_histories)
#     print(f"Epoch {e+1} completed and histories updated")

# print("Training completed.")



In [84]:
# # Plotting accuracy over time
# train_accuracies = {i: [] for i in range(no_clients)}
# val_accuracies = {i: [] for i in range(no_clients)}

# for epoch_histories in histories:
#     for i, history in enumerate(epoch_histories):
#         train_accuracies[i].append(history["train_accuracy"][-1])
#         val_accuracies[i].append(history["val_accuracy"][-1])

# plt.figure(figsize=(10, 6))

# for i in range(no_clients):
#     plt.plot(train_accuracies[i], label=f"Client {i} Train Accuracy")
#     plt.plot(val_accuracies[i], label=f"Client {i} Val Accuracy")

# plt.xlabel("Epoch")
# plt.ylabel("Accuracy")
# plt.title("Training and Validation Accuracy Over Time")
# plt.legend()
# plt.grid(True)
# plt.show()

In [85]:

# for e, epoch_histories in enumerate(histories):
#     print(f"Epoch {e+1} histories:")
#     for i, history in enumerate(epoch_histories):
#         print(f"  Client {i}: {history}")

In [86]:
def fit(
	epochs,
	model,
	train_loader,
	val_loader,
	criterion,
	optimizer,
	scheduler,
	patch=False,
	adaptation_steps=5,
	inner_lr=0.01,
):
	torch.cuda.empty_cache()
	train_losses = []
	test_losses = []
	val_iou = []
	val_acc = []
	train_iou = []
	train_acc = []
	lrs = []
	min_loss = np.inf
	decrease = 1
	not_improve = 0

	model.to(device)
	fit_time = time.time()
	for e in range(epochs):
		since = time.time()
		running_loss = 0
		iou_score = 0
		accuracy = 0

		# training loop
		model.train()
		for i, data in enumerate(tqdm(train_loader)):
			# training phase
			image_tiles, mask_tiles = data
			if patch:
				bs, n_tiles, c, h, w = image_tiles.size()
				image_tiles = image_tiles.view(-1, c, h, w)
				mask_tiles = mask_tiles.view(-1, h, w)

			image = image_tiles.to(device)
			mask = mask_tiles.to(device)

			# Meta-learning with MAML
			optimizer.zero_grad()

			# Clone the model for inner-loop updates
			learner = l2l.algorithms.MAML(model, lr=inner_lr).clone()

			# Inner loop: Adaptation
			for step in range(adaptation_steps):
				output = learner(image)
				loss = criterion(output, mask)
				learner.adapt(loss)

			# Meta-update
			output = learner(image)
			loss = criterion(output, mask)
			loss.backward()
			optimizer.step()

			# Step the learning rate
			lrs.append(get_lr(optimizer))
			scheduler.step()

			running_loss += loss.item()
			# Evaluation metrics
			iou_score += mIoU(output, mask)
			accuracy += pixel_accuracy(output, mask)

		# Validation loop
		model.eval()
		test_loss = 0
		test_accuracy = 0
		val_iou_score = 0
		with torch.no_grad():
			for i, data in enumerate(tqdm(val_loader)):
				image_tiles, mask_tiles = data
				if patch:
					bs, n_tiles, c, h, w = image_tiles.size()
					image_tiles = image_tiles.view(-1, c, h, w)
					mask_tiles = mask_tiles.view(-1, h, w)

				image = image_tiles.to(device)
				mask = mask_tiles.to(device)

				output = model(image)
				val_iou_score += mIoU(output, mask)
				test_accuracy += pixel_accuracy(output, mask)
				loss = criterion(output, mask)
				test_loss += loss.item()

		# Calculate mean for each batch
		train_losses.append(running_loss / len(train_loader))
		test_losses.append(test_loss / len(val_loader))

		if min_loss > (test_loss / len(val_loader)):
			print(
				"Loss Decreasing.. {:.3f} >> {:.3f} ".format(
					min_loss, (test_loss / len(val_loader))
				)
			)
			min_loss = test_loss / len(val_loader)
			decrease += 1
			if decrease % 5 == 0:
				print("saving model...")
				torch.save(
					model,
					"Unet-Mobilenet_v2_mIoU-{:.3f}.pt".format(
						val_iou_score / len(val_loader)
					),
				)

		if (test_loss / len(val_loader)) > min_loss:
			not_improve += 1
			min_loss = test_loss / len(val_loader)
			print(f"Loss Not Decrease for {not_improve} time")
			if not_improve == 7:
				print("Loss not decrease for 7 times, Stop Training")
				break

		# Update metrics
		val_iou.append(val_iou_score / len(val_loader))
		train_iou.append(iou_score / len(train_loader))
		train_acc.append(accuracy / len(train_loader))
		val_acc.append(test_accuracy / len(val_loader))
		print(
			"Epoch:{}/{}..".format(e + 1, epochs),
			"Train Loss: {:.3f}..".format(running_loss / len(train_loader)),
			"Val Loss: {:.3f}..".format(test_loss / len(val_loader)),
			"Train mIoU:{:.3f}..".format(iou_score / len(train_loader)),
			"Val mIoU: {:.3f}..".format(val_iou_score / len(val_loader)),
			"Train Acc:{:.3f}..".format(accuracy / len(train_loader)),
			"Val Acc:{:.3f}..".format(test_accuracy / len(val_loader)),
			"Time: {:.2f}m".format((time.time() - since) / 60),
		)

	history = {
		"train_loss": train_losses,
		"val_loss": test_losses,
		"train_miou": train_iou,
		"val_miou": val_iou,
		"train_acc": train_acc,
		"val_acc": val_acc,
		"lrs": lrs,
	}
	print("Total time: {:.2f} m".format((time.time() - fit_time) / 60))
	return history

In [87]:
histories = []
previous_losses = {i: [] for i in range(no_clients)}

cwts = [encrypt_wt(get_weights(model), i) for i, model in enumerate(models)]
print("Initial encrypted weights generated for all clients.")

for e in tqdm(range(epochs), desc="Epochs", colour="green"):
	print(f"Epoch {e+1}/{epochs} started")
	cwts = aggregate_wt(cwts)
	print(f"Aggregated encrypted weights after epoch {e+1}")
	wts = decrypt_weights(cwts)
	print(f"Decrypted weights after aggregation for epoch {e+1}")

	cwts = []
	epoch_histories = []

	for i in range(no_clients):
		print(f"Client {i} preparing for epoch {e+1}")
		wt = wts[i]
		model = load_weights(models[i], wt)
		if (e % 5 == 0) and i == 0:
			torch.save(model, f"{fp}/{e}_model.pth")
		train_loader = train_loaders[i]
		val_loader = val_loaders[i]

		criterion = nn.CrossEntropyLoss()
		optimizer = torch.optim.AdamW(
			model.parameters(), lr=max_lr, weight_decay=weight_decay
		)
		sched = torch.optim.lr_scheduler.OneCycleLR(
			optimizer, max_lr, epochs=1, steps_per_epoch=len(train_loader)
		)

		print(f"Client {i} previous losses: {previous_losses[i]}")
		history = fit(
			1,
			model,
			tqdm(train_loader, desc=f"Client {i} Training", colour="blue"),
			val_loader,
			criterion,
			optimizer,
			sched,
		)
		epoch_histories.append(history)

		previous_losses[i].append(
			{
				"train_loss": history["train_loss"][-1],
				"val_loss": history["val_loss"][-1],
				"train_acc": history["train_acc"][-1],
				"val_acc": history["val_acc"][-1],
			}
		)
		print(f"Client {i} updated losses: {previous_losses[i]}")

		wtarray = get_weights(model)
		cwts.append(encrypt_wt(wtarray, i))
		print(f"Client {i} weights encrypted for epoch {e+1}")

	histories.append(epoch_histories)
	print(f"Epoch {e+1} completed")

print("Training completed.")

Initial encrypted weights generated for all clients.


Epochs:   0%|[32m          [0m| 0/100 [00:00<?, ?it/s]

Epoch 1/100 started
Aggregated encrypted weights after epoch 1
Decrypted weights after aggregation for epoch 1
Client 0 preparing for epoch 1
Client 0 previous losses: []



  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Client 0 Training:  74%|[34m███████▎  [0m| 75/102 [03:20<01:12,  2.68s/it]
 74%|███████▎  | 75/102 [03:20<01:12,  2.68s/it]
Epochs:   0%|[32m          [0m| 0/100 [03:53<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# Initialize dictionaries to store accuracies, losses, and mIoU for each client
train_accuracies = {i: [] for i in range(no_clients)}
val_accuracies = {i: [] for i in range(no_clients)}
train_losses = {i: [] for i in range(no_clients)}
val_losses = {i: [] for i in range(no_clients)}
train_miou = {i: [] for i in range(no_clients)}
val_miou = {i: [] for i in range(no_clients)}

# Populate the dictionaries with data from histories
for epoch_histories in histories:
	for i, history in enumerate(epoch_histories):
		train_accuracies[i].append(history["train_acc"][-1])
		val_accuracies[i].append(history["val_acc"][-1])
		train_losses[i].append(history["train_loss"][-1])
		val_losses[i].append(history["val_loss"][-1])
		train_miou[i].append(history["train_miou"][-1])
		val_miou[i].append(history["val_miou"][-1])

# Plotting training accuracy for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_accuracies[i], label=f"Client {i} Train Accuracy")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Accuracy")
	plt.title(f"Client {i} Training Accuracy Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting validation accuracy for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(val_accuracies[i], label=f"Client {i} Val Accuracy")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Accuracy")
	plt.title(f"Client {i} Validation Accuracy Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting training loss for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_losses[i], label=f"Client {i} Train Loss")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Loss")
	plt.title(f"Client {i} Training Loss Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting validation loss for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(val_losses[i], label=f"Client {i} Val Loss")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Loss")
	plt.title(f"Client {i} Validation Loss Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting mean IoU for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_miou[i], label=f"Client {i} Train mIoU")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Mean IoU")
	plt.title(f"Client {i} Training Mean IoU Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting all clients together for training accuracy
plt.figure(figsize=(10, 6))
for i in range(no_clients):
	plt.plot(train_accuracies[i], label=f"Client {i} Train Accuracy")
plt.xlabel("Aggregation Round")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Over Aggregation Rounds for All Clients")
plt.legend()
plt.grid(True)
plt.show()

# Display the detailed history for each client and each aggregation round
for e, epoch_histories in enumerate(histories):
	print(f"Aggregation Round {e+1} histories:")
	for i, history in enumerate(epoch_histories):
		print(f"  Client {i}: {history}")

In [None]:
class DroneTestDataset(Dataset):

	def __init__(self, img_path, mask_path, X, transform=None):
		self.img_path = img_path
		self.mask_path = mask_path
		self.X = X
		self.transform = transform

	def __len__(self):
		return len(self.X)

	def __getitem__(self, idx):
		img_full_path = os.path.join(self.img_path, self.X[idx] + ".jpg")
		mask_full_path = os.path.join(self.mask_path, self.X[idx] + ".png")

		img = cv2.imread(img_full_path)
		if img is None:
			raise FileNotFoundError(f"Image not found at {img_full_path}")

		mask = cv2.imread(mask_full_path, cv2.IMREAD_GRAYSCALE)
		if mask is None:
			raise FileNotFoundError(f"Mask not found at {mask_full_path}")

		img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

		if self.transform is not None:
			aug = self.transform(image=img, mask=mask)
			img = Image.fromarray(aug["image"])
			mask = aug["mask"]

		if self.transform is None:
			img = Image.fromarray(img)

		mask = torch.from_numpy(mask).long()

		return img, mask


t_test = A.Resize(768, 1152, interpolation=cv2.INTER_NEAREST)
test_set = DroneTestDataset(IMAGE_PATH, MASK_PATH, X_test, transform=t_test)

In [None]:
def predict_image_mask_miou(
	model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
	model.eval()
	t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
	image = t(image)
	model.to(device)
	image = image.to(device)
	mask = mask.to(device)
	with torch.no_grad():

		image = image.unsqueeze(0)
		mask = mask.unsqueeze(0)

		output = model(image)
		score = mIoU(output, mask)
		masked = torch.argmax(output, dim=1)
		masked = masked.cpu().squeeze(0)
	return masked, score

In [None]:
def predict_image_mask_pixel(
	model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
	model.eval()
	t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
	image = t(image)
	model.to(device)
	image = image.to(device)
	mask = mask.to(device)
	with torch.no_grad():

		image = image.unsqueeze(0)
		mask = mask.unsqueeze(0)

		output = model(image)
		acc = pixel_accuracy(output, mask)
		masked = torch.argmax(output, dim=1)
		masked = masked.cpu().squeeze(0)
	return masked, acc

In [None]:
image, mask = test_set[3]
pred_mask, score = predict_image_mask_miou(model, image, mask)

In [None]:
def miou_score(model, test_set):
	score_iou = []
	for i in tqdm(range(len(test_set))):
		img, mask = test_set[i]
		pred_mask, score = predict_image_mask_miou(model, img, mask)
		score_iou.append(score)
	return score_iou

In [None]:
mob_miou = miou_score(model, test_set)

In [None]:
def pixel_acc(model, test_set):
	accuracy = []
	for i in tqdm(range(len(test_set))):
		img, mask = test_set[i]
		pred_mask, acc = predict_image_mask_pixel(model, img, mask)
		accuracy.append(acc)
	return accuracy

In [None]:
mob_acc = pixel_acc(model, test_set)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image)
ax1.set_title("Picture")

ax2.imshow(mask)
ax2.set_title("Ground truth")
ax2.set_axis_off()

ax3.imshow(pred_mask)
ax3.set_title("UNet-MobileNet | mIoU {:.3f}".format(score))
ax3.set_axis_off()

In [None]:
image3, mask3 = test_set[6]
pred_mask3, score3 = predict_image_mask_miou(model, image3, mask3)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image3)
ax1.set_title("Picture")

ax2.imshow(mask3)
ax2.set_title("Ground truth")
ax2.set_axis_off()

ax3.imshow(pred_mask3)
ax3.set_title("UNet-MobileNet | mIoU {:.3f}".format(score3))
ax3.set_axis_off()

In [None]:
print("Test Set mIoU", np.mean(mob_miou))

In [None]:
print("Test Set Pixel Accuracy", np.mean(mob_acc))