In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations for the MNIST dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# Load the MNIST dataset
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)


# Define the neural network architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Instantiate the model, define the loss function and the optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Training the model
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
            )


# Testing the model
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n"
    )


# Running the training and testing loops
for epoch in range(1, 11):
    train(model, train_loader, criterion, optimizer, epoch)
    test(model, test_loader, criterion)


Test set: Average loss: 0.0000, Accuracy: 9843/10000 (98%)


Test set: Average loss: 0.0000, Accuracy: 9867/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9905/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9897/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9876/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9929/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9908/10000 (99%)



KeyboardInterrupt: 

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
import numpy as np
import time
from tqdm import tqdm
import learn2learn as l2l
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def split_dataset_by_classes(dataset, n_clients):
    class_indices = [[] for _ in range(10)]
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    client_indices = [[] for _ in range(n_clients)]
    for class_idx in range(10):
        for client_idx in range(n_clients):
            client_indices[client_idx].extend(
                class_indices[class_idx][client_idx::n_clients]
            )

    client_datasets = [Subset(dataset, indices) for indices in client_indices]
    return client_datasets


def average_weights(models):
    avg_model = copy.deepcopy(models[0])
    for key in avg_model.state_dict().keys():
        avg_model.state_dict()[key] = torch.mean(
            torch.stack([model.state_dict()[key] for model in models]), dim=0
        )
    return avg_model


def federated_fit(
    epochs,
    model,
    client_loaders,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    adaptation_steps=5,
    inner_lr=0.01,
):
    torch.cuda.empty_cache()
    train_losses = []
    val_losses = []
    val_acc = []
    train_acc = []
    lrs = []
    min_loss = np.inf
    decrease = 1
    not_improve = 0

    model.to(device)
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        overall_accuracy = 0

        model.train()
        client_accuracies = []
        client_models = []
        for client_idx, client_loader in enumerate(client_loaders):
            client_model = copy.deepcopy(model)
            client_model.to(device)
            client_optimizer = optim.Adam(
                client_model.parameters(), lr=inner_lr, weight_decay=1e-5
            )
            client_accuracy = 0
            for i, data in enumerate(tqdm(client_loader)):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                client_optimizer.zero_grad()

                learner = l2l.algorithms.MAML(client_model, lr=inner_lr).clone()

                for step in range(adaptation_steps):
                    output = learner(images)
                    loss = criterion(output, labels)
                    learner.adapt(loss)

                output = learner(images)
                loss = criterion(output, labels)
                loss.backward()
                client_optimizer.step()

                running_loss += loss.item()
                client_accuracy += (
                    (output.argmax(dim=1) == labels).float().mean().item()
                )

            client_accuracy /= len(client_loader)
            client_accuracies.append(client_accuracy)
            client_models.append(client_model)
            print(
                f"Epoch {e + 1}, Client {client_idx + 1} Accuracy: {client_accuracy:.4f}"
            )

        # Aggregate the client models' weights
        model = average_weights(client_models)
        model.to(device)

        overall_accuracy = sum(client_accuracies) / len(client_accuracies)

        model.eval()
        test_loss = 0
        test_accuracy = 0
        with torch.no_grad():
            for i, data in enumerate(tqdm(val_loader)):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                output = model(images)
                test_accuracy += (output.argmax(dim=1) == labels).float().mean().item()
                loss = criterion(output, labels)
                test_loss += loss.item()

        train_losses.append(
            running_loss / sum(len(loader.dataset) for loader in client_loaders)
        )
        val_losses.append(test_loss / len(val_loader))
        val_accuracy = test_accuracy / len(val_loader)
        val_acc.append(val_accuracy)

        if min_loss > (test_loss / len(val_loader)):
            print(
                "Loss Decreasing.. {:.3f} >> {:.3f}".format(
                    min_loss, (test_loss / len(val_loader))
                )
            )
            min_loss = test_loss / len(val_loader)
            decrease += 1
            if decrease % 5 == 0:
                print("saving model...")
                torch.save(
                    model,
                    "Federated-MAML-Model-Accuracy-{:.3f}.pt".format(val_accuracy),
                )

        if (test_loss / len(val_loader)) > min_loss:
            not_improve += 1
            min_loss = test_loss / len(val_loader)
            print(f"Loss Not Decrease for {not_improve} time")
            if not_improve == 7:
                print("Loss not decrease for 7 times, Stop Training")
                break

        train_acc.append(overall_accuracy)
        print(
            "Epoch:{}/{}..".format(e + 1, epochs),
            "Train Loss: {:.3f}..".format(
                running_loss / sum(len(loader.dataset) for loader in client_loaders)
            ),
            "Val Loss: {:.3f}..".format(test_loss / len(val_loader)),
            "Train Acc:{:.3f}..".format(overall_accuracy),
            "Val Acc:{:.3f}..".format(val_accuracy),
            "Time: {:.2f}m".format((time.time() - since) / 60),
        )

    history = {
        "train_loss": train_losses,
        "val_loss": val_losses,
        "train_acc": train_acc,
        "val_acc": val_acc,
        "lrs": lrs,
    }
    print("Total time: {:.2f} m".format((time.time() - fit_time) / 60))
    return history


