In [14]:
from scipy.io import loadmat
from utils import *
data = loadmat('emnist-digits.mat')

In [15]:
import matplotlib.pyplot as plt
import numpy as np


In [16]:
dataset = data['dataset'][0, 0]
train = dataset[0][0, 0]  
test = dataset[1][0, 0]  
mapping = dataset[2]

train_images = train['images']   # Shape: (N, 28*28)
train_labels = train['labels']  # Shape: (N, 1)
train_writers = train['writers']

In [17]:
X = train_images.astype(np.float32) / 255.0  # Normalize to [0, 1]
y = train_labels.flatten().astype(np.int64)

# Wrap into a datalist with a single client
datalist = [(X, y)]

test_images = test['images'].astype(np.float32) / 255.0
test_labels = test['labels'].flatten().astype(np.int64)

In [18]:
from labels_utils import *

In [6]:

# Hyperparameters
T = 5       # number of global rounds
K = 10      # number of client GD steps
gamma = 0.1 # learning rate

# Run FedAvg with 1 client
print("now training the baseline, i.e. fedAvg with one client holding all the data")
model = fedavg(datalist, T, K, gamma)


now training the baseline, i.e. fedAvg with one client holding all the data
round :  1
round :  2
round :  3
round :  4
round :  5


In [7]:

test_accuracy = evaluate(model, test_images, test_labels)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

Test Accuracy: 84.58%


In [22]:
n_clients = 5
beta = 0.5 
datalist = create_dirichlet_clients(X, y, n_clients, beta)

# Hyperparameters
T = 5       # number of global rounds
K = 10      # number of client GD steps
gamma = 0.1 # learning rate
print("case with 5 clients, beta=0.5 skewed distribution!")
model = fedavg(datalist, T, K, gamma)



case with 5 clients, beta=0.5 skewed distribution!
round :  1
round :  2
round :  3
round :  4
round :  5


In [23]:
# Evaluate
test_accuracy = evaluate(model, test_images, test_labels)
print(f"Test Accuracy with {n_clients} clients and Dir({beta}): {test_accuracy * 100:.2f}%")

Test Accuracy with 5 clients and Dir(0.5): 72.24%


In [18]:
n_clients = 5
beta = 1e5
datalist = create_dirichlet_clients(X, y, n_clients, beta)

# Hyperparameters
T = 5       # number of global rounds
K = 10      # number of client GD steps
gamma = 0.1 # learning rate
print("case with 5 clients, beta=10^5, which means close to IID clients (baseline 2)!")
model = fedavg(datalist, T, K, gamma)


case with 5 clients, beta=10^5, which means close to IID clients (baseline 2)!
round :  1
round :  2
round :  3
round :  4
round :  5


In [19]:
# Evaluate
test_accuracy = evaluate(model, test_images, test_labels)
print(f"Test Accuracy with {n_clients} clients and Dir({beta}): {test_accuracy * 100:.2f}%")

Test Accuracy with 5 clients and Dir(100000.0): 84.07%


In [20]:
n_clients = 5
beta = 0.5 
datalist = create_dirichlet_clients(X, y, n_clients, beta)

# Hyperparameters
T = 5       # number of global rounds
K = 10      # number of client GD steps
gamma = 0.1 # learning rate
print("case with 5 clients, beta=0.5 skewed distribution, with weights based on the inverse KL-divergence")
weights = compute_inverse_kl_weights(datalist)
model = fedavg(datalist, T=5, K=10, gamma=0.1, weights=weights)


case with 5 clients, beta=0.5 skewed distribution, with weights based on the inverse KL-divergence
round :  1
round :  2
round :  3
round :  4
round :  5


In [21]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.6942750215530396


### Testing the Moon Federated learning on the same dataset

In [9]:
import importlib
import utils
importlib.reload(utils)
from utils import *

In [19]:
n_clients = 5
beta = 0.5 
datalist = create_dirichlet_clients(X, y, n_clients, beta)

In [20]:
# Hyperparameters
T = 5       # number of global rounds
K = 10      # number of client GD steps
gamma = 0.1 # learning rate
print("case with 5 clients, beta=0.5 skewed distribution!")
model = fedavg(datalist, T, K, gamma)

case with 5 clients, beta=0.5 skewed distribution!
round :  1
round :  2
round :  3


KeyboardInterrupt: 

In [12]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.7438250184059143


In [21]:
model, loss_curve = fedavg_moon(datalist, T=5, K=10, gamma=0.1, mu=0.5)


In [29]:
test_accuracy = evaluate(model, test_images, test_labels)
print(test_accuracy)

0.8203250169754028