n_clients = 3
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
full_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
client_datasets = split_dataset_by_classes(train_dataset, n_clients)
client_loaders = [
    DataLoader(dataset, batch_size=64, shuffle=True) for dataset in client_datasets
]
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
criterion = nn.CrossEntropyLoss()

history = federated_fit(
    epochs=10,
    model=model,
    client_loaders=client_loaders,
    val_loader=test_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    adaptation_steps=5,
    inner_lr=0.01,
)

100%|██████████| 251/251 [00:11<00:00, 22.34it/s]


Epoch 1, Client 1 Accuracy: 0.9386


100%|██████████| 251/251 [00:11<00:00, 22.34it/s]


Epoch 1, Client 2 Accuracy: 0.7956


100%|██████████| 250/250 [00:11<00:00, 22.33it/s]


Epoch 1, Client 3 Accuracy: 0.9359


100%|██████████| 12/12 [00:01<00:00,  7.98it/s]


Loss Decreasing.. inf >> 13.196
Epoch:1/10.. Train Loss: 0.006.. Val Loss: 13.196.. Train Acc:0.890.. Val Acc:0.608.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 22.18it/s]


Epoch 2, Client 1 Accuracy: 0.9576


100%|██████████| 251/251 [00:11<00:00, 22.16it/s]


Epoch 2, Client 2 Accuracy: 0.6938


100%|██████████| 250/250 [00:11<00:00, 22.03it/s]


Epoch 2, Client 3 Accuracy: 0.9669


100%|██████████| 12/12 [00:01<00:00,  8.32it/s]


Loss Not Decrease for 1 time
Epoch:2/10.. Train Loss: 1.358.. Val Loss: 65.187.. Train Acc:0.873.. Val Acc:0.197.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 22.11it/s]


Epoch 3, Client 1 Accuracy: 0.9654


100%|██████████| 251/251 [00:11<00:00, 22.08it/s]


Epoch 3, Client 2 Accuracy: 0.9492


100%|██████████| 250/250 [00:11<00:00, 21.95it/s]


Epoch 3, Client 3 Accuracy: 0.9587


100%|██████████| 12/12 [00:01<00:00,  8.30it/s]


Loss Decreasing.. 65.187 >> 50.324
Epoch:3/10.. Train Loss: 0.005.. Val Loss: 50.324.. Train Acc:0.958.. Val Acc:0.348.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 22.04it/s]


Epoch 4, Client 1 Accuracy: 0.9701


100%|██████████| 251/251 [00:11<00:00, 22.10it/s]


Epoch 4, Client 2 Accuracy: 0.9173


100%|██████████| 250/250 [00:11<00:00, 22.03it/s]


Epoch 4, Client 3 Accuracy: 0.9592


100%|██████████| 12/12 [00:01<00:00,  8.09it/s]


Loss Not Decrease for 2 time
Epoch:4/10.. Train Loss: 0.006.. Val Loss: 69.231.. Train Acc:0.949.. Val Acc:0.361.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 21.85it/s]


Epoch 5, Client 1 Accuracy: 0.9441


100%|██████████| 251/251 [00:11<00:00, 21.79it/s]


Epoch 5, Client 2 Accuracy: 0.9662


100%|██████████| 250/250 [00:11<00:00, 21.54it/s]


Epoch 5, Client 3 Accuracy: 0.9526


100%|██████████| 12/12 [00:01<00:00,  7.72it/s]


Loss Not Decrease for 3 time
Epoch:5/10.. Train Loss: 0.007.. Val Loss: 94.983.. Train Acc:0.954.. Val Acc:0.404.. Time: 0.60m


100%|██████████| 251/251 [00:11<00:00, 21.97it/s]


Epoch 6, Client 1 Accuracy: 0.8436


100%|██████████| 251/251 [00:11<00:00, 22.70it/s]


Epoch 6, Client 2 Accuracy: 0.7367


100%|██████████| 250/250 [00:11<00:00, 22.27it/s]


Epoch 6, Client 3 Accuracy: 0.8993


100%|██████████| 12/12 [00:01<00:00,  8.39it/s]


Loss Not Decrease for 4 time
Epoch:6/10.. Train Loss: 0.132.. Val Loss: 211.805.. Train Acc:0.827.. Val Acc:0.196.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 22.55it/s]


Epoch 7, Client 1 Accuracy: 0.9468


100%|██████████| 251/251 [00:11<00:00, 22.52it/s]


Epoch 7, Client 2 Accuracy: 0.9349


100%|██████████| 250/250 [00:11<00:00, 22.52it/s]


Epoch 7, Client 3 Accuracy: 0.5614


100%|██████████| 12/12 [00:01<00:00,  8.09it/s]


Loss Not Decrease for 5 time
Epoch:7/10.. Train Loss: 0.169.. Val Loss: 406.192.. Train Acc:0.814.. Val Acc:0.137.. Time: 0.58m


100%|██████████| 251/251 [00:11<00:00, 22.44it/s]


Epoch 8, Client 1 Accuracy: 0.7971


100%|██████████| 251/251 [00:11<00:00, 22.48it/s]


Epoch 8, Client 2 Accuracy: 0.7905


100%|██████████| 250/250 [00:11<00:00, 22.33it/s]


Epoch 8, Client 3 Accuracy: 0.9526


100%|██████████| 12/12 [00:01<00:00,  8.65it/s]


Loss Decreasing.. 406.192 >> 207.780
Epoch:8/10.. Train Loss: 0.025.. Val Loss: 207.780.. Train Acc:0.847.. Val Acc:0.161.. Time: 0.58m


100%|██████████| 251/251 [00:11<00:00, 22.08it/s]


Epoch 9, Client 1 Accuracy: 0.6421


100%|██████████| 251/251 [00:11<00:00, 22.54it/s]


Epoch 9, Client 2 Accuracy: 0.4847


100%|██████████| 250/250 [00:11<00:00, 22.36it/s]


Epoch 9, Client 3 Accuracy: 0.7383


100%|██████████| 12/12 [00:01<00:00,  8.25it/s]


Loss Not Decrease for 6 time
Epoch:9/10.. Train Loss: 243800726645.141.. Val Loss: 5209.912.. Train Acc:0.622.. Val Acc:0.107.. Time: 0.59m


100%|██████████| 251/251 [00:11<00:00, 22.48it/s]


Epoch 10, Client 1 Accuracy: 0.1307


100%|██████████| 251/251 [00:11<00:00, 22.47it/s]


Epoch 10, Client 2 Accuracy: 0.1390


100%|██████████| 250/250 [00:11<00:00, 22.44it/s]


Epoch 10, Client 3 Accuracy: 0.1238


100%|██████████| 12/12 [00:01<00:00,  8.48it/s]


Loss Decreasing.. 5209.912 >> 1681.246
saving model...
Epoch:10/10.. Train Loss: 8884541.179.. Val Loss: 1681.246.. Train Acc:0.131.. Val Acc:0.107.. Time: 0.58m
Total time: 5.88 m


In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys

import os



In [None]:

from datetime import datetime

# Get the current timestamp
current_timestamp = datetime.now()

# Format the timestamp in a human-readable form
folder_path = current_timestamp.strftime("%d_%H_%M")
fp = f"models/{folder_path}"
if not os.path.exists(fp):
	os.makedirs(fp)


In [None]:
import pandas as pd
import seaborn as sns
from PIL import Image
import os
import matplotlib.pyplot as plt
import cv2

from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as A

import time
import os
from tqdm.notebook import tqdm

from torchsummary import summary
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
clients = [0 , 1 , 2]
no_clients = len(clients)
epochs = 100

In [None]:
# max_lr = 1e-3
# epoch = 2
# weight_decay = 1e-4

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
# sched = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer, max_lr, epochs=epoch, steps_per_epoch=len(train_loader)
# )

# history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)

In [None]:
import numpy as np
from Pyfhel import Pyfhel

HE = Pyfhel()
ckks_params = {
	"scheme": "CKKS",
	"n": 2**14,  # Polynomial modulus degree. For CKKS, n/2 values can be
	"scale": 2**30,  # All the encodings will use it for float->fixed point
	"qi_sizes": [60, 30, 30, 30, 60],  # Number of bits of each prime in the chain.
}
HE.contextGen(**ckks_params)  # Generate context for ckks scheme
HE.keyGen()  # Key Generation: generates a pair of public/secret keys
HE.rotateKeyGen()

In [None]:
def generate_diffie_hellman_parameters():
	parameters = dh.generate_parameters(generator=2, key_size=512)
	return parameters


def generate_diffie_hellman_keys(parameters):
	private_key = parameters.generate_private_key()
	public_key = private_key.public_key()
	return private_key, public_key


def derive_key(private_key, peer_public_key):
	shared_key = private_key.exchange(peer_public_key)
	derived_key = HKDF(
		algorithm=hashes.SHA256(),
		length=32,
		salt=None,
		info=b"handshake data",
	).derive(shared_key)
	return derived_key


def encrypt_message_AES(key, message):
	serialized_obj = pickle.dumps(message)
	cipher = Cipher(algorithms.AES(key), modes.ECB())
	encryptor = cipher.encryptor()
	padded_obj = serialized_obj + b" " * (16 - len(serialized_obj) % 16)
	ciphertext = encryptor.update(padded_obj) + encryptor.finalize()
	return ciphertext


def decrypt_message_AES(key, ciphertext):
	cipher = Cipher(algorithms.AES(key), modes.ECB())
	decryptor = cipher.decryptor()
	padded_obj = decryptor.update(ciphertext) + decryptor.finalize()
	serialized_obj = padded_obj.rstrip(b" ")
	obj = pickle.loads(serialized_obj)
	return obj


def setup_AES():
	num_clients = len(clients)
	parameters = generate_diffie_hellman_parameters()
	server_private_key, server_public_key = generate_diffie_hellman_keys(parameters)
	client_keys = [generate_diffie_hellman_keys(parameters) for _ in range(num_clients)]
	shared_keys = [
		derive_key(server_private_key, client_public_key)
		for _, client_public_key in client_keys
	]
	client_shared_keys = [
		derive_key(client_private_key, server_public_key)
		for client_private_key, _ in client_keys
	]

	return client_keys, shared_keys, client_shared_keys

client_keys, shared_keys, client_shared_keys = setup_AES()

In [None]:
def load_weights(model, weights):
	with torch.no_grad(): 
		for param, weight in zip(model.parameters(), weights):
			param.copy_(torch.tensor(weight))
	return model

In [None]:
def get_weights(model):
	return [param.cpu().detach().numpy() for param in model.parameters()]

In [None]:
def aggregate_wt(encypted_cwts):
	# cwts = []
	# for i, ecwt in enumerate(encypted_cwts):
	# 	cwts.append(decrypt_message_AES(shared_keys[i], ecwt))
	cwts = encypted_cwts
	resmodel = []
	for j in range(len(cwts[0])):  # for layers
		layer = []
		for k in range(len(cwts[0][j])):  # for chunks
			tmp = cwts[0][j][k].copy()
			for i in range(1, len(cwts)):  # for clients
				tmp = tmp + cwts[i][j][k]
			tmp = tmp / len(cwts)
			layer.append(tmp)
		resmodel.append(layer)

	res = [resmodel.copy() for _ in range(len(clients))]
	return res

In [None]:
def encrypt_wt(wtarray, i):
	cwt = []
	for layer in wtarray:
		flat_array = layer.astype(np.float64).flatten()

		chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)
		clayer = []
		for chunk in chunks:
			ptxt = HE.encodeFrac(chunk)
			ctxt = HE.encryptPtxt(ptxt)
			clayer.append(ctxt)
		cwt.append(clayer.copy())
	# ciphertext = encrypt_message_AES(client_shared_keys[i], cwt)
	# return ciphertext
	return cwt

In [None]:
def decrypt_weights(res):
	decrypted_weights = []
	for client_weights, model in zip(res, models):
		decrypted_client_weights = []
		wtarray = get_weights(model)
		for layer_weights, layer in zip(client_weights, wtarray):
			decrypted_layer_weights = []
			flat_array = layer.astype(np.float64).flatten()
			chunks = np.array_split(flat_array, (len(flat_array) + 2**13 - 1) // 2**13)
			for chunk, encrypted_chunk in zip(chunks, layer_weights):
				decrypted_chunk = HE.decryptFrac(encrypted_chunk)
				original_chunk_size = len(chunk)
				decrypted_chunk = decrypted_chunk[:original_chunk_size]
				decrypted_layer_weights.append(decrypted_chunk)
			decrypted_layer_weights = np.concatenate(decrypted_layer_weights, axis=0)
			decrypted_layer_weights = decrypted_layer_weights.reshape(layer.shape)
			decrypted_client_weights.append(decrypted_layer_weights)
		decrypted_weights.append(decrypted_client_weights)
	return decrypted_weights

In [None]:
max_lr = 0.01
weight_decay = 0.01

In [None]:
histories = []
previous_losses = {i: [] for i in range(no_clients)}

cwts = [encrypt_wt(get_weights(model), i) for i, model in enumerate(models)]
print("Initial encrypted weights generated for all clients.")

for e in tqdm(range(epochs), desc="Epochs", colour="green"):
	print(f"Epoch {e+1}/{epochs} started")
	cwts = aggregate_wt(cwts)
	print(f"Aggregated encrypted weights after epoch {e+1}")
	wts = decrypt_weights(cwts)
	print(f"Decrypted weights after aggregation for epoch {e+1}")

	cwts = []
	epoch_histories = []

	for i in range(no_clients):
		print(f"Client {i} preparing for epoch {e+1}")
		wt = wts[i]
		model = load_weights(models[i], wt)
		if (e % 5 == 0) and i == 0:
			torch.save(model, f"{fp}/{e}_model.pth")
		train_loader = train_loaders[i]
		val_loader = val_loaders[i]

		criterion = nn.CrossEntropyLoss()
		optimizer = torch.optim.AdamW(
			model.parameters(), lr=max_lr, weight_decay=weight_decay
		)
		sched = torch.optim.lr_scheduler.OneCycleLR(
			optimizer, max_lr, epochs=1, steps_per_epoch=len(train_loader)
		)

		print(f"Client {i} previous losses: {previous_losses[i]}")
		history = fit(
			1,
			model,
			tqdm(train_loader, desc=f"Client {i} Training", colour="blue"),
			val_loader,
			criterion,
			optimizer,
			sched,
		)
		epoch_histories.append(history)

		previous_losses[i].append(
			{
				"train_loss": history["train_loss"][-1],
				"val_loss": history["val_loss"][-1],
				"train_acc": history["train_acc"][-1],
				"val_acc": history["val_acc"][-1],
			}
		)
		print(f"Client {i} updated losses: {previous_losses[i]}")

		wtarray = get_weights(model)
		cwts.append(encrypt_wt(wtarray, i))
		print(f"Client {i} weights encrypted for epoch {e+1}")

	histories.append(epoch_histories)
	print(f"Epoch {e+1} completed")

print("Training completed.")

In [None]:
import matplotlib.pyplot as plt

# Initialize dictionaries to store accuracies, losses, and mIoU for each client
train_accuracies = {i: [] for i in range(no_clients)}
val_accuracies = {i: [] for i in range(no_clients)}
train_losses = {i: [] for i in range(no_clients)}
val_losses = {i: [] for i in range(no_clients)}
train_miou = {i: [] for i in range(no_clients)}
val_miou = {i: [] for i in range(no_clients)}

# Populate the dictionaries with data from histories
for epoch_histories in histories:
	for i, history in enumerate(epoch_histories):
		train_accuracies[i].append(history["train_acc"][-1])
		val_accuracies[i].append(history["val_acc"][-1])
		train_losses[i].append(history["train_loss"][-1])
		val_losses[i].append(history["val_loss"][-1])
		train_miou[i].append(history["train_miou"][-1])
		val_miou[i].append(history["val_miou"][-1])

# Plotting training accuracy for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_accuracies[i], label=f"Client {i} Train Accuracy")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Accuracy")
	plt.title(f"Client {i} Training Accuracy Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting validation accuracy for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(val_accuracies[i], label=f"Client {i} Val Accuracy")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Accuracy")
	plt.title(f"Client {i} Validation Accuracy Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting training loss for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_losses[i], label=f"Client {i} Train Loss")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Loss")
	plt.title(f"Client {i} Training Loss Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting validation loss for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(val_losses[i], label=f"Client {i} Val Loss")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Loss")
	plt.title(f"Client {i} Validation Loss Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting mean IoU for each client independently
for i in range(no_clients):
	plt.figure(figsize=(10, 6))
	plt.plot(train_miou[i], label=f"Client {i} Train mIoU")
	plt.xlabel("Aggregation Round")
	plt.ylabel("Mean IoU")
	plt.title(f"Client {i} Training Mean IoU Over Aggregation Rounds")
	plt.legend()
	plt.grid(True)
	plt.show()

# Plotting all clients together for training accuracy
plt.figure(figsize=(10, 6))
for i in range(no_clients):
	plt.plot(train_accuracies[i], label=f"Client {i} Train Accuracy")
plt.xlabel("Aggregation Round")
plt.ylabel("Accuracy")
plt.title("Training Accuracy Over Aggregation Rounds for All Clients")
plt.legend()
plt.grid(True)
plt.show()

# Display the detailed history for each client and each aggregation round
for e, epoch_histories in enumerate(histories):
	print(f"Aggregation Round {e+1} histories:")
	for i, history in enumerate(epoch_histories):
		print(f"  Client {i}: {history}")

In [None]:
class DroneTestDataset(Dataset):

	def __init__(self, img_path, mask_path, X, transform=None):
		self.img_path = img_path
		self.mask_path = mask_path
		self.X = X
		self.transform = transform

	def __len__(self):
		return len(self.X)

	def __getitem__(self, idx):
		img_full_path = os.path.join(self.img_path, self.X[idx] + ".jpg")
		mask_full_path = os.path.join(self.mask_path, self.X[idx] + ".png")

		img = cv2.imread(img_full_path)
		if img is None:
			raise FileNotFoundError(f"Image not found at {img_full_path}")

		mask = cv2.imread(mask_full_path, cv2.IMREAD_GRAYSCALE)
		if mask is None:
			raise FileNotFoundError(f"Mask not found at {mask_full_path}")

		img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

		if self.transform is not None:
			aug = self.transform(image=img, mask=mask)
			img = Image.fromarray(aug["image"])
			mask = aug["mask"]

		if self.transform is None:
			img = Image.fromarray(img)

		mask = torch.from_numpy(mask).long()

		return img, mask


t_test = A.Resize(768, 1152, interpolation=cv2.INTER_NEAREST)
test_set = DroneTestDataset(IMAGE_PATH, MASK_PATH, X_test, transform=t_test)

In [None]:
def predict_image_mask_miou(
	model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
	model.eval()
	t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
	image = t(image)
	model.to(device)
	image = image.to(device)
	mask = mask.to(device)
	with torch.no_grad():

		image = image.unsqueeze(0)
		mask = mask.unsqueeze(0)

		output = model(image)
		score = mIoU(output, mask)
		masked = torch.argmax(output, dim=1)
		masked = masked.cpu().squeeze(0)
	return masked, score

In [None]:
def predict_image_mask_pixel(
	model, image, mask, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
):
	model.eval()
	t = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
	image = t(image)
	model.to(device)
	image = image.to(device)
	mask = mask.to(device)
	with torch.no_grad():

		image = image.unsqueeze(0)
		mask = mask.unsqueeze(0)

		output = model(image)
		acc = pixel_accuracy(output, mask)
		masked = torch.argmax(output, dim=1)
		masked = masked.cpu().squeeze(0)
	return masked, acc

In [None]:
image, mask = test_set[3]
pred_mask, score = predict_image_mask_miou(model, image, mask)

In [None]:
def miou_score(model, test_set):
	score_iou = []
	for i in tqdm(range(len(test_set))):
		img, mask = test_set[i]
		pred_mask, score = predict_image_mask_miou(model, img, mask)
		score_iou.append(score)
	return score_iou

In [None]:
mob_miou = miou_score(model, test_set)

In [None]:
def pixel_acc(model, test_set):
	accuracy = []
	for i in tqdm(range(len(test_set))):
		img, mask = test_set[i]
		pred_mask, acc = predict_image_mask_pixel(model, img, mask)
		accuracy.append(acc)
	return accuracy

In [None]:
mob_acc = pixel_acc(model, test_set)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image)
ax1.set_title("Picture")

ax2.imshow(mask)
ax2.set_title("Ground truth")
ax2.set_axis_off()

ax3.imshow(pred_mask)
ax3.set_title("UNet-MobileNet | mIoU {:.3f}".format(score))
ax3.set_axis_off()

In [None]:
image3, mask3 = test_set[6]
pred_mask3, score3 = predict_image_mask_miou(model, image3, mask3)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image3)
ax1.set_title("Picture")

ax2.imshow(mask3)
ax2.set_title("Ground truth")
ax2.set_axis_off()

ax3.imshow(pred_mask3)
ax3.set_title("UNet-MobileNet | mIoU {:.3f}".format(score3))
ax3.set_axis_off()

In [None]:
print("Test Set mIoU", np.mean(mob_miou))

In [None]:
print("Test Set Pixel Accuracy", np.mean(mob_acc))